Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Demo for UMAP Visualization and 3D Geometry Viewer | |
| Similar to the Streamlit demo but using Gradio interface | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import pandas as pd | |
| import tempfile | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from pathlib import Path | |
| import vtk | |
| import pyvista as pv | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import gradio as gr | |
| import json | |
| from sklearn.preprocessing import StandardScaler | |
| # Set PyVista to use offscreen rendering | |
| pv.set_plot_theme("document") | |
| # pv.global_theme.off_screen = True | |
| def load_umap_data() -> Tuple[np.ndarray, List[str], List[str], np.ndarray, List[str]]: | |
| """Load UMAP analysis results with dynamic file names.""" | |
| try: | |
| # Try to find the files with dynamic naming in files/ directory | |
| # Look for files with different patterns | |
| possible_embedding_files = [ | |
| "files/umap_embedding.npy", # Legacy fallback | |
| "files/combined_umap_embedding.npy", # Multiple datasets | |
| ] | |
| possible_labels_files = [ | |
| "files/names.npy", # Legacy fallback | |
| "files/combined_names.npy", # Multiple datasets | |
| ] | |
| possible_features_files = [ | |
| "files/features.npy", # Legacy fallback | |
| "files/combined_features.npy", # Multiple datasets | |
| ] | |
| possible_dataset_labels_files = [ | |
| "files/labels_dataset_labels.npy", # Legacy fallback | |
| "files/combined_dataset_labels.npy", # Multiple datasets | |
| ] | |
| # Also check for single dataset files (look for any files matching the pattern) | |
| import glob | |
| single_dataset_embedding = glob.glob("files/*_umap_embedding.npy") | |
| # Use more specific pattern to avoid matching dataset_labels files | |
| single_dataset_labels = glob.glob("files/*_names.npy") | |
| single_dataset_features = glob.glob("files/*_features.npy") | |
| single_dataset_dataset_labels = glob.glob("files/*_dataset_labels.npy") | |
| # Add single dataset files to the search lists | |
| possible_embedding_files.extend(single_dataset_embedding) | |
| possible_labels_files.extend(single_dataset_labels) | |
| possible_features_files.extend(single_dataset_features) | |
| possible_dataset_labels_files.extend(single_dataset_dataset_labels) | |
| # Find the first existing file for each type | |
| embedding_file = None | |
| for file in possible_embedding_files: | |
| if os.path.exists(file): | |
| embedding_file = file | |
| break | |
| labels_file = None | |
| for file in possible_labels_files: | |
| if os.path.exists(file): | |
| labels_file = file | |
| break | |
| features_file = None | |
| for file in possible_features_files: | |
| if os.path.exists(file): | |
| features_file = file | |
| break | |
| dataset_labels_file = None | |
| for file in possible_dataset_labels_files: | |
| if os.path.exists(file): | |
| dataset_labels_file = file | |
| break | |
| if not all([embedding_file, labels_file, features_file]): | |
| raise FileNotFoundError("Required UMAP data files not found") | |
| # Load embedding coordinates | |
| embedding = np.load(embedding_file) | |
| print(f"β Loaded UMAP embedding from {embedding_file}: {embedding.shape}") | |
| # Load data names | |
| data_names = np.load(labels_file, allow_pickle=True).tolist() | |
| print(f"β Loaded {len(data_names)} data names from {labels_file}") | |
| print(f" First 5 data names: {data_names[:5]}") | |
| # Load dataset labels | |
| if dataset_labels_file and os.path.exists(dataset_labels_file): | |
| dataset_labels = np.load(dataset_labels_file, allow_pickle=True).tolist() | |
| print(f"β Loaded dataset labels from {dataset_labels_file}: {len(dataset_labels)} labels") | |
| print(f" First 5 dataset labels: {dataset_labels[:5]}") | |
| else: | |
| dataset_labels = ["Unknown"] * len(data_names) | |
| print("β οΈ No dataset labels found, using 'Unknown'") | |
| # Load features | |
| features = np.load(features_file) | |
| print(f"β Loaded features from {features_file}: {features.shape}") | |
| return embedding, data_names, dataset_labels, features | |
| except FileNotFoundError as e: | |
| print(f"β Error loading data: {e}") | |
| print("Please run the PhysicsNeMo analysis first!") | |
| return None, None, None, None, None | |
| except Exception as e: | |
| print(f"β Unexpected error: {e}") | |
| return None, None, None, None, None | |
| def load_config() -> Dict[str, Any]: | |
| """Load configuration from YAML file.""" | |
| try: | |
| import yaml | |
| with open("config.yaml", 'r') as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| except Exception as e: | |
| print(f"β οΈ Could not load config: {e}") | |
| return {} | |
| def create_interactive_umap_plot(embedding: np.ndarray, data_names: List[str], | |
| dataset_labels: List[str], config: Dict[str, Any]) -> go.Figure: | |
| """Create interactive UMAP plot with dataset colors from config.""" | |
| # Create hover text | |
| hover_text = [] | |
| for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)): | |
| hover_text.append( | |
| f"<b>{name}</b><br>" | |
| f"Dataset: {dataset}<br>" | |
| f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})<br>" | |
| f"<i>Click to view 3D geometry</i>" | |
| ) | |
| # Create figure | |
| fig = go.Figure() | |
| # Get unique datasets and their colors from config | |
| unique_datasets = list(set(dataset_labels)) | |
| # Create dataset to color mapping from config | |
| dataset_color_map = {} | |
| data_folders = config.get('data', {}).get('data_folders', []) | |
| for folder in data_folders: | |
| dataset_color_map[folder['label']] = folder['color'] | |
| # Default colors if not in config | |
| default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] | |
| # Plot each dataset separately | |
| for dataset_idx, dataset in enumerate(unique_datasets): | |
| dataset_mask = np.array(dataset_labels) == dataset | |
| dataset_indices = np.where(dataset_mask)[0] | |
| # Get color from config or use default | |
| color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)]) | |
| fig.add_trace(go.Scatter( | |
| x=embedding[dataset_mask, 0], | |
| y=embedding[dataset_mask, 1], | |
| mode='markers', | |
| marker=dict( | |
| size=10, | |
| color=color, | |
| line=dict(width=1, color='black'), | |
| opacity=0.8 | |
| ), | |
| name=dataset, | |
| hovertemplate=[hover_text[i] for i in dataset_indices], | |
| customdata=dataset_indices, # Store indices for click events | |
| showlegend=True | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=" Interactive UMAP Visualization", | |
| x=0.5, | |
| font=dict(size=18, color='darkblue') | |
| ), | |
| xaxis_title="UMAP Component 1", | |
| yaxis_title="UMAP Component 2", | |
| width=600, | |
| height=500, | |
| hovermode='closest', | |
| clickmode='event+select', | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="center", | |
| x=0.5 | |
| ), | |
| margin=dict(l=60, r=60, t=100, b=60) | |
| ) | |
| return fig | |
| def create_highlighted_umap_plot(embedding: np.ndarray, data_names: List[str], | |
| dataset_labels: List[str], selected_index: int, config: Dict[str, Any]) -> go.Figure: | |
| """Create UMAP plot with a highlighted selected point.""" | |
| # Create hover text | |
| hover_text = [] | |
| for i, (name, dataset) in enumerate(zip(data_names, dataset_labels)): | |
| hover_text.append( | |
| f"<b>{name}</b><br>" | |
| f"Dataset: {dataset}<br>" | |
| f"UMAP: ({embedding[i, 0]:.3f}, {embedding[i, 1]:.3f})<br>" | |
| f"<i>Click to view 3D geometry</i>" | |
| ) | |
| # Create figure | |
| fig = go.Figure() | |
| # Get unique datasets | |
| unique_datasets = list(set(dataset_labels)) | |
| # Create dataset to color mapping from config | |
| dataset_color_map = {} | |
| data_folders = config.get('data', {}).get('data_folders', []) | |
| for folder in data_folders: | |
| dataset_color_map[folder['label']] = folder['color'] | |
| # Default colors if not in config | |
| default_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] | |
| # Plot each dataset separately | |
| for dataset_idx, dataset in enumerate(unique_datasets): | |
| dataset_mask = np.array(dataset_labels) == dataset | |
| dataset_indices = np.where(dataset_mask)[0] | |
| # Get color from config or use default | |
| color = dataset_color_map.get(dataset, default_colors[dataset_idx % len(default_colors)]) | |
| # Separate selected point from others | |
| selected_mask = dataset_indices == selected_index | |
| other_mask = dataset_indices != selected_index | |
| # Plot non-selected points | |
| if np.any(other_mask): | |
| other_indices = dataset_indices[other_mask] | |
| fig.add_trace(go.Scatter( | |
| x=embedding[other_indices, 0], | |
| y=embedding[other_indices, 1], | |
| mode='markers', | |
| marker=dict( | |
| size=8, | |
| color=color, | |
| line=dict(width=1, color='black'), | |
| opacity=0.6 | |
| ), | |
| name=dataset, | |
| hovertemplate=[hover_text[i] for i in other_indices], | |
| customdata=other_indices, | |
| showlegend=True | |
| )) | |
| # Plot selected point with highlight | |
| if np.any(selected_mask): | |
| selected_idx = dataset_indices[selected_mask][0] | |
| fig.add_trace(go.Scatter( | |
| x=[embedding[selected_idx, 0]], | |
| y=[embedding[selected_idx, 1]], | |
| mode='markers', | |
| marker=dict( | |
| size=20, | |
| color='black', | |
| line=dict(width=2, color='white'), | |
| opacity=1.0, | |
| symbol='square' | |
| ), | |
| name=f"{dataset} - Selected", | |
| hovertemplate=[hover_text[selected_idx]], | |
| customdata=[selected_idx], | |
| showlegend=True | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Interactive UMAP Visualization - Selected: {data_names[selected_index]}", | |
| x=0.5, | |
| font=dict(size=16, color='darkblue') | |
| ), | |
| xaxis_title="UMAP Component 1", | |
| yaxis_title="UMAP Component 2", | |
| width=600, | |
| height=500, | |
| hovermode='closest', | |
| clickmode='event+select', | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="center", | |
| x=0.5 | |
| ), | |
| margin=dict(l=60, r=60, t=100, b=60) | |
| ) | |
| return fig | |
| def load_vtp_file(file_path: str) -> Optional[pv.PolyData]: | |
| """Load a VTP file and return PyVista mesh.""" | |
| try: | |
| if not os.path.exists(file_path): | |
| print(f"β File not found: {file_path}") | |
| return None | |
| # Load with PyVista | |
| mesh = pv.read(file_path) | |
| print(f"β Loaded mesh: {mesh.n_points} points, {mesh.n_cells} cells") | |
| return mesh | |
| except Exception as e: | |
| print(f"β Error loading {file_path}: {e}") | |
| return None | |
| def display_data_geometry(mesh: pv.PolyData, data_name: str) -> go.Figure: | |
| """Display data geometry using Plotly 3D mesh visualization.""" | |
| try: | |
| # Get mesh points and faces | |
| points = mesh.points | |
| # Sample points if too many for performance | |
| if len(points) > 20000: | |
| indices = np.random.choice(len(points), 20000, replace=False) | |
| points = points[indices] | |
| # Create enhanced point cloud | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=points[:, 0], | |
| y=points[:, 1], | |
| z=points[:, 2], | |
| mode='markers', | |
| marker=dict( | |
| size=3, | |
| color=points[:, 2], | |
| colorscale='Viridis', | |
| opacity=0.8, | |
| line=dict(width=0.2, color='rgba(0,0,0,0.3)'), | |
| showscale=True, | |
| colorbar=dict( | |
| title="Height (Z)" | |
| ) | |
| ), | |
| name='Point Cloud', | |
| hovertemplate='<b>3D Point</b><br>' + | |
| 'X: %{x:.2f}<br>' + | |
| 'Y: %{y:.2f}<br>' + | |
| 'Z: %{z:.2f}<br>' + | |
| '<extra></extra>' | |
| )]) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"3D Geometry: {data_name}", | |
| x=0.5, | |
| font=dict(size=16, color='darkblue') | |
| ), | |
| scene=dict( | |
| xaxis=dict( | |
| title='X', | |
| gridcolor='lightgray', | |
| showbackground=True, | |
| backgroundcolor='rgba(240,240,240,0.1)' | |
| ), | |
| yaxis=dict( | |
| title='Y', | |
| gridcolor='lightgray', | |
| showbackground=True, | |
| backgroundcolor='rgba(240,240,240,0.1)' | |
| ), | |
| zaxis=dict( | |
| title='Z', | |
| gridcolor='lightgray', | |
| showbackground=True, | |
| backgroundcolor='rgba(240,240,240,0.1)' | |
| ), | |
| camera=dict( | |
| eye=dict(x=1.8, y=1.8, z=1.2), | |
| center=dict(x=0, y=0, z=0), | |
| up=dict(x=0, y=0, z=1) | |
| ), | |
| aspectmode='data', | |
| bgcolor='white' | |
| ), | |
| width=600, | |
| height=500, | |
| margin=dict(l=20, r=20, t=60, b=20), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white' | |
| ) | |
| return fig | |
| except Exception as e: | |
| print(f"β 3D visualization failed: {e}") | |
| return None | |
| def get_mesh_info(mesh: pv.PolyData, data_name: str) -> str: | |
| """Get mesh information as formatted string.""" | |
| try: | |
| bounds = mesh.bounds | |
| volume = (bounds[1] - bounds[0]) * (bounds[3] - bounds[2]) * (bounds[5] - bounds[4]) | |
| info = f""" | |
| **π Mesh Information: {data_name}** | |
| **Basic Info:** | |
| - Points: {mesh.n_points:,} | |
| - Cells: {mesh.n_cells:,} | |
| - Data Arrays: {len(mesh.point_data.keys())} | |
| **Bounds:** | |
| - X: [{bounds[0]:.2f}, {bounds[1]:.2f}] | |
| - Y: [{bounds[2]:.2f}, {bounds[3]:.2f}] | |
| - Z: [{bounds[4]:.2f}, {bounds[5]:.2f}] | |
| **Volume:** {volume:.2f} | |
| """ | |
| # Show data arrays if available | |
| if mesh.point_data.keys(): | |
| info += "\n**Point Data Arrays:**\n" | |
| for key in mesh.point_data.keys(): | |
| info += f"- {key}: {mesh.point_data[key].shape}\n" | |
| return info | |
| except Exception as e: | |
| return f"β Unable to get mesh info: {str(e)}" | |
| def find_vtp_file(data_name: str, config: Dict[str, Any]) -> Optional[str]: | |
| """Find the VTP file for a given data name.""" | |
| try: | |
| data_folders = config.get('data', {}).get('data_folders', []) | |
| file_pattern = config.get('data', {}).get('file_pattern', '*.vtp') | |
| for folder_config in data_folders: | |
| folder_path = Path(folder_config['path']) | |
| if folder_path.exists(): | |
| # Search for VTP files | |
| vtp_files = list(folder_path.glob("**/*.vtp")) | |
| # Look for exact match first | |
| for vtp_file in vtp_files: | |
| if vtp_file.stem == data_name: | |
| return str(vtp_file) | |
| # Look for partial match | |
| for vtp_file in vtp_files: | |
| if data_name.lower() in vtp_file.stem.lower(): | |
| return str(vtp_file) | |
| print(f"β οΈ Could not find VTP file for: {data_name}") | |
| return None | |
| except Exception as e: | |
| print(f"β Error finding VTP file: {e}") | |
| return None | |
| def create_gradio_interface(): | |
| """Create the Gradio interface.""" | |
| # Load data | |
| embedding, data_names, dataset_labels, features = load_umap_data() | |
| config = load_config() | |
| if embedding is None: | |
| return gr.Interface( | |
| fn=lambda: "β No UMAP data found. Please run the PhysicsNeMo analysis first!", | |
| inputs=[], | |
| outputs=gr.Textbox(), | |
| title="π UMAP Visualization Demo", | |
| description="Error: No data available" | |
| ) | |
| # Create UMAP plot | |
| umap_fig = create_interactive_umap_plot(embedding, data_names, dataset_labels, config) | |
| def update_geometry_viewer_from_click(click_data): | |
| """Update the 3D geometry viewer when a point is clicked.""" | |
| if not click_data or not click_data.get('points'): | |
| return None, "Click on any point in the UMAP plot to view its 3D geometry." | |
| # Get the clicked point data | |
| point = click_data['points'][0] | |
| point_index = point.get('customdata') | |
| if point_index is None: | |
| return None, "β Could not get point index from click data." | |
| # Get the car name from the index | |
| if point_index >= len(data_names): | |
| return None, f"β Invalid point index: {point_index}" | |
| selected_car = data_names[point_index] | |
| # Find the VTP file | |
| vtp_file = find_vtp_file(selected_car, config) | |
| if not vtp_file: | |
| return None, f"β Could not find VTP file for: {selected_car}" | |
| # Load the mesh | |
| mesh = load_vtp_file(vtp_file) | |
| if mesh is None: | |
| return None, f"β Could not load geometry for: {selected_car}" | |
| # Create 3D visualization | |
| fig_3d = display_data_geometry(mesh, selected_car) | |
| mesh_info = get_mesh_info(mesh, selected_car) | |
| if fig_3d is None: | |
| return None, f"β Could not create 3D visualization for: {selected_car}" | |
| return fig_3d, mesh_info | |
| def handle_plot_click(click_data): | |
| """Handle plot click events.""" | |
| return update_geometry_viewer_from_click(click_data) | |
| def update_geometry_from_input(input_text): | |
| """Update geometry viewer from text input.""" | |
| if not input_text or input_text.strip() == "": | |
| return None, "Enter a car name or index number to view its 3D geometry." | |
| input_text = input_text.strip() | |
| selected_car = None | |
| # Try to parse as index first | |
| try: | |
| index = int(input_text) | |
| if 0 <= index < len(data_names): | |
| selected_car = data_names[index] | |
| else: | |
| return None, f"β Index {index} is out of range. Valid range: 0-{len(data_names)-1}" | |
| except ValueError: | |
| # Not a number, try to find by name | |
| input_lower = input_text.lower() | |
| for i, car_name in enumerate(data_names): | |
| if input_lower in car_name.lower(): | |
| selected_car = car_name | |
| break | |
| if selected_car is None: | |
| return None, f"β Could not find car matching '{input_text}'. Try a different name or index." | |
| # Find the VTP file | |
| vtp_file = find_vtp_file(selected_car, config) | |
| if not vtp_file: | |
| return None, f"β Could not find VTP file for: {selected_car}" | |
| # Load the mesh | |
| mesh = load_vtp_file(vtp_file) | |
| if mesh is None: | |
| return None, f"β Could not load geometry for: {selected_car}" | |
| # Create 3D visualization | |
| fig_3d = display_data_geometry(mesh, selected_car) | |
| mesh_info = get_mesh_info(mesh, selected_car) | |
| if fig_3d is None: | |
| return None, f"β Could not create 3D visualization for: {selected_car}" | |
| return fig_3d, mesh_info | |
| def update_geometry_from_selection(selected_data): | |
| """Update geometry viewer from dropdown selection.""" | |
| if not selected_data or selected_data.strip() == "": | |
| return None, None | |
| selected_data = selected_data.strip() | |
| # Find the index of the selected data | |
| try: | |
| selected_index = data_names.index(selected_data) | |
| except ValueError: | |
| return None, None | |
| # Find the VTP file | |
| vtp_file = find_vtp_file(selected_data, config) | |
| if not vtp_file: | |
| return None, None | |
| # Load the mesh | |
| mesh = load_vtp_file(vtp_file) | |
| if mesh is None: | |
| return None, None | |
| # Create 3D visualization | |
| fig_3d = display_data_geometry(mesh, selected_data) | |
| if fig_3d is None: | |
| return None, None | |
| # Create updated UMAP plot with highlighted point | |
| fig_umap = create_highlighted_umap_plot(embedding, data_names, dataset_labels, selected_index, config) | |
| return fig_3d, fig_umap | |
| def get_dataset_info(): | |
| """Get dataset information.""" | |
| if dataset_labels is None: | |
| return "No dataset information available." | |
| unique_datasets, counts = np.unique(dataset_labels, return_counts=True) | |
| info = "**π Dataset Information:**\n\n" | |
| for dataset, count in zip(unique_datasets, counts): | |
| info += f"β’ **{dataset}**: {count} items\n" | |
| info += f"\n**Total Data Items**: {len(data_names)}\n" | |
| info += f"**Total Features**: {features.shape[1] if features is not None else 'Unknown'}\n" | |
| # Note: Clustering information not available in this demo | |
| info += f"\n**Note**: Clustering information not available in this demo" | |
| return info | |
| # Create the interface | |
| with gr.Blocks(title="π UMAP Visualization Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Interactive UMAP Visualization Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # UMAP Visualization | |
| gr.Markdown("## π UMAP Embedding Visualization") | |
| umap_plot = gr.Plot(umap_fig, label="UMAP Plot") | |
| with gr.Column(scale=1): | |
| # 3D Geometry Viewer | |
| gr.Markdown("## π― 3D Geometry Viewer") | |
| # Quick selection input | |
| gr.Markdown("### Quick Selection") | |
| selection_dropdown = gr.Dropdown( | |
| choices=data_names, | |
| label="Select Data Item", | |
| info="Type to search or select from dropdown", | |
| allow_custom_value=True, | |
| value=None | |
| ) | |
| geometry_plot = gr.Plot(label="3D Geometry") | |
| # Update geometry when selection changes | |
| selection_dropdown.change( | |
| fn=update_geometry_from_selection, | |
| inputs=[selection_dropdown], | |
| outputs=[geometry_plot, umap_plot] | |
| ) | |
| return demo | |
| def main(): | |
| """Main function to launch the Gradio demo.""" | |
| print("π Launching Gradio UMAP Visualization Demo...") | |
| # Check if we're in the right directory | |
| if not os.path.exists("config.yaml"): | |
| print("β Error: config.yaml not found. Please run this script from the umap directory.") | |
| return 1 | |
| # Check for required data files (with dynamic naming) | |
| import glob | |
| # Look for any files matching the patterns in files/ directory | |
| embedding_files = glob.glob("files/*_umap_embedding.npy") + glob.glob("files/umap_embedding.npy") | |
| features_files = glob.glob("files/*_features.npy") + glob.glob("files/features.npy") | |
| labels_files = glob.glob("files/*_names.npy") + glob.glob("files/names.npy") | |
| if not embedding_files or not features_files or not labels_files: | |
| print("β Error: Required UMAP data files not found") | |
| print("Looking for files matching patterns in files/ directory:") | |
| print(" - files/*_umap_embedding.npy or files/umap_embedding.npy") | |
| print(" - files/*_features.npy or files/features.npy") | |
| print(" - files/*_names.npy or files/names.npy") | |
| print("Please run the PhysicsNeMo analysis first using: python run_umap.py") | |
| return 1 | |
| try: | |
| # Create and launch the interface | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7862, | |
| share=False, | |
| show_error=True, | |
| quiet=False | |
| ) | |
| return 0 | |
| except Exception as e: | |
| print(f"β Error launching demo: {e}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |