AnsysLPFMTrame-App / similarity /umap /umap_gradio_demo.py
udbhav
Recreate Trame_app branch with clean history
67fb03c
#!/usr/bin/env python3
"""
Gradio Demo for UMAP Visualization and 3D Geometry Viewer
Similar to the Streamlit demo but using Gradio interface
"""
import os
import sys
import numpy as np
import pandas as pd
import tempfile
import plotly.graph_objects as go
import plotly.express as px
from pathlib import Path
import vtk
import pyvista as pv
from typing import List, Tuple, Optional, Dict, Any
import gradio as gr
import json
from sklearn.preprocessing import StandardScaler
# Set PyVista to use offscreen rendering
pv.set_plot_theme("document")
# pv.global_theme.off_screen = True
def load_umap_data() -> Tuple[np.ndarray, List[str], List[str], np.ndarray, List[str]]:
"""Load UMAP analysis results with dynamic file names."""
try:
# Try to find the files with dynamic naming in files/ directory
# Look for files with different patterns
possible_embedding_files = [
"files/umap_embedding.npy", # Legacy fallback
"files/combined_umap_embedding.npy", # Multiple datasets
]
possible_labels_files = [
"files/names.npy", # Legacy fallback
"files/combined_names.npy", # Multiple datasets
]
possible_features_files = [
"files/features.npy", # Legacy fallback
"files/combined_features.npy", # Multiple datasets
]
possible_dataset_labels_files = [
"files/labels_dataset_labels.npy", # Legacy fallback
"files/combined_dataset_labels.npy", # Multiple datasets
]
# Also check for single dataset files (look for any files matching the pattern)
import glob
single_dataset_embedding = glob.glob("files/*_umap_embedding.npy")
# Use more specific pattern to avoid matching dataset_labels files
single_dataset_labels = glob.glob("files/*_names.npy")
single_dataset_features = glob.glob("files/*_features.npy")
single_dataset_dataset_labels = glob.glob("files/*_dataset_labels.npy")
# Add single dataset files to the search lists
possible_embedding_files.extend(single_dataset_embedding)
possible_labels_files.extend(single_dataset_labels)
possible_features_files.extend(single_dataset_features)
possible_dataset_labels_files.extend(single_dataset_dataset_labels)
# Find the first existing file for each type
embedding_file = None
for file in possible_embedding_files:
if os.path.exists(file):
embedding_file = file
break
labels_file = None
for file in possible_labels_files:
if os.path.exists(file):
labels_file = file
break
features_file = None
for file in possible_features_files:
if os.path.exists(file):
features_file = file
break
dataset_labels_file = None
for file in possible_dataset_labels_files:
if os.path.exists(file):
dataset_labels_file = file
break
if not all([embedding_file, labels_file, features_file]):
raise FileNotFoundError("Required UMAP data files not found")
# Load embedding coordinates
embedding = np.load(embedding_file)
print(f"βœ… Loaded UMAP embedding from {embedding_file}: {embedding.shape}")
# Load data names
data_names = np.load(labels_file, allow_pickle=True).tolist()
print(f"βœ… Loaded {len(data_names)} data names from {labels_file}")
print(f" First 5 data names: {data_names[:5]}")
# Load dataset labels
if dataset_labels_file and os.path.exists(dataset_labels_file):
dataset_labels = np.load(dataset_labels_file, allow_pickle=True).tolist()
print(f"βœ… Loaded dataset labels from {dataset_labels_file}: {len(dataset_labels)} labels")
print(f" First 5 dataset labels: {dataset_labels[:5]}")
else:
dataset_labels = ["Unknown"] * len(data_names)
print("⚠️ No dataset labels found, using 'Unknown'")
# Load features
features = np.load(features_file)
print(f"βœ… Loaded features from {features_file}: {features.shape}")
return embedding, data_names, dataset_labels, features
except FileNotFoundError as e:
print(f"❌ Error loading data: {e}")
print("Please run the PhysicsNeMo analysis first!")
return None, None, None, None, None
except Exception as e:
print(f"❌ Unexpected error: {e}")
return None, None, None, None, None
def load_config() -> Dict[str, Any]:
"""Load configuration from YAML file."""
try:
import yaml
with open("config.yaml", 'r') as f:
config = yaml.safe_load(f)
return config
except Exception as e:
print(f"⚠️ Could not load config: {e}")
return {}
def create_interactive_umap_plot(embedding: np.ndarray, data_names: List[str],
dataset_labels: List[str], config: Dict[str, Any]) -> go.Figure:
"""Create interactive UMAP plot with dataset colors from config."""
# Create hover text
hover_text = []
for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)):
hover_text.append(
f"<b>{name}</b><br>"
f"Dataset: {dataset}<br>"
f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})<br>"
f"<i>Click to view 3D geometry</i>"
)
# Create figure
fig = go.Figure()
# Get unique datasets and their colors from config
unique_datasets = list(set(dataset_labels))
# Create dataset to color mapping from config
dataset_color_map = {}
data_folders = config.get('data', {}).get('data_folders', [])
for folder in data_folders:
dataset_color_map[folder['label']] = folder['color']
# Default colors if not in config
default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
# Plot each dataset separately
for dataset_idx, dataset in enumerate(unique_datasets):
dataset_mask = np.array(dataset_labels) == dataset
dataset_indices = np.where(dataset_mask)[0]
# Get color from config or use default
color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)])
fig.add_trace(go.Scatter(
x=embedding[dataset_mask, 0],
y=embedding[dataset_mask, 1],
mode='markers',
marker=dict(
size=10,
color=color,
line=dict(width=1, color='black'),
opacity=0.8
),
name=dataset,
hovertemplate=[hover_text[i] for i in dataset_indices],
customdata=dataset_indices, # Store indices for click events
showlegend=True
))
# Update layout
fig.update_layout(
title=dict(
text=" Interactive UMAP Visualization",
x=0.5,
font=dict(size=18, color='darkblue')
),
xaxis_title="UMAP Component 1",
yaxis_title="UMAP Component 2",
width=600,
height=500,
hovermode='closest',
clickmode='event+select',
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="center",
x=0.5
),
margin=dict(l=60, r=60, t=100, b=60)
)
return fig
def create_highlighted_umap_plot(embedding: np.ndarray, data_names: List[str],
dataset_labels: List[str], selected_index: int, config: Dict[str, Any]) -> go.Figure:
"""Create UMAP plot with a highlighted selected point."""
# Create hover text
hover_text = []
for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)):
hover_text.append(
f"<b>{name}</b><br>"
f"Dataset: {dataset}<br>"
f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})<br>"
f"<i>Click to view 3D geometry</i>"
)
# Create figure
fig = go.Figure()
# Get unique datasets
unique_datasets = list(set(dataset_labels))
# Create dataset to color mapping from config
dataset_color_map = {}
data_folders = config.get('data', {}).get('data_folders', [])
for folder in data_folders:
dataset_color_map[folder['label']] = folder['color']
# Default colors if not in config
default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
# Plot each dataset separately
for dataset_idx, dataset in enumerate(unique_datasets):
dataset_mask = np.array(dataset_labels) == dataset
dataset_indices = np.where(dataset_mask)[0]
# Get color from config or use default
color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)])
# Separate selected point from others
selected_mask = dataset_indices == selected_index
other_mask = dataset_indices != selected_index
# Plot non-selected points
if np.any(other_mask):
other_indices = dataset_indices[other_mask]
fig.add_trace(go.Scatter(
x=embedding[other_indices, 0],
y=embedding[other_indices, 1],
mode='markers',
marker=dict(
size=8,
color=color,
line=dict(width=1, color='black'),
opacity=0.6
),
name=dataset,
hovertemplate=[hover_text[i] for i in other_indices],
customdata=other_indices,
showlegend=True
))
# Plot selected point with highlight
if np.any(selected_mask):
selected_idx = dataset_indices[selected_mask][0]
fig.add_trace(go.Scatter(
x=[embedding[selected_idx, 0]],
y=[embedding[selected_idx, 1]],
mode='markers',
marker=dict(
size=20,
color='black',
line=dict(width=2, color='white'),
opacity=1.0,
symbol='square'
),
name=f"{dataset} - Selected",
hovertemplate=[hover_text[selected_idx]],
customdata=[selected_idx],
showlegend=True
))
# Update layout
fig.update_layout(
title=dict(
text=f"Interactive UMAP Visualization - Selected: {data_names[selected_index]}",
x=0.5,
font=dict(size=16, color='darkblue')
),
xaxis_title="UMAP Component 1",
yaxis_title="UMAP Component 2",
width=600,
height=500,
hovermode='closest',
clickmode='event+select',
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="center",
x=0.5
),
margin=dict(l=60, r=60, t=100, b=60)
)
return fig
def load_vtp_file(file_path: str) -> Optional[pv.PolyData]:
"""Load a VTP file and return PyVista mesh."""
try:
if not os.path.exists(file_path):
print(f"❌ File not found: {file_path}")
return None
# Load with PyVista
mesh = pv.read(file_path)
print(f"βœ… Loaded mesh: {mesh.n_points} points, {mesh.n_cells} cells")
return mesh
except Exception as e:
print(f"❌ Error loading {file_path}: {e}")
return None
def display_data_geometry(mesh: pv.PolyData, data_name: str) -> go.Figure:
"""Display data geometry using Plotly 3D mesh visualization."""
try:
# Get mesh points and faces
points = mesh.points
# Sample points if too many for performance
if len(points) > 20000:
indices = np.random.choice(len(points), 20000, replace=False)
points = points[indices]
# Create enhanced point cloud
fig = go.Figure(data=[go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=3,
color=points[:, 2],
colorscale='Viridis',
opacity=0.8,
line=dict(width=0.2, color='rgba(0,0,0,0.3)'),
showscale=True,
colorbar=dict(
title="Height (Z)"
)
),
name='Point Cloud',
hovertemplate='<b>3D Point</b><br>' +
'X: %{x:.2f}<br>' +
'Y: %{y:.2f}<br>' +
'Z: %{z:.2f}<br>' +
'<extra></extra>'
)])
fig.update_layout(
title=dict(
text=f"3D Geometry: {data_name}",
x=0.5,
font=dict(size=16, color='darkblue')
),
scene=dict(
xaxis=dict(
title='X',
gridcolor='lightgray',
showbackground=True,
backgroundcolor='rgba(240,240,240,0.1)'
),
yaxis=dict(
title='Y',
gridcolor='lightgray',
showbackground=True,
backgroundcolor='rgba(240,240,240,0.1)'
),
zaxis=dict(
title='Z',
gridcolor='lightgray',
showbackground=True,
backgroundcolor='rgba(240,240,240,0.1)'
),
camera=dict(
eye=dict(x=1.8, y=1.8, z=1.2),
center=dict(x=0, y=0, z=0),
up=dict(x=0, y=0, z=1)
),
aspectmode='data',
bgcolor='white'
),
width=600,
height=500,
margin=dict(l=20, r=20, t=60, b=20),
paper_bgcolor='white',
plot_bgcolor='white'
)
return fig
except Exception as e:
print(f"❌ 3D visualization failed: {e}")
return None
def get_mesh_info(mesh: pv.PolyData, data_name: str) -> str:
"""Get mesh information as formatted string."""
try:
bounds = mesh.bounds
volume = (bounds[1] - bounds[0]) * (bounds[3] - bounds[2]) * (bounds[5] - bounds[4])
info = f"""
**πŸ“Š Mesh Information: {data_name}**
**Basic Info:**
- Points: {mesh.n_points:,}
- Cells: {mesh.n_cells:,}
- Data Arrays: {len(mesh.point_data.keys())}
**Bounds:**
- X: [{bounds[0]:.2f}, {bounds[1]:.2f}]
- Y: [{bounds[2]:.2f}, {bounds[3]:.2f}]
- Z: [{bounds[4]:.2f}, {bounds[5]:.2f}]
**Volume:** {volume:.2f}
"""
# Show data arrays if available
if mesh.point_data.keys():
info += "\n**Point Data Arrays:**\n"
for key in mesh.point_data.keys():
info += f"- {key}: {mesh.point_data[key].shape}\n"
return info
except Exception as e:
return f"❌ Unable to get mesh info: {str(e)}"
def find_vtp_file(data_name: str, config: Dict[str, Any]) -> Optional[str]:
"""Find the VTP file for a given data name."""
try:
data_folders = config.get('data', {}).get('data_folders', [])
file_pattern = config.get('data', {}).get('file_pattern', '*.vtp')
for folder_config in data_folders:
folder_path = Path(folder_config['path'])
if folder_path.exists():
# Search for VTP files
vtp_files = list(folder_path.glob("**/*.vtp"))
# Look for exact match first
for vtp_file in vtp_files:
if vtp_file.stem == data_name:
return str(vtp_file)
# Look for partial match
for vtp_file in vtp_files:
if data_name.lower() in vtp_file.stem.lower():
return str(vtp_file)
print(f"⚠️ Could not find VTP file for: {data_name}")
return None
except Exception as e:
print(f"❌ Error finding VTP file: {e}")
return None
def create_gradio_interface():
"""Create the Gradio interface."""
# Load data
embedding, data_names, dataset_labels, features = load_umap_data()
config = load_config()
if embedding is None:
return gr.Interface(
fn=lambda: "❌ No UMAP data found. Please run the PhysicsNeMo analysis first!",
inputs=[],
outputs=gr.Textbox(),
title="πŸ“Š UMAP Visualization Demo",
description="Error: No data available"
)
# Create UMAP plot
umap_fig = create_interactive_umap_plot(embedding, data_names, dataset_labels, config)
def update_geometry_viewer_from_click(click_data):
"""Update the 3D geometry viewer when a point is clicked."""
if not click_data or not click_data.get('points'):
return None, "Click on any point in the UMAP plot to view its 3D geometry."
# Get the clicked point data
point = click_data['points'][0]
point_index = point.get('customdata')
if point_index is None:
return None, "❌ Could not get point index from click data."
# Get the car name from the index
if point_index >= len(data_names):
return None, f"❌ Invalid point index: {point_index}"
selected_car = data_names[point_index]
# Find the VTP file
vtp_file = find_vtp_file(selected_car, config)
if not vtp_file:
return None, f"❌ Could not find VTP file for: {selected_car}"
# Load the mesh
mesh = load_vtp_file(vtp_file)
if mesh is None:
return None, f"❌ Could not load geometry for: {selected_car}"
# Create 3D visualization
fig_3d = display_data_geometry(mesh, selected_car)
mesh_info = get_mesh_info(mesh, selected_car)
if fig_3d is None:
return None, f"❌ Could not create 3D visualization for: {selected_car}"
return fig_3d, mesh_info
def handle_plot_click(click_data):
"""Handle plot click events."""
return update_geometry_viewer_from_click(click_data)
def update_geometry_from_input(input_text):
"""Update geometry viewer from text input."""
if not input_text or input_text.strip() == "":
return None, "Enter a car name or index number to view its 3D geometry."
input_text = input_text.strip()
selected_car = None
# Try to parse as index first
try:
index = int(input_text)
if 0 <= index < len(data_names):
selected_car = data_names[index]
else:
return None, f"❌ Index {index} is out of range. Valid range: 0-{len(data_names)-1}"
except ValueError:
# Not a number, try to find by name
input_lower = input_text.lower()
for i, car_name in enumerate(data_names):
if input_lower in car_name.lower():
selected_car = car_name
break
if selected_car is None:
return None, f"❌ Could not find car matching '{input_text}'. Try a different name or index."
# Find the VTP file
vtp_file = find_vtp_file(selected_car, config)
if not vtp_file:
return None, f"❌ Could not find VTP file for: {selected_car}"
# Load the mesh
mesh = load_vtp_file(vtp_file)
if mesh is None:
return None, f"❌ Could not load geometry for: {selected_car}"
# Create 3D visualization
fig_3d = display_data_geometry(mesh, selected_car)
mesh_info = get_mesh_info(mesh, selected_car)
if fig_3d is None:
return None, f"❌ Could not create 3D visualization for: {selected_car}"
return fig_3d, mesh_info
def update_geometry_from_selection(selected_data):
"""Update geometry viewer from dropdown selection."""
if not selected_data or selected_data.strip() == "":
return None, None
selected_data = selected_data.strip()
# Find the index of the selected data
try:
selected_index = data_names.index(selected_data)
except ValueError:
return None, None
# Find the VTP file
vtp_file = find_vtp_file(selected_data, config)
if not vtp_file:
return None, None
# Load the mesh
mesh = load_vtp_file(vtp_file)
if mesh is None:
return None, None
# Create 3D visualization
fig_3d = display_data_geometry(mesh, selected_data)
if fig_3d is None:
return None, None
# Create updated UMAP plot with highlighted point
fig_umap = create_highlighted_umap_plot(embedding, data_names, dataset_labels, selected_index, config)
return fig_3d, fig_umap
def get_dataset_info():
"""Get dataset information."""
if dataset_labels is None:
return "No dataset information available."
unique_datasets, counts = np.unique(dataset_labels, return_counts=True)
info = "**πŸ“Š Dataset Information:**\n\n"
for dataset, count in zip(unique_datasets, counts):
info += f"β€’ **{dataset}**: {count} items\n"
info += f"\n**Total Data Items**: {len(data_names)}\n"
info += f"**Total Features**: {features.shape[1] if features is not None else 'Unknown'}\n"
# Note: Clustering information not available in this demo
info += f"\n**Note**: Clustering information not available in this demo"
return info
# Create the interface
with gr.Blocks(title="πŸ“Š UMAP Visualization Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“Š Interactive UMAP Visualization Demo")
with gr.Row():
with gr.Column(scale=1):
# UMAP Visualization
gr.Markdown("## πŸ“Š UMAP Embedding Visualization")
umap_plot = gr.Plot(umap_fig, label="UMAP Plot")
with gr.Column(scale=1):
# 3D Geometry Viewer
gr.Markdown("## 🎯 3D Geometry Viewer")
# Quick selection input
gr.Markdown("### Quick Selection")
selection_dropdown = gr.Dropdown(
choices=data_names,
label="Select Data Item",
info="Type to search or select from dropdown",
allow_custom_value=True,
value=None
)
geometry_plot = gr.Plot(label="3D Geometry")
# Update geometry when selection changes
selection_dropdown.change(
fn=update_geometry_from_selection,
inputs=[selection_dropdown],
outputs=[geometry_plot, umap_plot]
)
return demo
def main():
"""Main function to launch the Gradio demo."""
print("πŸš€ Launching Gradio UMAP Visualization Demo...")
# Check if we're in the right directory
if not os.path.exists("config.yaml"):
print("❌ Error: config.yaml not found. Please run this script from the umap directory.")
return 1
# Check for required data files (with dynamic naming)
import glob
# Look for any files matching the patterns in files/ directory
embedding_files = glob.glob("files/*_umap_embedding.npy") + glob.glob("files/umap_embedding.npy")
features_files = glob.glob("files/*_features.npy") + glob.glob("files/features.npy")
labels_files = glob.glob("files/*_names.npy") + glob.glob("files/names.npy")
if not embedding_files or not features_files or not labels_files:
print("❌ Error: Required UMAP data files not found")
print("Looking for files matching patterns in files/ directory:")
print(" - files/*_umap_embedding.npy or files/umap_embedding.npy")
print(" - files/*_features.npy or files/features.npy")
print(" - files/*_names.npy or files/names.npy")
print("Please run the PhysicsNeMo analysis first using: python run_umap.py")
return 1
try:
# Create and launch the interface
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7862,
share=False,
show_error=True,
quiet=False
)
return 0
except Exception as e:
print(f"❌ Error launching demo: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())