Spaces:
Sleeping
Sleeping
| # ========= Headless / Offscreen safety (before any VTK import) ========= | |
| import os | |
| os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1") | |
| os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1") | |
| os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe") | |
| os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3") | |
| os.environ.setdefault("DISPLAY", "") | |
| # ========= Core setup ========= | |
| import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio | |
| import numpy as np | |
| import torch | |
| import pyvista as pv | |
| from scipy.spatial import cKDTree | |
| from vtk.util import numpy_support as nps | |
| import matplotlib.cm as cm | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import hf_hub_download | |
| from accelerate import Accelerator | |
| from accelerate.utils import DistributedDataParallelKwargs | |
| import pickle | |
| from sklearn.metrics import pairwise_distances | |
| from train import get_single_latent | |
| from sklearn.neighbors import NearestNeighbors | |
| from utils.app_utils2 import ( | |
| create_visualization_points, | |
| create_visualization_stl, | |
| camera_from_bounds, | |
| bounds_from_points, | |
| convert_vtp_to_glb, | |
| convert_vtp_to_stl, | |
| time_function, | |
| print_timing, | |
| mesh_get_variable, | |
| mph_to_ms, | |
| get_boundary_conditions_text, | |
| compute_confidence_score, | |
| compute_cosine_score, | |
| decimate_mesh, | |
| ) | |
| # ========= trame ========= | |
| from trame.app import TrameApp | |
| from trame.decorators import change | |
| from trame.ui.vuetify3 import SinglePageLayout | |
| from trame.widgets import vuetify3 as v3, html | |
| from trame_vtk.widgets.vtk import VtkRemoteView | |
| # ========= VTK ========= | |
| from vtkmodules.vtkRenderingCore import ( | |
| vtkRenderer, | |
| vtkRenderWindow, | |
| vtkPolyDataMapper, | |
| vtkActor, | |
| vtkRenderWindowInteractor, | |
| ) | |
| from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor | |
| from vtkmodules.vtkIOGeometry import vtkSTLReader | |
| from vtkmodules.vtkFiltersCore import vtkTriangleFilter | |
| from vtkmodules.vtkCommonCore import vtkLookupTable | |
| from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera | |
| # ========= Writable paths / caches ========= | |
| DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.environ.setdefault("MPLCONFIGDIR", DATA_DIR) | |
| GEOM_DIR = os.path.join(DATA_DIR, "geometry") | |
| SOLN_DIR = os.path.join(DATA_DIR, "solution") | |
| WEIGHTS_DIR = os.path.join(DATA_DIR, "weights") | |
| for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR): | |
| os.makedirs(d, exist_ok=True) | |
| HF_DIR = os.path.join(DATA_DIR, "hf_home") | |
| os.environ.setdefault("HF_HOME", HF_DIR) | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR) | |
| os.makedirs(HF_DIR, exist_ok=True) | |
| for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR): | |
| if not os.access(p, os.W_OK): | |
| raise RuntimeError(f"Not writable: {p}") | |
| # ========= Auto-decimation ladder ========= | |
| def auto_target_reduction(num_cells: int) -> float: | |
| if num_cells <= 10_000: | |
| return 0.0 | |
| elif num_cells <= 20_000: | |
| return 0.2 | |
| elif num_cells <= 50_000: | |
| return 0.4 | |
| elif num_cells <= 100_000: | |
| return 0.5 | |
| elif num_cells <= 500_000: | |
| return 0.6 | |
| elif num_cells < 1_000_000: | |
| return 0.8 | |
| else: | |
| return 0.9 | |
| # ========= Registry / choices ========= | |
| REGISTRY = { | |
| "Incompressible flow inside artery": { | |
| "repo_id": "ansysresearch/pretrained_models", | |
| "config": "configs/app_configs/Incompressible flow inside artery/config.yaml", | |
| "model_attr": "ansysLPFMs", | |
| "checkpoints": {"best_val": "ckpt_artery.pt"}, | |
| "out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"], | |
| }, | |
| "Vehicle crash analysis": { | |
| "repo_id": "ansysresearch/pretrained_models", | |
| "config": "configs/app_configs/Vehicle crash analysis/config.yaml", | |
| "model_attr": "ansysLPFMs", | |
| "checkpoints": {"best_val": "ckpt_vehiclecrash.pt"}, | |
| "out_variable": [ | |
| "impact_force", | |
| "deformation", | |
| "energy_absorption", | |
| "x_displacement", | |
| "y_displacement", | |
| "z_displacement", | |
| ], | |
| }, | |
| "Compressible flow over plane": { | |
| "repo_id": "ansysresearch/pretrained_models", | |
| "config": "configs/app_configs/Compressible flow over plane/config.yaml", | |
| "model_attr": "ansysLPFMs", | |
| "checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"}, | |
| "out_variable": ["pressure"], | |
| }, | |
| "Incompressible flow over car": { | |
| "repo_id": "ansysresearch/pretrained_models", | |
| "config": "configs/app_configs/Incompressible flow over car/config.yaml", | |
| "model_attr": "ansysLPFMs", | |
| "checkpoints": {"best_val": "ckpt_cadillac_v3.pt"}, | |
| "out_variable": ["pressure"], | |
| }, | |
| } | |
| def variables_for(dataset: str): | |
| spec = REGISTRY.get(dataset, {}) | |
| ov = spec.get("out_variable") | |
| if isinstance(ov, str): | |
| return [ov] | |
| if isinstance(ov, (list, tuple)): | |
| return list(ov) | |
| return list(spec.get("checkpoints", {}).keys()) | |
| ANALYSIS_TYPE_MAPPING = { | |
| "External flow": ["Incompressible flow over car", "Compressible flow over plane"], | |
| "Internal flow": ["Incompressible flow inside artery"], | |
| "Crash analysis": ["Vehicle crash analysis"], | |
| } | |
| ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys()) | |
| DEFAULT_ANALYSIS_TYPE = "External flow" | |
| DEFAULT_DATASET = "Incompressible flow over car" | |
| VAR_CHOICES0 = variables_for(DEFAULT_DATASET) | |
| DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None | |
| # ========= Simple cache ========= | |
| class GeometryCache: | |
| def __init__(self): | |
| self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation | |
| self.current_mesh = None # what the app is actually using right now | |
| GEOMETRY_CACHE = GeometryCache() | |
| # ========= Model store ========= | |
| class ModelStore: | |
| def __init__(self): | |
| self._cache = {} | |
| def _build(self, dataset: str, progress_cb=None): | |
| def tick(x): | |
| if progress_cb: | |
| try: | |
| progress_cb(int(x)) | |
| except: | |
| pass | |
| if dataset in self._cache: | |
| tick(12) | |
| return self._cache[dataset] | |
| print(f"🔧 Building model for {dataset}") | |
| start_time = time.time() | |
| try: | |
| spec = REGISTRY[dataset] | |
| repo_id = spec["repo_id"] | |
| ckpt_name = spec["checkpoints"]["best_val"] | |
| tick(6) | |
| t0 = time.time() | |
| ckpt_path_hf = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=ckpt_name, | |
| repo_type="model", | |
| local_dir=HF_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print_timing("Model checkpoint download", t0) | |
| tick(8) | |
| t0 = time.time() | |
| ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset) | |
| os.makedirs(ckpt_local_dir, exist_ok=True) | |
| ckpt_path = os.path.join(ckpt_local_dir, ckpt_name) | |
| if not os.path.exists(ckpt_path): | |
| shutil.copy(ckpt_path_hf, ckpt_path) | |
| print_timing("Local model copy setup", t0) | |
| tick(9) | |
| t0 = time.time() | |
| cfg_path = spec["config"] | |
| if not os.path.exists(cfg_path): | |
| raise FileNotFoundError(f"Missing config: {cfg_path}") | |
| cfg = OmegaConf.load(cfg_path) | |
| cfg.save_latent = True | |
| print_timing("Configuration loading", t0) | |
| tick(11) | |
| t0 = time.time() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0)) | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print_timing("Device init", t0) | |
| tick(12) | |
| t0 = time.time() | |
| import models | |
| model_cls_name = spec["model_attr"] | |
| if not hasattr(models, model_cls_name): | |
| raise ValueError(f"Model '{model_cls_name}' not found") | |
| model = getattr(models, model_cls_name)(cfg).to(device) | |
| print_timing("Model build", t0) | |
| tick(14) | |
| t0 = time.time() | |
| state = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state) | |
| model.eval() | |
| print_timing("Weights load", t0) | |
| tick(15) | |
| result = (cfg, model, device, accelerator) | |
| self._cache[dataset] = result | |
| print_timing(f"Total model build for {dataset}", start_time) | |
| return result | |
| except Exception as e: | |
| print_timing(f"Model build failed for {dataset}", e) | |
| raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}") | |
| def get(self, dataset: str, variable: str, progress_cb=None): | |
| return self._build(dataset, progress_cb=progress_cb) | |
| MODEL_STORE = ModelStore() | |
| # ========= Inference pipeline ========= | |
| def _variable_index(dataset: str, variable: str) -> int: | |
| ov = REGISTRY[dataset]["out_variable"] | |
| return 0 if isinstance(ov, str) else ov.index(variable) | |
| def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None): | |
| jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json") | |
| json_data = json.load(open(jpath, "r")) | |
| pts = np.asarray(mesh.points, dtype=np.float32) | |
| N = pts.shape[0] | |
| rng = np.random.default_rng(42) | |
| idx = rng.permutation(N) | |
| points = pts[idx] | |
| tgt_np = mesh_get_variable(mesh, variable, N)[idx] | |
| pos = torch.from_numpy(points) | |
| target = torch.from_numpy(tgt_np).unsqueeze(-1) | |
| if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None: | |
| if "freestream_velocity" in boundary_conditions: | |
| inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1) | |
| inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx] | |
| pos = torch.cat((pos, inlet_x_velocity), dim=1) | |
| if getattr(cfg, "input_normalization", None) == "shift_axis": | |
| coords = pos[:, :3].clone() | |
| coords[:, 0] = coords[:, 0] - coords[:, 0].min() | |
| coords[:, 2] = coords[:, 2] - coords[:, 2].min() | |
| y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0 | |
| coords[:, 1] = coords[:, 1] - y_center | |
| pos[:, :3] = coords | |
| if getattr(cfg, "pos_embed_sincos", False): | |
| mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32) | |
| maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32) | |
| pos = 1000.0 * (pos - mins) / (maxs - mins) | |
| pos = torch.clamp(pos, 0, 1000) | |
| cosine_score = compute_cosine_score(mesh, dataset) | |
| return pos, target, points, cosine_score | |
| def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None): | |
| def p(v): | |
| if progress_cb: | |
| try: | |
| progress_cb(int(v)) | |
| except Exception: | |
| pass | |
| if GEOMETRY_CACHE.current_mesh is None: | |
| raise ValueError("No geometry loaded") | |
| p(5) | |
| cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p) | |
| p(15) | |
| pos, target, points, cosine_score = process_mesh_fast( | |
| GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions | |
| ) | |
| p(25) | |
| confidence_score = 0.0 | |
| try: | |
| if dataset not in ["Incompressible flow inside artery"]: | |
| geom_path = os.path.join(GEOM_DIR, "geometry.stl") | |
| latent_features = get_single_latent( | |
| mesh_path=geom_path, | |
| config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"), | |
| device=device, | |
| custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None, | |
| use_training_velocity=False, | |
| model=model, | |
| ) | |
| embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy") | |
| pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl") | |
| scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl") | |
| embedding = np.load(embedding_path) | |
| pca_reducer = pickle.load(open(pca_reducer_path, "rb")) | |
| scaler = pickle.load(open(scaler_path, "rb")) | |
| train_pair_dists = pairwise_distances(embedding) | |
| sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0 | |
| n_points, n_features = latent_features.shape | |
| np.random.seed(42) | |
| target_len = int(pca_reducer.n_features_in_ / 256) | |
| if n_points > target_len: | |
| latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)] | |
| elif n_points < target_len: | |
| num_extra = target_len - n_points | |
| extra_indices = np.random.choice(n_points, num_extra, replace=True) | |
| latent_features = np.vstack([latent_features, latent_features[extra_indices]]) | |
| latent_features = latent_features.flatten() | |
| confidence_score = compute_confidence_score( | |
| latent_features, embedding, scaler, pca_reducer, sigma | |
| ) | |
| except Exception: | |
| confidence_score = 0.0 | |
| data = { | |
| "input_pos": pos.unsqueeze(0).to(device), | |
| "output_feat": target.unsqueeze(0).to(device), | |
| } | |
| with torch.no_grad(): | |
| inp = data["input_pos"] | |
| _, N, _ = inp.shape | |
| chunk = int(getattr(cfg, "num_points", 10000)) | |
| if getattr(cfg, "chunked_eval", False) and chunk < N: | |
| input_pos = data["input_pos"] | |
| chunk_size = cfg.num_points | |
| out_chunks = [] | |
| total = (N + chunk_size - 1) // chunk_size | |
| for k, i in enumerate(range(0, N, chunk_size)): | |
| ch = input_pos[:, i : i + chunk_size, :] | |
| n_valid = ch.shape[1] | |
| if n_valid < chunk_size: | |
| pad = input_pos[:, : chunk_size - n_valid, :] | |
| ch = torch.cat([ch, pad], dim=1) | |
| data["input_pos"] = ch | |
| out_chunk = model(data) | |
| if isinstance(out_chunk, (tuple, list)): | |
| out_chunk = out_chunk[0] | |
| out_chunks.append(out_chunk[:, :n_valid, :]) | |
| p(25 + 60 * (k + 1) / max(1, total)) | |
| outputs = torch.cat(out_chunks, dim=1) | |
| else: | |
| p(40) | |
| outputs = model(data) | |
| if isinstance(outputs, (tuple, list)): | |
| outputs = outputs[0] | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| p(85) | |
| vi = _variable_index(dataset, variable) | |
| pred = outputs[0, :, vi : vi + 1] | |
| if getattr(cfg, "normalization", "") == "std_norm": | |
| fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json") | |
| j = json.load(open(fp, "r")) | |
| mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device) | |
| sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device) | |
| pred = pred * sd + mu | |
| pred_np = pred.squeeze().detach().cpu().numpy() | |
| tgt_np = target.squeeze().numpy() | |
| pred_t = torch.from_numpy(pred_np).unsqueeze(-1) | |
| tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1) | |
| rel_l2 = torch.mean( | |
| torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1) | |
| / torch.norm(tgt_t.squeeze(-1), p=2, dim=-1) | |
| ) | |
| tgt_mean = torch.mean(tgt_t) | |
| ss_tot = torch.sum((tgt_t - tgt_mean) ** 2) | |
| ss_res = torch.sum((tgt_t - pred_t) ** 2) | |
| r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0) | |
| p(100) | |
| return { | |
| "points": np.asarray(points), | |
| "pred": np.asarray(pred_np), | |
| "tgt": np.asarray(tgt_np), | |
| "cosine_score": float(cosine_score), | |
| "confidence_score": float(confidence_score), | |
| "abs_err": float(np.mean(np.abs(pred_np - tgt_np))), | |
| "mse_err": float(np.mean((pred_np - tgt_np) ** 2)), | |
| "rel_l2": float(rel_l2.item()), | |
| "r_squared": float(r2.item()), | |
| } | |
| # ========= VTK helpers ========= | |
| def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)): | |
| r = vtkSTLReader() | |
| r.SetFileName(stl_path) | |
| r.Update() | |
| tri = vtkTriangleFilter() | |
| tri.SetInputConnection(r.GetOutputPort()) | |
| tri.Update() | |
| m = vtkPolyDataMapper() | |
| m.SetInputConnection(tri.GetOutputPort()) | |
| a = vtkActor() | |
| a.SetMapper(m) | |
| a.GetProperty().SetColor(*color) | |
| return a | |
| def build_jet_lut(vmin, vmax): | |
| lut = vtkLookupTable() | |
| lut.SetRange(float(vmin), float(vmax)) | |
| lut.SetNumberOfTableValues(256) | |
| lut.Build() | |
| cmap = cm.get_cmap("jet", 256) | |
| for i in range(256): | |
| r_, g_, b_, _ = cmap(i) | |
| lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0) | |
| return lut | |
| def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None): | |
| r = vtkSTLReader() | |
| r.SetFileName(stl_path) | |
| r.Update() | |
| poly = r.GetOutput() | |
| stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData()) | |
| tree = cKDTree(points_xyz) | |
| _, nn_idx = tree.query(stl_pts, k=1) | |
| scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx] | |
| vtk_arr = nps.numpy_to_vtk(scalars, deep=True) | |
| vtk_arr.SetName(array_name) | |
| poly.GetPointData().AddArray(vtk_arr) | |
| poly.GetPointData().SetActiveScalars(array_name) | |
| mapper = vtkPolyDataMapper() | |
| mapper.SetInputData(poly) | |
| mapper.SetScalarModeToUsePointData() | |
| mapper.ScalarVisibilityOn() | |
| mapper.SetScalarRange(float(vmin), float(vmax)) | |
| if lut is None: | |
| lut = build_jet_lut(vmin, vmax) | |
| mapper.SetLookupTable(lut) | |
| mapper.UseLookupTableScalarRangeOn() | |
| actor = vtkActor() | |
| actor.SetMapper(mapper) | |
| return actor | |
| def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8): | |
| existing = [] | |
| ca = renderer.GetActors2D() | |
| ca.InitTraversal() | |
| for _ in range(ca.GetNumberOfItems()): | |
| a = ca.GetNextItem() | |
| if isinstance(a, vtkScalarBarActor): | |
| existing.append(a) | |
| for a in existing: | |
| renderer.RemoveActor2D(a) | |
| sbar = vtkScalarBarActor() | |
| sbar.SetLookupTable(lut) | |
| sbar.SetOrientationToVertical() | |
| sbar.SetLabelFormat(label_fmt) | |
| sbar.SetNumberOfLabels(int(n_labels)) | |
| sbar.SetTitle(title) | |
| sbar.SetPosition(0.92, 0.05) | |
| sbar.SetPosition2(0.06, 0.90) | |
| tp = sbar.GetTitleTextProperty() | |
| tp.SetColor(1, 1, 1) | |
| tp.SetBold(True) | |
| tp.SetFontSize(22) | |
| lp = sbar.GetLabelTextProperty() | |
| lp.SetColor(1, 1, 1) | |
| lp.SetFontSize(18) | |
| renderer.AddActor2D(sbar) | |
| return sbar | |
| def _on_velocity_change(self, velocity_mph, **_): | |
| """Keep velocity_mph numeric and in a reasonable range.""" | |
| try: | |
| v = float(velocity_mph) | |
| except Exception: | |
| v = 45.0 # fallback default | |
| v = max(0.0, min(300.0, v)) | |
| if v != self.state.velocity_mph: | |
| self.state.velocity_mph = v | |
| # ---------- Small helpers ---------- | |
| def poly_count(mesh: pv.PolyData) -> int: | |
| if hasattr(mesh, "n_faces_strict"): | |
| return mesh.n_faces_strict | |
| return mesh.n_cells | |
| def md_to_html(txt: str) -> str: | |
| if not txt: | |
| return "" | |
| safe = _html.escape(txt) | |
| safe = re.sub(r"\*\*(.+?)\*\*", r"<b>\1</b>", safe) | |
| return "<br/>".join(safe.splitlines()) | |
| def bc_text_right(dataset: str) -> str: | |
| if dataset == "Incompressible flow over car": | |
| return ( | |
| "<b>Reference Density:</b> 1.225 kg/m³<br><br>" | |
| "<b>Reference Viscosity:</b> 1.789e-5 Pa·s<br><br>" | |
| "<b>Operating Pressure:</b> 101325 Pa" | |
| ) | |
| if dataset == "Compressible flow over plane": | |
| return ( | |
| "<b>Ambient Temperature:</b> 218 K<br><br>" | |
| "<b>Cruising velocity:</b> 250.0 m/s or 560 mph" | |
| ) | |
| return "" | |
| def bc_text_left(dataset: str) -> str: | |
| if dataset == "Compressible flow over plane": | |
| return ( | |
| "<b>Reference Density:</b> 0.36 kg/m³<br><br>" | |
| "<b>Reference viscosity:</b> 1.716e-05 kg/(m·s)<br><br>" | |
| "<b>Operating Pressure:</b> 23842 Pa" | |
| ) | |
| return "" | |
| # ===================================================================== | |
| # ======================= APP ======================================= | |
| # ===================================================================== | |
| class PFMDemo(TrameApp): | |
| def __init__(self, server=None): | |
| super().__init__(server) | |
| # ---------------- VTK RENDERERS ---------------- | |
| self.ren_geom = vtkRenderer() | |
| self.ren_geom.SetBackground(0.10, 0.16, 0.22) | |
| self.rw_geom = vtkRenderWindow() | |
| self.rw_geom.SetOffScreenRendering(1) | |
| self.rw_geom.AddRenderer(self.ren_geom) | |
| self.rwi_geom = vtkRenderWindowInteractor() | |
| self.rwi_geom.SetRenderWindow(self.rw_geom) | |
| self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera()) | |
| try: | |
| self.rwi_geom.Initialize() | |
| self.rwi_geom.Enable() | |
| except Exception: | |
| pass | |
| self.ren_pred = vtkRenderer() | |
| self.ren_pred.SetBackground(0.10, 0.16, 0.22) | |
| self.rw_pred = vtkRenderWindow() | |
| self.rw_pred.SetOffScreenRendering(1) | |
| self.rw_pred.AddRenderer(self.ren_pred) | |
| self.rwi_pred = vtkRenderWindowInteractor() | |
| self.rwi_pred.SetRenderWindow(self.rw_pred) | |
| self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera()) | |
| try: | |
| self.rwi_pred.Initialize() | |
| self.rwi_pred.Enable() | |
| except Exception: | |
| pass | |
| self.scalar_bar = None | |
| # timers / flags | |
| self._predict_t0 = None | |
| self._infer_thread = None | |
| self._pre_upload_thread = None | |
| self._infer_heartbeat_on = False | |
| self._loop = None | |
| # ---------------- TRAME STATE ---------------- | |
| s = self.state | |
| s.theme_dark = True | |
| s.analysis_types = ANALYSIS_TYPE | |
| s.analysis_type = DEFAULT_ANALYSIS_TYPE | |
| s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE] | |
| s.dataset = DEFAULT_DATASET | |
| s.variable_choices = variables_for(DEFAULT_DATASET) | |
| s.variable = s.variable_choices[0] if s.variable_choices else None | |
| # dialog (still kept) | |
| s.show_decimation_dialog = False | |
| s.decim_override_enabled = False | |
| s.decim_override_mode = "medium" | |
| s.decim_override_custom = 0.5 | |
| # menu decimation defaults | |
| s.decim_enable = False # user MUST toggle to override auto | |
| s.decim_target = 0.5 | |
| s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced | |
| s.decim_max_faces = int(1e7) | |
| # register controller properly | |
| # self.server.controller.decimate_again = self.decimate_again | |
| # self.server.controller.add("decimate_again", self.decimate_again) | |
| ctrl = self.server.controller | |
| # ✅ this actually registers the trigger | |
| ctrl.add("decimate_again", self.decimate_again) | |
| ctrl.add("reset_mesh", self.reset_mesh) | |
| s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car") | |
| s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane") | |
| s.velocity_mph = 45.0 | |
| s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET) | |
| s.bc_left = bc_text_left(DEFAULT_DATASET) | |
| s.bc_right = bc_text_right(DEFAULT_DATASET) | |
| s.bc_text_html = s.bc_right or md_to_html(s.bc_text) | |
| s.stats_html = "👋 Upload a geometry, then click Predict." | |
| s.upload = None | |
| # upload | |
| s.is_uploading = False | |
| s.pm_upload = 0 | |
| s.pm_elapsed_upload = 0.0 | |
| s.upload_msg = "" | |
| # predict | |
| s.is_predicting = False | |
| s.predict_progress = 0 | |
| s.predict_msg = "" | |
| s.predict_elapsed = 0.0 | |
| s.predict_est_total = 0.0 | |
| s.pm_infer = 0 | |
| s.pm_elapsed_infer = 0.0 | |
| self._build_ui() | |
| def _ensure_loop(self): | |
| if self._loop is not None: | |
| return self._loop | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| self._loop = loop | |
| return loop | |
| def _run_coro(self, coro): | |
| loop = self._ensure_loop() | |
| if loop.is_running(): | |
| return asyncio.ensure_future(coro, loop=loop) | |
| return loop.run_until_complete(coro) | |
| async def _flush_async(self): | |
| try: | |
| self.server.state.flush() | |
| except Exception: | |
| pass | |
| await asyncio.sleep(0) | |
| def _build_ui(self): | |
| ctrl = self.server.controller | |
| with SinglePageLayout(self.server, full_height=True) as layout: | |
| layout.title.set_text("") # clear | |
| layout.title.hide = True # hide default | |
| with layout.toolbar: | |
| with v3.VContainer( | |
| fluid=True, | |
| style=( | |
| "max-width: 1800px;" # overall width | |
| "margin: 0 auto;" # center it | |
| "padding: 0 8px;" # ← left/right margin | |
| "box-sizing: border-box;" | |
| ), | |
| ): | |
| v3.VSpacer() | |
| html.Div( | |
| "Ansys: Physics Foundation Model", | |
| style=( | |
| "width:100%;" | |
| "text-align:center;" | |
| "font-size:34px;" | |
| "font-weight:900;" | |
| "letter-spacing:0.4px;" | |
| "line-height:1.2;" | |
| ), | |
| ) | |
| v3.VSpacer() | |
| # toolbar | |
| with layout.toolbar: | |
| # ← same margin container for the second toolbar row | |
| with v3.VContainer( | |
| fluid=True, | |
| style=( | |
| "max-width: 1800px;" | |
| "margin: 0 auto;" | |
| "padding: 0 8px;" | |
| "box-sizing: border-box;" | |
| ), | |
| ): | |
| v3.VSwitch( | |
| v_model=("theme_dark",), | |
| label="Dark Theme", | |
| inset=True, | |
| density="compact", | |
| hide_details=True, | |
| ) | |
| v3.VSpacer() | |
| with layout.content: | |
| html.Style(""" | |
| /* Small side padding for the whole app */ | |
| .v-application__wrap { | |
| padding-left: 8px; | |
| padding-right: 8px; | |
| padding-bottom: 8px; | |
| } | |
| :root { | |
| --pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; | |
| --pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace; | |
| } | |
| html, body, .v-application { | |
| margin: 0; | |
| padding: 0; | |
| font-family: var(--pfm-font-ui) !important; | |
| font-weight: 500; | |
| letter-spacing: .25px; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| text-rendering: optimizeLegibility; | |
| line-height: 1.5; | |
| font-size: 15.5px; | |
| color: #ECEFF4; | |
| } | |
| /* ... keep all your other typography / button / slider styles here ... */ | |
| .v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; } | |
| .v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; } | |
| /* (rest of your .pfm-* classes unchanged) */ | |
| """) | |
| # html.Style(""" | |
| # .v-theme--dark { background: #1F232B !important; } | |
| # .v-theme--light { background: #f5f6f8 !important; } | |
| # .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; } | |
| # .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; } | |
| # .v-theme--dark .pfm-viewer { background: #15171d !important; } | |
| # .v-theme--light .pfm-viewer { background: #e9edf3 !important; } | |
| # .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); } | |
| # .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; } | |
| # .pfm-btn-big.v-btn { | |
| # height: 48px !important; | |
| # font-size: 18px !important; | |
| # font-weight: 600 !important; | |
| # letter-spacing: 1.2px; | |
| # text-transform: none !important; | |
| # border-radius: 999px !important; | |
| # } | |
| # .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; } | |
| # """) | |
| html.Link( | |
| rel="stylesheet", | |
| href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap", | |
| ) | |
| with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)): | |
| with v3.VContainer( | |
| fluid=True, | |
| class_="pa-6", | |
| style=( | |
| "max-width: 2200px;" # max width of content | |
| "margin: 8px auto 16px auto;" # top / left-right / bottom | |
| "padding: 0 8px;" # inner left/right padding | |
| "box-sizing: border-box;" | |
| "background: rgba(255,255,255,0.02);" | |
| "border-radius: 16px;" | |
| ), | |
| ): | |
| # 1) Physics Application | |
| with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3): | |
| html.Div( | |
| "🧪 <b>Physics Application</b>", | |
| style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;", | |
| ) | |
| html.Div( | |
| "Select the type of analysis", | |
| style="font-size:24px;opacity:.82;margin-bottom:18px;", | |
| ) | |
| toggle = v3.VBtnToggle( | |
| v_model=("analysis_type", self.state.analysis_type), | |
| class_="mt-1", | |
| mandatory=True, | |
| rounded=True, | |
| ) | |
| # with toggle: | |
| # for at in ANALYSIS_TYPE: | |
| # v3.VBtn( | |
| # at, | |
| # value=at, | |
| # variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"), | |
| # class_="mr-2 pfm-toggle-xxl", | |
| # ) | |
| with toggle: | |
| for at in ANALYSIS_TYPE: | |
| v3.VBtn( | |
| at, | |
| value=at, | |
| variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"), | |
| class_="mr-2 pfm-toggle-xxl", | |
| style=( | |
| "font-size:18px;" | |
| "font-weight:800;" | |
| "letter-spacing:0.4px;" | |
| "text-transform:none;" | |
| ), | |
| ) | |
| # 2) Dataset + Variable | |
| with v3.VRow(dense=True, class_="mb-3"): | |
| with v3.VCol(cols=6): | |
| with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3): | |
| html.Div( | |
| "🧩 Sub Application", | |
| style="font-weight:700;font-size:24px;margin-bottom:14px;", | |
| ) | |
| v3.VSelect( | |
| v_model=("dataset", self.state.dataset), | |
| items=("dataset_choices", self.state.dataset_choices), | |
| hide_details=True, | |
| density="comfortable", | |
| style=( | |
| "font-size:24px;" | |
| "font-weight:800;" | |
| "height:56px;" | |
| "display:flex;" | |
| "align-items:center;" | |
| ), | |
| class_="pfm-big-select-subapp", | |
| menu_props={"content_class": "pfm-subapp-list"}, # <— key for dropdown items | |
| ) | |
| # v3.VSelect( | |
| # v_model=("dataset", self.state.dataset), | |
| # items=("dataset_choices", self.state.dataset_choices), | |
| # hide_details=True, | |
| # density="comfortable", | |
| # class_="pfm-big-select-subapp pfm-subapp-list", | |
| # style="font-size:21px;", | |
| # ) | |
| with v3.VCol(cols=6): | |
| with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3): | |
| html.Div( | |
| "📊 Variable to Predict", | |
| style="font-weight:700;font-size:20px;margin-bottom:14px;", | |
| ) | |
| v3.VSelect( | |
| v_model=("variable", self.state.variable), | |
| items=("variable_choices", self.state.variable_choices), | |
| hide_details=True, | |
| density="comfortable", | |
| class_="pfm-var-select", | |
| style=( | |
| "font-size:20px;" | |
| "font-weight:800;" | |
| "height:56px;" | |
| "display:flex;" | |
| "align-items:center;" | |
| ), | |
| menu_props={"content_class": "pfm-var-list"}, | |
| ) | |
| # v3.VSelect( | |
| # v_model=("variable", self.state.variable), | |
| # items=("variable_choices", self.state.variable_choices), | |
| # hide_details=True, | |
| # density="comfortable", | |
| # style="font-size:16px;", | |
| # ) | |
| # 3) Boundary Conditions | |
| with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3): | |
| html.Div( | |
| "🧱 Boundary Conditions", | |
| style="font-weight:700;font-size:22px;margin-bottom:16px;", | |
| ) | |
| # two columns: Left = velocity controls, Right = reference text | |
| with v3.VRow(class_="align-start", dense=True): | |
| # ---- LEFT: velocity slider / readout ---- | |
| with v3.VCol(cols=7, class_="pfm-vel"): | |
| html.Div( | |
| "🚗 Velocity (mph)", | |
| class_="pfm-vel-title", | |
| style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;", | |
| ) | |
| html.Div( | |
| "Set the inlet velocity in miles per hour", | |
| class_="pfm-vel-sub", | |
| style="margin-bottom:10px;font-size:20px;opacity:.95;", | |
| ) | |
| with v3.VRow(class_="align-center mt-3 mb-2", v_if=("show_velocity",)): | |
| with v3.VCol(cols=8): | |
| v3.VSlider( | |
| v_model=("velocity_mph",), | |
| min=30.0, | |
| max=80.0, | |
| step=0.1, | |
| thumb_label=True, | |
| style="height:54px;max-width:540px;", | |
| class_="pfm-vel-slider", | |
| ) | |
| with v3.VCol(cols=4): | |
| v3.VTextField( | |
| v_model=("velocity_mph",), | |
| type="number", | |
| min=0, | |
| max=300, | |
| step="0.1", | |
| density="comfortable", | |
| variant="outlined", | |
| hide_details=True, | |
| suffix="mph", | |
| style="max-width:160px;margin-left:8px;", | |
| ) | |
| html.Div( | |
| "{{ Number(velocity_mph).toFixed(0) }} / 80 " | |
| "<span style='opacity:.95'>" | |
| "({{ (Number(velocity_mph) * 0.44704).toFixed(2) }} m/s)</span>", | |
| v_if=("show_velocity",), | |
| class_="pfm-vel-readout", | |
| style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:4px;", | |
| ) | |
| # ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ---- | |
| with v3.VCol(cols=5, class_="pfm-bc-right"): | |
| html.Div( | |
| v_html=("bc_text_html", ""), | |
| style=( | |
| "margin-top:6px;" | |
| "font-size:18px;" | |
| "line-height:1.7;" | |
| "min-width:260px;" | |
| "max-width:360px;" | |
| "text-align:left;" | |
| ), | |
| ) | |
| # 4) Two viewers | |
| with v3.VRow(style="margin-top: 24px;"): | |
| # LEFT = upload | |
| with v3.VCol(cols=6): | |
| with v3.VRow(class_="align-center justify-space-between mb-2"): | |
| html.Div( | |
| "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📤 Input Geometry</span>", | |
| ) | |
| # ✅ working gear menu | |
| with v3.VMenu( | |
| location="bottom end", | |
| close_on_content_click=False, | |
| offset="4 8", | |
| ): | |
| # activator slot MUST expose { props } and we MUST bind them to the button | |
| with v3.Template(v_slot_activator="{ props }"): | |
| with v3.VBtn( | |
| icon=True, | |
| variant="text", | |
| density="comfortable", | |
| style="min-width:32px;", | |
| v_bind="props", # 👈 this is the key | |
| ): | |
| v3.VIcon("mdi-cog", size="22") | |
| # menu content | |
| with v3.VCard(class_="pa-4", style="min-width: 280px;"): | |
| html.Div("<b>Mesh decimation</b>", class_="mb-3", style="font-size:14px;") | |
| v3.VSwitch( | |
| v_model=("decim_enable",), | |
| label="Enable decimation", | |
| inset=True, | |
| hide_details=True, | |
| class_="mb-4", | |
| ) | |
| html.Div( | |
| "Target reduction (fraction of faces to remove)", | |
| class_="mb-1", | |
| style="font-size:12px;color:#9ca3af;", | |
| ) | |
| v3.VSlider( | |
| v_model=("decim_target",), | |
| min=0.0, | |
| max=0.999, | |
| step=0.001, | |
| hide_details=True, | |
| class_="mb-2", | |
| ) | |
| html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3") | |
| with v3.VRow(dense=True, class_="mb-3"): | |
| with v3.VCol(cols=6): | |
| html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1") | |
| v3.VTextField( | |
| v_model=("decim_min_faces",), | |
| type="number", | |
| density="compact", | |
| hide_details=True, | |
| ) | |
| with v3.VCol(cols=6): | |
| html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1") | |
| v3.VTextField( | |
| v_model=("decim_max_faces",), | |
| type="number", | |
| density="compact", | |
| hide_details=True, | |
| ) | |
| v3.VBtn( | |
| "Apply to current mesh", | |
| block=True, | |
| color="primary", | |
| class_="mt-2", | |
| click=self.decimate_again, | |
| ) | |
| v3.VBtn( | |
| "Reset to original mesh", | |
| block=True, | |
| variant="tonal", | |
| class_="mt-2", | |
| click=self.reset_mesh, # 👈 will call the controller you added | |
| ) | |
| v3.VFileInput( | |
| label="Select 3D File", | |
| style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;", | |
| multiple=False, | |
| show_size=True, | |
| accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb", | |
| v_model=("upload", None), | |
| clearable=True, | |
| ) | |
| with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"): | |
| self.view_geom = VtkRemoteView( | |
| self.rw_geom, | |
| interactive=True, | |
| interactive_ratio=1, | |
| server=self.server, | |
| ) | |
| with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress", | |
| rounded=True, elevation=3): | |
| html.Div("<b>Upload</b>", style="font-size:18px;") | |
| # progress bar: only while uploading | |
| v3.VProgressLinear( | |
| v_model=("pm_upload", 0), | |
| height=22, | |
| style="margin-top:10px;margin-bottom:10px;", | |
| color="primary", | |
| rounded=True, | |
| v_show=("is_uploading",), # 👈 bar disappears after upload | |
| ) | |
| # text: percentage + time + message, only while uploading | |
| html.Div( | |
| "{{ pm_upload }}% — {{ pm_elapsed_upload.toFixed(2) }}s — {{ upload_msg }}", | |
| style="font-size:14px;", | |
| v_show=("is_uploading",), # 👈 hide text after completion | |
| ) | |
| v3.VBtn( | |
| "🗑️ CLEAR", | |
| block=True, | |
| variant="tonal", | |
| class_="mt-3 pfm-btn-big", | |
| style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;", | |
| click=self.clear, | |
| ) | |
| # RIGHT = prediction | |
| with v3.VCol(cols=6): | |
| html.Div( | |
| "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📈 Prediction Results</span>", | |
| style="margin-bottom:10px;", | |
| ) | |
| html.Div( | |
| v_html=("stats_html",), | |
| class_="mb-3", | |
| style="font-size:20px;line-height:1.4;", | |
| ) | |
| # v3.VProgressLinear( | |
| # v_model=("predict_progress", 0), | |
| # height=22, | |
| # style="margin-top:6px;margin-bottom:12px;", | |
| # color="primary", | |
| # rounded=True, | |
| # indeterminate=("predict_progress < 10",), | |
| # v_show=("is_predicting",), | |
| # ) | |
| # html.Div( | |
| # "Predicting: {{ predict_progress }}%", | |
| # style="font-size:14px;margin-bottom:10px;", | |
| # v_show=("is_predicting",), | |
| # ) | |
| with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"): | |
| self.view_pred = VtkRemoteView( | |
| self.rw_pred, | |
| interactive=True, | |
| interactive_ratio=1, | |
| server=self.server, | |
| ) | |
| with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress", | |
| rounded=True, elevation=3): | |
| html.Div("<b>Inference</b>", style="font-size:18px;") | |
| # 🔴 OLD: v_model=("predict_progress", 0), indeterminate=... | |
| # 🟢 NEW: use pm_infer and a normal (non-indeterminate) bar | |
| v3.VProgressLinear( | |
| v_model=("pm_infer", 0), | |
| height=22, | |
| style="margin-top:6px;margin-bottom:12px;", | |
| color="success", | |
| rounded=True, | |
| indeterminate=("predict_progress <= 0",), | |
| v_show=("is_predicting",), # 👈 bar only visible while predicting | |
| ) | |
| # text line: % + elapsed time + current stage message | |
| html.Div( | |
| "{{ pm_infer }}% — {{ pm_elapsed_infer.toFixed(2) }}s — {{ predict_msg }}", | |
| style="font-size:14px;margin-bottom:10px;", | |
| # ❗ if you want the *text* to also disappear at the end, keep v_show; | |
| # if you want the final "✅ Prediction complete — 1.23s" to stay, REMOVE v_show | |
| v_show=("is_predicting",), | |
| ) | |
| v3.VBtn( | |
| "🚀 PREDICT", | |
| block=True, | |
| color="primary", | |
| class_="mt-3 pfm-btn-big", | |
| style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;", | |
| click=self.predict, | |
| ) | |
| layout.on_ready = self._first_paint | |
| def _first_paint(self, **_): | |
| for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)): | |
| try: | |
| rw.Render() | |
| except Exception: | |
| pass | |
| view.update() | |
| # --------------------------------------------------------- | |
| # UPLOAD (async) | |
| # --------------------------------------------------------- | |
| def _write_upload_to_disk(self, payload) -> str: | |
| if payload is None: | |
| raise ValueError("No file payload") | |
| if isinstance(payload, (list, tuple)): | |
| payload = payload[0] | |
| if isinstance(payload, str): | |
| return payload | |
| if not isinstance(payload, dict): | |
| raise ValueError(f"Unsupported payload: {type(payload)}") | |
| if payload.get("path"): | |
| return payload["path"] | |
| name = payload.get("name") or "upload" | |
| content = payload.get("content") | |
| if isinstance(content, str) and content.startswith("data:"): | |
| content = content.split(",", 1)[1] | |
| raw = base64.b64decode(content) if isinstance(content, str) else bytes(content) | |
| os.makedirs(GEOM_DIR, exist_ok=True) | |
| file_path = os.path.join(GEOM_DIR, name) | |
| with open(file_path, "wb") as f: | |
| f.write(raw) | |
| return file_path | |
| def _pre_upload_spinner_loop(self): | |
| s = self.state | |
| phase = 1 | |
| while self._pre_upload_on and not self._upload_actual_started and s.is_uploading: | |
| s.pm_upload = max(1, min(9, phase)) | |
| s.upload_msg = "Initializing upload..." | |
| try: | |
| self.server.state.flush() | |
| except Exception: | |
| pass | |
| phase = 1 if phase >= 9 else phase + 1 | |
| time.sleep(0.15) | |
| def _start_pre_upload_spinner(self): | |
| if self._pre_upload_thread and self._pre_upload_thread.is_alive(): | |
| return | |
| self._upload_actual_started = False | |
| self._pre_upload_on = True | |
| self._pre_upload_thread = threading.Thread( | |
| target=self._pre_upload_spinner_loop, daemon=True | |
| ) | |
| self._pre_upload_thread.start() | |
| def _stop_pre_upload_spinner(self): | |
| self._pre_upload_on = False | |
| self._pre_upload_thread = None | |
| async def _fake_upload_bump(self, stop_event: asyncio.Event): | |
| s = self.state | |
| while not stop_event.is_set() and s.pm_upload < 9: | |
| s.pm_upload += 1 | |
| await self._flush_async() | |
| await asyncio.sleep(0.05) | |
| async def _upload_worker_async(self, upload): | |
| s = self.state | |
| loop = self._ensure_loop() | |
| t0 = time.time() | |
| s.is_uploading = True | |
| s.upload_msg = "Starting upload..." | |
| s.pm_elapsed_upload = 0.0 | |
| s.pm_upload = 5 | |
| self.server.state.flush() | |
| await asyncio.sleep(0) | |
| fake_stop = asyncio.Event() | |
| fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop)) | |
| try: | |
| self._upload_actual_started = True | |
| self._stop_pre_upload_spinner() | |
| if not fake_stop.is_set(): | |
| fake_stop.set() | |
| try: | |
| await fake_task | |
| except asyncio.CancelledError: | |
| pass | |
| s.upload_msg = "Writing file to disk..." | |
| s.pm_upload = 10 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload) | |
| s.upload_msg = "Reading mesh..." | |
| s.pm_upload = 20 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| mesh = await loop.run_in_executor(None, pv.read, file_path) | |
| # 3) decimation (auto first) | |
| try: | |
| nf = poly_count(mesh) | |
| except Exception: | |
| nf = mesh.n_cells | |
| auto_tr = float(auto_target_reduction(nf)) | |
| # reflect auto in UI | |
| s.decim_target = auto_tr | |
| s.decim_min_faces = 5000 # <= allow decimation even for 27k faces | |
| s.decim_max_faces = int(1e7) | |
| target = auto_tr | |
| min_faces = 5000 | |
| max_faces = int(1e7) | |
| # user override | |
| if self.state.decim_enable: | |
| target = float(self.state.decim_target or 0.0) | |
| min_faces = int(self.state.decim_min_faces or 5000) | |
| max_faces = int(self.state.decim_max_faces or 1e7) | |
| if target > 0.0: | |
| s.upload_msg = f"Decimating mesh ({target:.3f})..." | |
| s.pm_upload = max(s.pm_upload, 45) | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| dec_cfg = { | |
| "enabled": True, | |
| "method": "pro", | |
| "target_reduction": target, | |
| "min_faces": min_faces, | |
| "max_faces": max_faces, | |
| } | |
| mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg) | |
| # 4) normals + save | |
| s.upload_msg = "Preparing geometry..." | |
| s.pm_upload = 75 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| def _normals_and_save(m): | |
| m_fixed = m.compute_normals( | |
| consistent_normals=True, | |
| auto_orient_normals=True, | |
| point_normals=True, | |
| cell_normals=False, | |
| inplace=False, | |
| ) | |
| geom_path_ = os.path.join(GEOM_DIR, "geometry.stl") | |
| m_fixed.save(geom_path_) | |
| return geom_path_, m_fixed | |
| geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh) | |
| # 5) update viewer | |
| self.ren_geom.RemoveAllViewProps() | |
| self.ren_geom.AddActor(make_actor_from_stl(geom_path)) | |
| self.ren_geom.ResetCamera() | |
| try: | |
| self.rw_geom.Render() | |
| except Exception: | |
| pass | |
| self.view_geom.update() | |
| # GEOMETRY_CACHE.current_mesh = mesh_fixed | |
| GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True) | |
| GEOMETRY_CACHE.current_mesh = mesh_fixed | |
| s.upload_msg = "✅ Geometry ready." | |
| s.pm_upload = 100 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| except Exception as e: | |
| s.upload_msg = f"❌ Upload failed: {e}" | |
| s.pm_upload = 0 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| finally: | |
| s.is_uploading = False | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| if not fake_stop.is_set(): | |
| fake_stop.set() | |
| if not fake_task.done(): | |
| fake_task.cancel() | |
| try: | |
| await fake_task | |
| except Exception: | |
| pass | |
| def _on_upload_change(self, upload, **_): | |
| if not upload: | |
| return | |
| self._run_coro(self._upload_worker_async(upload)) | |
| def decimate_again(self): | |
| self._run_coro(self._decimate_again_async()) | |
| async def _decimate_again_async(self): | |
| s = self.state | |
| loop = self._ensure_loop() | |
| if GEOMETRY_CACHE.current_mesh is None: | |
| # nothing to decimate | |
| s.upload_msg = "No mesh to re-decimate" | |
| await self._flush_async() | |
| return | |
| # --- start "upload-like" progress for manual decimation --- | |
| t0 = time.time() | |
| s.is_uploading = True | |
| s.pm_upload = 5 | |
| s.pm_elapsed_upload = 0.0 | |
| s.upload_msg = "Starting mesh re-decimation..." | |
| await self._flush_async() | |
| try: | |
| # read parameters from UI | |
| try: | |
| target = float(s.decim_target) | |
| except Exception: | |
| target = 0.0 | |
| try: | |
| min_faces = int(s.decim_min_faces) | |
| except Exception: | |
| min_faces = 5000 | |
| try: | |
| max_faces = int(s.decim_max_faces) | |
| except Exception: | |
| max_faces = int(1e7) | |
| if (not s.decim_enable) or target <= 0.0: | |
| s.upload_msg = "Decimation disabled" | |
| s.pm_upload = 0 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| return | |
| # --- bump before heavy decimation call --- | |
| s.upload_msg = f"Re-decimating mesh ({target:.3f})..." | |
| s.pm_upload = 25 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| dec_cfg = { | |
| "enabled": True, | |
| "method": "pro", | |
| "target_reduction": target, | |
| "min_faces": min_faces, | |
| "max_faces": max_faces, | |
| } | |
| # heavy work on executor | |
| mesh = await loop.run_in_executor( | |
| None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg | |
| ) | |
| # --- normals + save --- | |
| s.upload_msg = "Recomputing normals & saving..." | |
| s.pm_upload = 70 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| def _normals_and_save(m): | |
| m_fixed = m.compute_normals( | |
| consistent_normals=True, | |
| auto_orient_normals=True, | |
| point_normals=True, | |
| cell_normals=False, | |
| inplace=False, | |
| ) | |
| geom_path_ = os.path.join(GEOM_DIR, "geometry.stl") | |
| m_fixed.save(geom_path_) | |
| return geom_path_, m_fixed | |
| geom_path, mesh_fixed = await loop.run_in_executor( | |
| None, _normals_and_save, mesh | |
| ) | |
| # --- update viewer --- | |
| s.upload_msg = "Updating viewer..." | |
| s.pm_upload = 90 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| self.ren_geom.RemoveAllViewProps() | |
| self.ren_geom.AddActor(make_actor_from_stl(geom_path)) | |
| self.ren_geom.ResetCamera() | |
| try: | |
| self.rw_geom.Render() | |
| except Exception: | |
| pass | |
| self.view_geom.update() | |
| GEOMETRY_CACHE.current_mesh = mesh_fixed | |
| # --- final bump --- | |
| s.upload_msg = "✅ Re-decimated" | |
| s.pm_upload = 100 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| except Exception as e: | |
| s.upload_msg = f"❌ Re-decimation failed: {e}" | |
| s.pm_upload = 0 | |
| s.pm_elapsed_upload = time.time() - t0 | |
| await self._flush_async() | |
| finally: | |
| # hide bar + text after we’re done | |
| s.is_uploading = False | |
| await self._flush_async() | |
| def reset_mesh(self): | |
| self._run_coro(self._reset_mesh_async()) | |
| async def _reset_mesh_async(self): | |
| s = self.state | |
| if GEOMETRY_CACHE.original_mesh is None: | |
| s.upload_msg = "No original mesh to reset to" | |
| await self._flush_async() | |
| return | |
| # use the saved original | |
| orig = GEOMETRY_CACHE.original_mesh | |
| # save it again as current | |
| GEOMETRY_CACHE.current_mesh = orig | |
| # write to disk (so the STL on disk matches the viewer) | |
| geom_path = os.path.join(GEOM_DIR, "geometry.stl") | |
| orig.save(geom_path) | |
| # update viewer | |
| self.ren_geom.RemoveAllViewProps() | |
| self.ren_geom.AddActor(make_actor_from_stl(geom_path)) | |
| self.ren_geom.ResetCamera() | |
| try: | |
| self.rw_geom.Render() | |
| except Exception: | |
| pass | |
| self.view_geom.update() | |
| s.upload_msg = "↩️ Reset to original mesh" | |
| await self._flush_async() | |
| # --------------------------------------------------------- | |
| # prediction | |
| # --------------------------------------------------------- | |
| def _start_infer_heartbeat(self): | |
| if self._infer_thread and self._infer_thread.is_alive(): | |
| return | |
| def loop_fn(): | |
| while self._infer_heartbeat_on: | |
| if self.state.is_predicting and self._predict_t0 is not None: | |
| self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0) | |
| try: | |
| self.server.state.flush() | |
| except Exception: | |
| pass | |
| time.sleep(0.12) | |
| self._infer_heartbeat_on = True | |
| self._infer_thread = threading.Thread(target=loop_fn, daemon=True) | |
| self._infer_thread.start() | |
| def _stop_infer_heartbeat(self): | |
| self._infer_heartbeat_on = False | |
| self._infer_thread = None | |
| async def _predict_worker_async(self): | |
| s = self.state | |
| loop = self._ensure_loop() | |
| t0 = time.time() | |
| if GEOMETRY_CACHE.current_mesh is None: | |
| s.predict_msg = "❌ Please upload geometry first" | |
| s.is_predicting = False | |
| await self._flush_async() | |
| return | |
| s.is_predicting = True | |
| s.predict_progress = 1 | |
| s.pm_infer = 1 | |
| s.predict_msg = "Preparing inference..." | |
| self._predict_t0 = time.time() | |
| self._start_infer_heartbeat() | |
| await self._flush_async() | |
| try: | |
| dataset = s.dataset | |
| variable = s.variable | |
| boundary = ( | |
| {"freestream_velocity": mph_to_ms(s.velocity_mph)} | |
| if dataset == "Incompressible flow over car" | |
| else None | |
| ) | |
| s.predict_msg = "Loading model/checkpoint..." | |
| s.predict_progress = 5 | |
| s.pm_infer = 5 | |
| await self._flush_async() | |
| cfg, model, device, _ = await loop.run_in_executor( | |
| None, MODEL_STORE.get, dataset, variable, None | |
| ) | |
| s.predict_msg = "Processing mesh for inference..." | |
| s.predict_progress = 35 | |
| s.pm_infer = 35 | |
| await self._flush_async() | |
| def _run_full(): | |
| return run_inference_fast( | |
| dataset, | |
| variable, | |
| boundary_conditions=boundary, | |
| progress_cb=None, | |
| ) | |
| viz = await loop.run_in_executor(None, _run_full) | |
| s.predict_msg = "Preparing visualization..." | |
| s.predict_progress = 85 | |
| s.pm_infer = 85 | |
| await self._flush_async() | |
| stl_path = os.path.join(GEOM_DIR, "geometry.stl") | |
| vmin = float(np.min(viz["pred"])) | |
| vmax = float(np.max(viz["pred"])) | |
| if os.path.exists(stl_path): | |
| _tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path) | |
| lut = build_jet_lut(vmin, vmax) | |
| colored_actor = color_actor_with_scalars_from_prediction( | |
| stl_path, | |
| viz["points"], | |
| viz["pred"], | |
| variable, | |
| vmin, | |
| vmax, | |
| lut=lut, | |
| ) | |
| self.ren_pred.AddActor(colored_actor) | |
| units = { | |
| "pressure": "Pa", | |
| "x_velocity": "m/s", | |
| "y_velocity": "m/s", | |
| "z_velocity": "m/s", | |
| }.get(variable, "") | |
| title = f"{variable} ({units})" if units else variable | |
| self.scalar_bar = add_or_update_scalar_bar( | |
| self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8 | |
| ) | |
| src_cam = self.ren_geom.GetActiveCamera() | |
| dst_cam = self.ren_pred.GetActiveCamera() | |
| if src_cam is not None and dst_cam is not None: | |
| dst_cam.SetPosition(src_cam.GetPosition()) | |
| dst_cam.SetFocalPoint(src_cam.GetFocalPoint()) | |
| dst_cam.SetViewUp(src_cam.GetViewUp()) | |
| dst_cam.SetParallelScale(src_cam.GetParallelScale()) | |
| cr = src_cam.GetClippingRange() | |
| dst_cam.SetClippingRange(cr) | |
| try: | |
| self.rw_pred.Render() | |
| except Exception: | |
| pass | |
| self.view_pred.update() | |
| raw_vmin = float(np.min(viz["pred"])) | |
| raw_vmax = float(np.max(viz["pred"])) | |
| s.stats_html = ( | |
| f"<b>{variable} min:</b>{raw_vmin:.3e} " | |
| f"<b>max:</b> {raw_vmax:.3e} " | |
| f"<b>Confidence:</b> {viz['confidence_score']:.4f}" | |
| ) | |
| s.predict_msg = "✅ Prediction complete." | |
| s.predict_progress = 100 | |
| s.pm_infer = 100 | |
| s.predict_elapsed = time.time() - t0 | |
| s.pm_elapsed_infer = s.predict_elapsed | |
| await self._flush_async() | |
| except Exception as e: | |
| s.predict_msg = f"❌ Prediction failed: {e}" | |
| s.predict_progress = 0 | |
| s.pm_infer = 0 | |
| await self._flush_async() | |
| finally: | |
| s.is_predicting = False | |
| self._stop_infer_heartbeat() | |
| await self._flush_async() | |
| def predict(self, *_): | |
| self._run_coro(self._predict_worker_async()) | |
| # --------------------------------------------------------- | |
| # dataset wiring | |
| # --------------------------------------------------------- | |
| def _on_analysis_type_change(self, analysis_type=None, **_): | |
| ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", []) | |
| default_ds = ds_list[0] if ds_list else None | |
| self.state.dataset_choices = ds_list | |
| if default_ds and self.state.dataset != default_ds: | |
| self.state.dataset = default_ds | |
| elif self.state.dataset: | |
| self._apply_dataset(self.state.dataset) | |
| def _on_dataset_change(self, dataset=None, **_): | |
| if not dataset: | |
| return | |
| self._apply_dataset(dataset) | |
| def _apply_dataset(self, ds: str): | |
| s = self.state | |
| opts = variables_for(ds) if ds else [] | |
| s.variable_choices = opts | |
| s.variable = opts[0] if opts else None | |
| s.show_velocity = (ds == "Incompressible flow over car") | |
| s.is_plane = (ds == "Compressible flow over plane") | |
| s.bc_text = get_boundary_conditions_text(ds) | |
| s.bc_left = bc_text_left(ds) | |
| s.bc_right = bc_text_right(ds) | |
| s.bc_text_html = s.bc_right or md_to_html(s.bc_text) | |
| # --------------------------------------------------------- | |
| # clear | |
| # --------------------------------------------------------- | |
| def clear(self, *_): | |
| for d in [GEOM_DIR, SOLN_DIR]: | |
| if os.path.exists(d): | |
| shutil.rmtree(d) | |
| os.makedirs(d, exist_ok=True) | |
| s = self.state | |
| s.stats_html = "🧹 Cleared. Upload again." | |
| s.is_uploading = False | |
| s.pm_upload = 0 | |
| s.upload_msg = "" | |
| s.pm_elapsed_upload = 0.0 | |
| s.is_predicting = False | |
| s.predict_progress = 0 | |
| s.predict_msg = "" | |
| s.pm_infer = 0 | |
| s.pm_elapsed_infer = 0.0 | |
| self.ren_geom.RemoveAllViewProps() | |
| self.ren_pred.RemoveAllViewProps() | |
| for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)): | |
| try: | |
| rw.Render() | |
| except Exception: | |
| pass | |
| view.update() | |
| # ---------- main ---------- | |
| def main(): | |
| app = PFMDemo() | |
| app.server.controller.add("decimate_again", app.decimate_again) | |
| app.server.controller.add("reset_mesh", app.reset_mesh) | |
| app.server.start(7872) | |
| if __name__ == "__main__": | |
| main() | |