Spaces:
Sleeping
Sleeping
| import os, numpy as np, pyvista as pv, trimesh, matplotlib.pyplot as plt | |
| from scipy.spatial import cKDTree | |
| from functools import wraps | |
| import time | |
| from scipy.ndimage import gaussian_filter | |
| from vtk.util import numpy_support as nps | |
| from scipy.spatial.distance import cosine | |
| def _to_vertex_colors(values): | |
| v = np.asarray(values, dtype=float) | |
| vmin, vmax = float(np.nanmin(v)), float(np.nanmax(v) + 1e-12) | |
| norm = (v - vmin) / (vmax - vmin) | |
| return (plt.get_cmap("viridis")(norm)[:, :3] * 255).astype(np.uint8) | |
| def _pv_to_trimesh(poly: pv.PolyData, colors=None): | |
| V = np.asarray(poly.points) | |
| F = poly.faces.reshape(-1, 4)[:, 1:4] if poly.faces.size else None | |
| m = trimesh.Trimesh(vertices=V, faces=F, process=False) | |
| if colors is not None and len(colors) == len(V): | |
| if colors.dtype != np.uint8: | |
| colors = (np.clip(colors, 0, 1) * 255).astype(np.uint8) | |
| m.visual.vertex_colors = colors | |
| return m | |
| def vtp_to_glb(in_vtp: str, out_glb: str, scalar: str | None = None, point_size_frac=0.004, decimation_config=None): | |
| """Read a .vtp, produce a fast-to-load .glb for Gradio/three.js.""" | |
| poly = pv.read(in_vtp) | |
| # Apply decimation if enabled | |
| if decimation_config and decimation_config.get("enabled", True): | |
| try: | |
| original_faces = poly.n_cells | |
| target_faces = int(original_faces * (1 - decimation_config["target_reduction"])) | |
| target_faces = max(decimation_config["min_faces"], min(target_faces, decimation_config["max_faces"])) | |
| if original_faces > target_faces and poly.faces.size > 0: | |
| print(f"π§ Decimating VTP mesh: {original_faces} β {target_faces} faces") | |
| poly = poly.decimate( | |
| target_reduction=decimation_config["target_reduction"], | |
| preserve_topology=decimation_config["preserve_topology"] | |
| ) | |
| print(f"β VTP decimation successful: {poly.n_cells} faces") | |
| except Exception as e: | |
| print(f"β οΈ VTP decimation failed: {e}, using original mesh") | |
| # If it's a triangle mesh: compute normals & color (optional), export GLB | |
| if poly.faces.size > 0: | |
| poly2 = poly.compute_normals(consistent_normals=True, auto_orient_normals=True, | |
| point_normals=True, cell_normals=False, inplace=False) | |
| colors = _to_vertex_colors(poly2.point_data[scalar]) if (scalar and scalar in poly2.point_data) else None | |
| tm_mesh = _pv_to_trimesh(poly2, colors) | |
| trimesh.exchange.gltf.export_glb(tm_mesh, out_glb) | |
| return out_glb | |
| # If it's a point cloud: glyph to small spheres so points are visible | |
| pc = pv.PolyData(poly.points) | |
| if scalar and scalar in poly.point_data: | |
| pc.point_data[scalar] = np.asarray(poly.point_data[scalar]) | |
| # pick radius relative to bbox diagonal (tweak if needed) | |
| xmin, xmax, ymin, ymax, zmin, zmax = pc.bounds | |
| diag = np.linalg.norm([xmax - xmin, ymax - ymin, zmax - zmin]) | |
| radius = max(1e-8, point_size_frac * diag) | |
| sphere = pv.Sphere(radius=radius, theta_resolution=16, phi_resolution=16) | |
| glyphs = pc.glyph(geom=sphere, scale=False) | |
| colors = None | |
| if scalar and scalar in glyphs.point_data: | |
| colors = _to_vertex_colors(glyphs.point_data[scalar]) | |
| tm_mesh = _pv_to_trimesh(glyphs, colors) | |
| trimesh.exchange.gltf.export_glb(tm_mesh, out_glb) | |
| return out_glb | |
| import matplotlib.pyplot as plt | |
| from matplotlib.cm import ScalarMappable | |
| from matplotlib.colors import Normalize | |
| def save_colorbar_png(vmin, vmax, out_path, cmap_name="jet", units="", variable_name=""): | |
| # Create figure with transparent background | |
| fig, ax = plt.subplots(figsize=(0.1, 8.0), facecolor='none') # transparent background | |
| ax.set_facecolor('none') # transparent axis background | |
| norm = Normalize(vmin=vmin, vmax=vmax) | |
| sm = ScalarMappable(norm=norm, cmap=cmap_name) | |
| cb = plt.colorbar(sm, cax=ax, orientation="vertical") | |
| # Customize colorbar appearance | |
| cb.ax.set_facecolor('none') # transparent colorbar background | |
| cb.ax.tick_params(colors='white', labelsize=16) # larger font size for better visibility | |
| # Custom format function to show mantissa values only | |
| def format_func(x, pos): | |
| # Convert to scientific notation to get mantissa and exponent | |
| if x == 0: | |
| return "0.00" | |
| # Get the exponent | |
| exponent = int(np.floor(np.log10(abs(x)))) | |
| # Get the mantissa | |
| mantissa = x / (10 ** exponent) | |
| # Format mantissa to 2 decimal places | |
| return f"{mantissa:.2f}" | |
| from matplotlib.ticker import FuncFormatter | |
| cb.ax.yaxis.set_major_formatter(FuncFormatter(format_func)) | |
| # Add exponent label at the top | |
| if vmax != 0: | |
| exponent = int(np.floor(np.log10(abs(vmax)))) | |
| if exponent != 0: | |
| cb.ax.text(0.5, 1.05, f"Γ10^{exponent}", transform=cb.ax.transAxes, | |
| ha='center', va='bottom', fontsize=18, color='white', fontweight='bold') | |
| # Add units to the label if provided | |
| if units: | |
| label_text = f"{variable_name} ({units})" if variable_name else f"Value ({units})" | |
| else: | |
| label_text = variable_name if variable_name else "Value" | |
| cb.set_label(label_text, color='white', fontsize=20, fontweight='bold') | |
| cb.ax.yaxis.label.set_color('white') | |
| # Save with transparent background | |
| fig.savefig(out_path, bbox_inches="tight", dpi=200, facecolor='none', edgecolor='none', transparent=True) | |
| plt.close(fig) | |
| return out_path | |
| def create_visualization_points(viz_data): | |
| # Apply point cloud decimation if enabled | |
| points = viz_data["points"] | |
| pred_values = viz_data["pred"] | |
| cloud = trimesh.points.PointCloud(points) | |
| cmap = plt.get_cmap("jet") | |
| v = pred_values | |
| vmin, vmax = float(np.min(v)), float(np.max(v)) | |
| norm = Normalize(vmin=vmin, vmax=vmax) | |
| rgb = (cmap(norm(v))[:, :3] * 255).astype(np.uint8) | |
| cloud.visual.vertex_colors = rgb | |
| center = cloud.centroid | |
| rot_x = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0], center) | |
| rot_z = trimesh.transformations.rotation_matrix(np.radians(180), [0, 1, 0], center) | |
| cloud.apply_transform(rot_z @ rot_x) | |
| return cloud, vmin, vmax | |
| def create_visualization_stl(viz_data, stl_path): | |
| v = viz_data["pred"] | |
| vmin, vmax = float(np.min(v)), float(np.max(v)) | |
| cmap = plt.get_cmap("jet") | |
| norm = Normalize(vmin=vmin, vmax=vmax) | |
| stl_mesh = trimesh.load(stl_path) | |
| stl_points = np.asarray(stl_mesh.vertices, dtype=np.float32) | |
| # Build KDTree to interpolate output on stl coordinate system | |
| tree = cKDTree(viz_data["points"]) | |
| _, idx = tree.query(stl_points, k=1) # nearest neighbor | |
| stl_points_pred = viz_data["pred"][idx] | |
| stl_mesh.visual.vertex_colors = (cmap(norm(stl_points_pred))[:, :3] * 255).astype(np.uint8) | |
| # fix the orientation of the stl mesh | |
| center = stl_mesh.centroid | |
| rot_x = trimesh.transformations.rotation_matrix( | |
| angle=np.radians(-90), | |
| direction=[1, 0, 0], | |
| point=center | |
| ) | |
| rot_z = trimesh.transformations.rotation_matrix( | |
| angle=np.radians(180), | |
| direction=[0, 1, 0], | |
| point=center | |
| ) | |
| rotation_total = rot_z @ rot_x | |
| stl_mesh.apply_transform(rotation_total) | |
| return stl_mesh, vmin, vmax | |
| def create_visualization_vtp(viz_data, variable_name): | |
| """Create a VTP file from visualization data with scalar field""" | |
| # Create a PyVista point cloud | |
| points = viz_data["points"] | |
| pred_values = viz_data["pred"] | |
| # Create PyVista PolyData from points | |
| point_cloud = pv.PolyData(points) | |
| # Add the predicted values as scalar data | |
| point_cloud[variable_name] = pred_values | |
| # # Also add target values if available | |
| # if "tgt" in viz_data: | |
| # point_cloud[f"{variable_name}_target"] = viz_data["tgt"] | |
| return point_cloud | |
| def camera_from_bounds(bounds, distance_scale=2.2): | |
| xmin, xmax, ymin, ymax, zmin, zmax = bounds | |
| center = np.array([(xmin+xmax)/2, (ymin+ymax)/2, (zmin+zmax)/2], dtype=float) | |
| ext = np.array([xmax-xmin, ymax-ymin, zmax-zmin], dtype=float) | |
| diag = float(np.linalg.norm(ext) or 1.0) | |
| dir_vec = np.array([1.0, 1.0, 1.0]) / np.sqrt(3.0) | |
| return (center + distance_scale * diag * dir_vec).tolist() | |
| def bounds_from_points(points: np.ndarray): | |
| mins = points.min(axis=0) | |
| maxs = points.max(axis=0) | |
| return (mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]) | |
| # ========================== Example Loading Handlers ========================== | |
| def convert_vtp_to_glb(vtp_path, output_dir): | |
| """Convert VTP file to GLB format""" | |
| try: | |
| # Read VTP file with pyvista | |
| mesh = pv.read(vtp_path) | |
| mesh = mesh.triangulate() | |
| # Convert to trimesh format | |
| print(" π Converting VTP to GLB...") | |
| if mesh.n_points == mesh.n_cells or mesh.get_cell(0).type == 1: | |
| # Point cloud -> GLB via trimesh | |
| print(" π Point cloud -> GLB via trimesh") | |
| tmesh = trimesh.points.PointCloud(mesh.points) | |
| tmesh.visual.vertex_colors = np.tile([190, 190, 190], (mesh.n_points, 1)) | |
| center = tmesh.centroid | |
| rot_x = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0], center) | |
| rot_z = trimesh.transformations.rotation_matrix(np.radians(180), [0, 1, 0], center) | |
| tmesh.apply_transform(rot_z @ rot_x) | |
| # Save as GLB | |
| glb_path = os.path.join(output_dir, f"{os.path.basename(vtp_path)}.glb") | |
| tmesh.export(glb_path) | |
| return glb_path | |
| else: | |
| # Triangular mesh -> GLB via trimesh | |
| print(" π Triangular mesh -> GLB via trimesh") | |
| mesh = mesh.triangulate() | |
| mesh_fixed = mesh.compute_normals( | |
| consistent_normals=True, auto_orient_normals=True, | |
| point_normals=True, cell_normals=False, inplace=False | |
| ) | |
| tmesh = _pv_to_trimesh(mesh_fixed) | |
| glb_path = os.path.join(output_dir, f"{os.path.basename(vtp_path)}.glb") | |
| tmesh.export(glb_path) | |
| return glb_path | |
| except Exception as e: | |
| print(f"Error converting VTP to GLB: {str(e)}") | |
| return None | |
| def convert_vtp_to_stl(vtp_path, output_dir, decimation_config=None): | |
| """Convert VTP file to STL format""" | |
| try: | |
| # Read VTP file with pyvista | |
| mesh = pv.read(vtp_path) | |
| mesh = mesh.triangulate() | |
| # Apply decimation if enabled | |
| if decimation_config and decimation_config.get("enabled", True): | |
| try: | |
| original_faces = mesh.n_cells | |
| target_faces = int(original_faces * (1 - decimation_config["target_reduction"])) | |
| target_faces = max(decimation_config["min_faces"], min(target_faces, decimation_config["max_faces"])) | |
| if original_faces > target_faces and mesh.faces.size > 0: | |
| print(f"π§ Decimating VTP mesh: {original_faces} β {target_faces} faces") | |
| mesh = mesh.decimate( | |
| target_reduction=decimation_config["target_reduction"] | |
| # preserve_topology=decimation_config["preserve_topology"] | |
| ) | |
| print(f"β VTP decimation successful: {mesh.n_cells} faces") | |
| except Exception as e: | |
| print(f"β οΈ VTP decimation failed: {e}, using original mesh") | |
| # Convert to trimesh format | |
| mesh_fixed = mesh.compute_normals( | |
| consistent_normals=True, auto_orient_normals=True, | |
| point_normals=True, cell_normals=False, inplace=False | |
| ) | |
| geom_path = os.path.join(output_dir, f"{os.path.basename(vtp_path)}.stl") | |
| mesh_fixed.save(geom_path) | |
| return geom_path | |
| except Exception as e: | |
| print(f"Error converting VTP to STL: {str(e)}") | |
| return None | |
| # ========================== Unit Conversion Utilities ========================== | |
| def mph_to_ms(mph): | |
| """Convert miles per hour to meters per second""" | |
| return mph * 0.44704 # 1 mph = 0.44704 m/s | |
| def ms_to_mph(ms): | |
| """Convert meters per second to miles per hour""" | |
| return ms / 0.44704 | |
| def time_function(func_name=None): | |
| """Decorator to time function execution""" | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| name = func_name or func.__name__ | |
| start_time = time.time() | |
| print(f"β±οΈ Starting {name}...") | |
| try: | |
| result = func(*args, **kwargs) | |
| elapsed = time.time() - start_time | |
| print(f"β {name} completed in {elapsed:.3f} seconds") | |
| return result | |
| except Exception as e: | |
| elapsed = time.time() - start_time | |
| print(f"β {name} failed after {elapsed:.3f} seconds: {str(e)}") | |
| raise | |
| return wrapper | |
| return decorator | |
| def print_timing(message, start_time=None): | |
| """Print timing information""" | |
| if start_time is None: | |
| return time.time() | |
| else: | |
| elapsed = time.time() - start_time | |
| print(f"β±οΈ {message}: {elapsed:.3f} seconds") | |
| return time.time() | |
| def vtk_to_np(arr): | |
| if isinstance(arr, np.ndarray): | |
| return arr | |
| try: | |
| return np.asarray(arr) | |
| except Exception: | |
| return nps.vtk_to_numpy(arr) | |
| def mesh_get_variable(mesh: pv.DataSet, var_name: str, npts: int): | |
| if var_name in mesh.point_data: | |
| v = vtk_to_np(mesh.point_data[var_name]).reshape(-1) | |
| elif var_name in mesh.cell_data: | |
| mesh2 = mesh.cell_data_to_point_data() | |
| v = vtk_to_np(mesh2.point_data[var_name]).reshape(-1) | |
| elif var_name in mesh.field_data: | |
| v0 = vtk_to_np(mesh.field_data[var_name]).ravel() | |
| v = np.full(npts, float(v0[0]), dtype=np.float32) if v0.size == 1 else np.zeros(npts, np.float32) | |
| else: | |
| v = np.zeros(npts, np.float32) | |
| return v.astype(np.float32, copy=False) | |
| def get_boundary_conditions_text(dataset): | |
| """Generate boundary conditions display text for the UI""" | |
| if dataset == "Incompressible flow inside artery": | |
| return "" | |
| elif dataset == "Incompressible flow over car": | |
| return """ | |
| **Reference Density:** 1.225 kg/mΒ³ | |
| **Reference Viscosity:** 1.789e-5 PaΒ·s | |
| **Operating Pressure:** 101325 Pa""" | |
| elif dataset == "Compressible flow over plane": | |
| return """ | |
| **Reference Density:** 0.36 kg/mΒ³ | |
| **Reference viscosity:** 1.716e-05 kg/(mΒ·s) | |
| **Operating Pressure:** 23842 Pa | |
| --- | |
| **Ambient Temperature:** 218 K | |
| **Cruising velocity:** 250.0 m/s or 560 mph | |
| """ | |
| elif dataset == "Vehicle crash analysis": | |
| return "" | |
| else: | |
| return "**π Boundary Conditions:** Not specified" | |
| def get_boundary_conditions_left(dataset): | |
| """Generate left column boundary conditions for plane dataset""" | |
| if dataset == "Compressible flow over plane": | |
| return """ | |
| **Reference Density:** 0.36 kg/mΒ³ | |
| **Reference viscosity:** 1.716e-05 kg/(mΒ·s) | |
| **Operating Pressure:** 23842 Pa | |
| """ | |
| else: | |
| return "" | |
| def get_boundary_conditions_right(dataset): | |
| """Generate right column boundary conditions for plane dataset""" | |
| if dataset == "Compressible flow over plane": | |
| return """ | |
| **Ambient Temperature:** 218 K | |
| **Cruising velocity:** 250.0 m/s or 560 mph | |
| """ | |
| else: | |
| return "" | |
| ### utils to compute cosine score | |
| def get_points(mesh, max_points=5000): | |
| """Extract and subsample point cloud from VTP file.""" | |
| try: | |
| points = mesh.points | |
| if len(points) > max_points: | |
| indices = np.random.choice(len(points), max_points, replace=False) | |
| points = points[indices] | |
| return points | |
| except Exception as e: | |
| raise ValueError(f"Error reading {mesh}: {e}") | |
| def compute_cosine_score(mesh, dataset, smooth_sigma=1): | |
| """ | |
| Compute Cosine similarity score for a single VTP file against saved training distribution. | |
| Args: | |
| mesh: PyVista mesh | |
| dataset: Dataset name | |
| smooth_sigma: Gaussian smoothing parameter | |
| Returns: | |
| float: Cosine similarity score (0-1), higher means closer to training distribution | |
| """ | |
| # Load the saved training distribution | |
| train_dist_path = os.path.join("configs/app_configs/" , dataset, "train_dist.npz") | |
| if not os.path.exists(train_dist_path): | |
| raise ValueError(f"Training distribution file not found: {train_dist_path}") | |
| data = np.load(train_dist_path) | |
| train_hist = data['hist'] | |
| bin_edges = [data['edges0'], data['edges1'], data['edges2']] | |
| # Get test points from VTP file | |
| test_points = get_points(mesh) | |
| # Create 3D histogram for test points using same bins as training | |
| test_hist, _ = np.histogramdd(test_points, bins=bin_edges, density=True) | |
| test_hist = gaussian_filter(test_hist, sigma=smooth_sigma) | |
| # Flatten and normalize | |
| test_hist = test_hist.flatten() | |
| test_hist /= test_hist.sum() | |
| # Add small epsilon to avoid zero values | |
| epsilon = 1e-12 | |
| train_hist_safe = train_hist + epsilon | |
| test_hist_safe = test_hist + epsilon | |
| # Compute Cosine similarity (1 - cosine distance) | |
| cosine_distance = cosine(train_hist_safe, test_hist_safe) | |
| cosine_similarity = 1 - cosine_distance | |
| print(f"Cosine Score for {dataset}: {cosine_similarity:.6f}") | |
| return cosine_similarity | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'dark') { | |
| url.searchParams.set('__theme', 'dark'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # ========================== Mesh Decimation Utilities ========================== | |
| def decimate_mesh(mesh, config=None): | |
| """ | |
| Decimate a PyVista mesh for faster visualization while preserving important features. | |
| Args: | |
| mesh: PyVista mesh object | |
| config: Decimation configuration dict (uses DECIMATION_CONFIG if None) | |
| Returns: | |
| Decimated PyVista mesh | |
| """ | |
| try: | |
| # Skip decimation for point clouds | |
| if mesh.n_cells == 0 or mesh.n_points == mesh.n_cells or mesh.get_cell(0).type == 1: | |
| print("βοΈ Skipping decimation for point cloud") | |
| return mesh | |
| ## triangulate mesh | |
| mesh = mesh.triangulate() | |
| original_faces = mesh.n_cells | |
| original_points = mesh.n_points | |
| # Calculate target number of faces | |
| target_faces = int(original_faces * (1 - config["target_reduction"])) | |
| target_faces = max(config["min_faces"], min(target_faces, config["max_faces"])) | |
| # Skip if already small enough | |
| if original_faces <= target_faces: | |
| print(f"π Mesh already small enough: {original_faces} faces") | |
| return mesh | |
| print(f"π§ Decimating mesh: {original_faces} β {target_faces} faces ({original_points} β ~{target_faces*2} points)") | |
| # Use PyVista's decimation | |
| decimated = mesh.decimate( | |
| target_reduction=config["target_reduction"] | |
| # preserve_topology=config["preserve_topology"] | |
| ) | |
| # Ensure we don't go below minimum faces | |
| if decimated.n_cells < config["min_faces"]: | |
| print(f"β οΈ Decimation resulted in too few faces ({decimated.n_cells}), using original mesh") | |
| return mesh | |
| print(f"β Decimation successful: {decimated.n_cells} faces, {decimated.n_points} points") | |
| return decimated | |
| except Exception as e: | |
| print(f"β οΈ Decimation failed: {e}, using original mesh") | |
| return mesh |