#!/usr/bin/env python3 """ Gradio Demo for UMAP Visualization and 3D Geometry Viewer Similar to the Streamlit demo but using Gradio interface """ import os import sys import numpy as np import pandas as pd import tempfile import plotly.graph_objects as go import plotly.express as px from pathlib import Path import vtk import pyvista as pv from typing import List, Tuple, Optional, Dict, Any import gradio as gr import json from sklearn.preprocessing import StandardScaler # Set PyVista to use offscreen rendering pv.set_plot_theme("document") # pv.global_theme.off_screen = True def load_umap_data() -> Tuple[np.ndarray, List[str], List[str], np.ndarray, List[str]]: """Load UMAP analysis results with dynamic file names.""" try: # Try to find the files with dynamic naming in files/ directory # Look for files with different patterns possible_embedding_files = [ "files/umap_embedding.npy", # Legacy fallback "files/combined_umap_embedding.npy", # Multiple datasets ] possible_labels_files = [ "files/names.npy", # Legacy fallback "files/combined_names.npy", # Multiple datasets ] possible_features_files = [ "files/features.npy", # Legacy fallback "files/combined_features.npy", # Multiple datasets ] possible_dataset_labels_files = [ "files/labels_dataset_labels.npy", # Legacy fallback "files/combined_dataset_labels.npy", # Multiple datasets ] # Also check for single dataset files (look for any files matching the pattern) import glob single_dataset_embedding = glob.glob("files/*_umap_embedding.npy") # Use more specific pattern to avoid matching dataset_labels files single_dataset_labels = glob.glob("files/*_names.npy") single_dataset_features = glob.glob("files/*_features.npy") single_dataset_dataset_labels = glob.glob("files/*_dataset_labels.npy") # Add single dataset files to the search lists possible_embedding_files.extend(single_dataset_embedding) possible_labels_files.extend(single_dataset_labels) possible_features_files.extend(single_dataset_features) possible_dataset_labels_files.extend(single_dataset_dataset_labels) # Find the first existing file for each type embedding_file = None for file in possible_embedding_files: if os.path.exists(file): embedding_file = file break labels_file = None for file in possible_labels_files: if os.path.exists(file): labels_file = file break features_file = None for file in possible_features_files: if os.path.exists(file): features_file = file break dataset_labels_file = None for file in possible_dataset_labels_files: if os.path.exists(file): dataset_labels_file = file break if not all([embedding_file, labels_file, features_file]): raise FileNotFoundError("Required UMAP data files not found") # Load embedding coordinates embedding = np.load(embedding_file) print(f"✅ Loaded UMAP embedding from {embedding_file}: {embedding.shape}") # Load data names data_names = np.load(labels_file, allow_pickle=True).tolist() print(f"✅ Loaded {len(data_names)} data names from {labels_file}") print(f" First 5 data names: {data_names[:5]}") # Load dataset labels if dataset_labels_file and os.path.exists(dataset_labels_file): dataset_labels = np.load(dataset_labels_file, allow_pickle=True).tolist() print(f"✅ Loaded dataset labels from {dataset_labels_file}: {len(dataset_labels)} labels") print(f" First 5 dataset labels: {dataset_labels[:5]}") else: dataset_labels = ["Unknown"] * len(data_names) print("⚠️ No dataset labels found, using 'Unknown'") # Load features features = np.load(features_file) print(f"✅ Loaded features from {features_file}: {features.shape}") return embedding, data_names, dataset_labels, features except FileNotFoundError as e: print(f"❌ Error loading data: {e}") print("Please run the PhysicsNeMo analysis first!") return None, None, None, None, None except Exception as e: print(f"❌ Unexpected error: {e}") return None, None, None, None, None def load_config() -> Dict[str, Any]: """Load configuration from YAML file.""" try: import yaml with open("config.yaml", 'r') as f: config = yaml.safe_load(f) return config except Exception as e: print(f"⚠️ Could not load config: {e}") return {} def create_interactive_umap_plot(embedding: np.ndarray, data_names: List[str], dataset_labels: List[str], config: Dict[str, Any]) -> go.Figure: """Create interactive UMAP plot with dataset colors from config.""" # Create hover text hover_text = [] for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)): hover_text.append( f"{name}
" f"Dataset: {dataset}
" f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})
" f"Click to view 3D geometry" ) # Create figure fig = go.Figure() # Get unique datasets and their colors from config unique_datasets = list(set(dataset_labels)) # Create dataset to color mapping from config dataset_color_map = {} data_folders = config.get('data', {}).get('data_folders', []) for folder in data_folders: dataset_color_map[folder['label']] = folder['color'] # Default colors if not in config default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] # Plot each dataset separately for dataset_idx, dataset in enumerate(unique_datasets): dataset_mask = np.array(dataset_labels) == dataset dataset_indices = np.where(dataset_mask)[0] # Get color from config or use default color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)]) fig.add_trace(go.Scatter( x=embedding[dataset_mask, 0], y=embedding[dataset_mask, 1], mode='markers', marker=dict( size=10, color=color, line=dict(width=1, color='black'), opacity=0.8 ), name=dataset, hovertemplate=[hover_text[i] for i in dataset_indices], customdata=dataset_indices, # Store indices for click events showlegend=True )) # Update layout fig.update_layout( title=dict( text=" Interactive UMAP Visualization", x=0.5, font=dict(size=18, color='darkblue') ), xaxis_title="UMAP Component 1", yaxis_title="UMAP Component 2", width=600, height=500, hovermode='closest', clickmode='event+select', legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5 ), margin=dict(l=60, r=60, t=100, b=60) ) return fig def create_highlighted_umap_plot(embedding: np.ndarray, data_names: List[str], dataset_labels: List[str], selected_index: int, config: Dict[str, Any]) -> go.Figure: """Create UMAP plot with a highlighted selected point.""" # Create hover text hover_text = [] for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)): hover_text.append( f"{name}
" f"Dataset: {dataset}
" f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})
" f"Click to view 3D geometry" ) # Create figure fig = go.Figure() # Get unique datasets unique_datasets = list(set(dataset_labels)) # Create dataset to color mapping from config dataset_color_map = {} data_folders = config.get('data', {}).get('data_folders', []) for folder in data_folders: dataset_color_map[folder['label']] = folder['color'] # Default colors if not in config default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] # Plot each dataset separately for dataset_idx, dataset in enumerate(unique_datasets): dataset_mask = np.array(dataset_labels) == dataset dataset_indices = np.where(dataset_mask)[0] # Get color from config or use default color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)]) # Separate selected point from others selected_mask = dataset_indices == selected_index other_mask = dataset_indices != selected_index # Plot non-selected points if np.any(other_mask): other_indices = dataset_indices[other_mask] fig.add_trace(go.Scatter( x=embedding[other_indices, 0], y=embedding[other_indices, 1], mode='markers', marker=dict( size=8, color=color, line=dict(width=1, color='black'), opacity=0.6 ), name=dataset, hovertemplate=[hover_text[i] for i in other_indices], customdata=other_indices, showlegend=True )) # Plot selected point with highlight if np.any(selected_mask): selected_idx = dataset_indices[selected_mask][0] fig.add_trace(go.Scatter( x=[embedding[selected_idx, 0]], y=[embedding[selected_idx, 1]], mode='markers', marker=dict( size=20, color='black', line=dict(width=2, color='white'), opacity=1.0, symbol='square' ), name=f"{dataset} - Selected", hovertemplate=[hover_text[selected_idx]], customdata=[selected_idx], showlegend=True )) # Update layout fig.update_layout( title=dict( text=f"Interactive UMAP Visualization - Selected: {data_names[selected_index]}", x=0.5, font=dict(size=16, color='darkblue') ), xaxis_title="UMAP Component 1", yaxis_title="UMAP Component 2", width=600, height=500, hovermode='closest', clickmode='event+select', legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5 ), margin=dict(l=60, r=60, t=100, b=60) ) return fig def load_vtp_file(file_path: str) -> Optional[pv.PolyData]: """Load a VTP file and return PyVista mesh.""" try: if not os.path.exists(file_path): print(f"❌ File not found: {file_path}") return None # Load with PyVista mesh = pv.read(file_path) print(f"✅ Loaded mesh: {mesh.n_points} points, {mesh.n_cells} cells") return mesh except Exception as e: print(f"❌ Error loading {file_path}: {e}") return None def display_data_geometry(mesh: pv.PolyData, data_name: str) -> go.Figure: """Display data geometry using Plotly 3D mesh visualization.""" try: # Get mesh points and faces points = mesh.points # Sample points if too many for performance if len(points) > 20000: indices = np.random.choice(len(points), 20000, replace=False) points = points[indices] # Create enhanced point cloud fig = go.Figure(data=[go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers', marker=dict( size=3, color=points[:, 2], colorscale='Viridis', opacity=0.8, line=dict(width=0.2, color='rgba(0,0,0,0.3)'), showscale=True, colorbar=dict( title="Height (Z)" ) ), name='Point Cloud', hovertemplate='3D Point
' + 'X: %{x:.2f}
' + 'Y: %{y:.2f}
' + 'Z: %{z:.2f}
' + '' )]) fig.update_layout( title=dict( text=f"3D Geometry: {data_name}", x=0.5, font=dict(size=16, color='darkblue') ), scene=dict( xaxis=dict( title='X', gridcolor='lightgray', showbackground=True, backgroundcolor='rgba(240,240,240,0.1)' ), yaxis=dict( title='Y', gridcolor='lightgray', showbackground=True, backgroundcolor='rgba(240,240,240,0.1)' ), zaxis=dict( title='Z', gridcolor='lightgray', showbackground=True, backgroundcolor='rgba(240,240,240,0.1)' ), camera=dict( eye=dict(x=1.8, y=1.8, z=1.2), center=dict(x=0, y=0, z=0), up=dict(x=0, y=0, z=1) ), aspectmode='data', bgcolor='white' ), width=600, height=500, margin=dict(l=20, r=20, t=60, b=20), paper_bgcolor='white', plot_bgcolor='white' ) return fig except Exception as e: print(f"❌ 3D visualization failed: {e}") return None def get_mesh_info(mesh: pv.PolyData, data_name: str) -> str: """Get mesh information as formatted string.""" try: bounds = mesh.bounds volume = (bounds[1] - bounds[0]) * (bounds[3] - bounds[2]) * (bounds[5] - bounds[4]) info = f""" **📊 Mesh Information: {data_name}** **Basic Info:** - Points: {mesh.n_points:,} - Cells: {mesh.n_cells:,} - Data Arrays: {len(mesh.point_data.keys())} **Bounds:** - X: [{bounds[0]:.2f}, {bounds[1]:.2f}] - Y: [{bounds[2]:.2f}, {bounds[3]:.2f}] - Z: [{bounds[4]:.2f}, {bounds[5]:.2f}] **Volume:** {volume:.2f} """ # Show data arrays if available if mesh.point_data.keys(): info += "\n**Point Data Arrays:**\n" for key in mesh.point_data.keys(): info += f"- {key}: {mesh.point_data[key].shape}\n" return info except Exception as e: return f"❌ Unable to get mesh info: {str(e)}" def find_vtp_file(data_name: str, config: Dict[str, Any]) -> Optional[str]: """Find the VTP file for a given data name.""" try: data_folders = config.get('data', {}).get('data_folders', []) file_pattern = config.get('data', {}).get('file_pattern', '*.vtp') for folder_config in data_folders: folder_path = Path(folder_config['path']) if folder_path.exists(): # Search for VTP files vtp_files = list(folder_path.glob("**/*.vtp")) # Look for exact match first for vtp_file in vtp_files: if vtp_file.stem == data_name: return str(vtp_file) # Look for partial match for vtp_file in vtp_files: if data_name.lower() in vtp_file.stem.lower(): return str(vtp_file) print(f"⚠️ Could not find VTP file for: {data_name}") return None except Exception as e: print(f"❌ Error finding VTP file: {e}") return None def create_gradio_interface(): """Create the Gradio interface.""" # Load data embedding, data_names, dataset_labels, features = load_umap_data() config = load_config() if embedding is None: return gr.Interface( fn=lambda: "❌ No UMAP data found. Please run the PhysicsNeMo analysis first!", inputs=[], outputs=gr.Textbox(), title="📊 UMAP Visualization Demo", description="Error: No data available" ) # Create UMAP plot umap_fig = create_interactive_umap_plot(embedding, data_names, dataset_labels, config) def update_geometry_viewer_from_click(click_data): """Update the 3D geometry viewer when a point is clicked.""" if not click_data or not click_data.get('points'): return None, "Click on any point in the UMAP plot to view its 3D geometry." # Get the clicked point data point = click_data['points'][0] point_index = point.get('customdata') if point_index is None: return None, "❌ Could not get point index from click data." # Get the car name from the index if point_index >= len(data_names): return None, f"❌ Invalid point index: {point_index}" selected_car = data_names[point_index] # Find the VTP file vtp_file = find_vtp_file(selected_car, config) if not vtp_file: return None, f"❌ Could not find VTP file for: {selected_car}" # Load the mesh mesh = load_vtp_file(vtp_file) if mesh is None: return None, f"❌ Could not load geometry for: {selected_car}" # Create 3D visualization fig_3d = display_data_geometry(mesh, selected_car) mesh_info = get_mesh_info(mesh, selected_car) if fig_3d is None: return None, f"❌ Could not create 3D visualization for: {selected_car}" return fig_3d, mesh_info def handle_plot_click(click_data): """Handle plot click events.""" return update_geometry_viewer_from_click(click_data) def update_geometry_from_input(input_text): """Update geometry viewer from text input.""" if not input_text or input_text.strip() == "": return None, "Enter a car name or index number to view its 3D geometry." input_text = input_text.strip() selected_car = None # Try to parse as index first try: index = int(input_text) if 0 <= index < len(data_names): selected_car = data_names[index] else: return None, f"❌ Index {index} is out of range. Valid range: 0-{len(data_names)-1}" except ValueError: # Not a number, try to find by name input_lower = input_text.lower() for i, car_name in enumerate(data_names): if input_lower in car_name.lower(): selected_car = car_name break if selected_car is None: return None, f"❌ Could not find car matching '{input_text}'. Try a different name or index." # Find the VTP file vtp_file = find_vtp_file(selected_car, config) if not vtp_file: return None, f"❌ Could not find VTP file for: {selected_car}" # Load the mesh mesh = load_vtp_file(vtp_file) if mesh is None: return None, f"❌ Could not load geometry for: {selected_car}" # Create 3D visualization fig_3d = display_data_geometry(mesh, selected_car) mesh_info = get_mesh_info(mesh, selected_car) if fig_3d is None: return None, f"❌ Could not create 3D visualization for: {selected_car}" return fig_3d, mesh_info def update_geometry_from_selection(selected_data): """Update geometry viewer from dropdown selection.""" if not selected_data or selected_data.strip() == "": return None, None selected_data = selected_data.strip() # Find the index of the selected data try: selected_index = data_names.index(selected_data) except ValueError: return None, None # Find the VTP file vtp_file = find_vtp_file(selected_data, config) if not vtp_file: return None, None # Load the mesh mesh = load_vtp_file(vtp_file) if mesh is None: return None, None # Create 3D visualization fig_3d = display_data_geometry(mesh, selected_data) if fig_3d is None: return None, None # Create updated UMAP plot with highlighted point fig_umap = create_highlighted_umap_plot(embedding, data_names, dataset_labels, selected_index, config) return fig_3d, fig_umap def get_dataset_info(): """Get dataset information.""" if dataset_labels is None: return "No dataset information available." unique_datasets, counts = np.unique(dataset_labels, return_counts=True) info = "**📊 Dataset Information:**\n\n" for dataset, count in zip(unique_datasets, counts): info += f"• **{dataset}**: {count} items\n" info += f"\n**Total Data Items**: {len(data_names)}\n" info += f"**Total Features**: {features.shape[1] if features is not None else 'Unknown'}\n" # Note: Clustering information not available in this demo info += f"\n**Note**: Clustering information not available in this demo" return info # Create the interface with gr.Blocks(title="📊 UMAP Visualization Demo", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📊 Interactive UMAP Visualization Demo") with gr.Row(): with gr.Column(scale=1): # UMAP Visualization gr.Markdown("## 📊 UMAP Embedding Visualization") umap_plot = gr.Plot(umap_fig, label="UMAP Plot") with gr.Column(scale=1): # 3D Geometry Viewer gr.Markdown("## 🎯 3D Geometry Viewer") # Quick selection input gr.Markdown("### Quick Selection") selection_dropdown = gr.Dropdown( choices=data_names, label="Select Data Item", info="Type to search or select from dropdown", allow_custom_value=True, value=None ) geometry_plot = gr.Plot(label="3D Geometry") # Update geometry when selection changes selection_dropdown.change( fn=update_geometry_from_selection, inputs=[selection_dropdown], outputs=[geometry_plot, umap_plot] ) return demo def main(): """Main function to launch the Gradio demo.""" print("🚀 Launching Gradio UMAP Visualization Demo...") # Check if we're in the right directory if not os.path.exists("config.yaml"): print("❌ Error: config.yaml not found. Please run this script from the umap directory.") return 1 # Check for required data files (with dynamic naming) import glob # Look for any files matching the patterns in files/ directory embedding_files = glob.glob("files/*_umap_embedding.npy") + glob.glob("files/umap_embedding.npy") features_files = glob.glob("files/*_features.npy") + glob.glob("files/features.npy") labels_files = glob.glob("files/*_names.npy") + glob.glob("files/names.npy") if not embedding_files or not features_files or not labels_files: print("❌ Error: Required UMAP data files not found") print("Looking for files matching patterns in files/ directory:") print(" - files/*_umap_embedding.npy or files/umap_embedding.npy") print(" - files/*_features.npy or files/features.npy") print(" - files/*_names.npy or files/names.npy") print("Please run the PhysicsNeMo analysis first using: python run_umap.py") return 1 try: # Create and launch the interface demo = create_gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7862, share=False, show_error=True, quiet=False ) return 0 except Exception as e: print(f"❌ Error launching demo: {e}") return 1 if __name__ == "__main__": sys.exit(main())