# ========= Headless / Offscreen safety (before any VTK import) ========= import os os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1") os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1") os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe") os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3") os.environ.setdefault("DISPLAY", "") # ========= Core setup ========= import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio import numpy as np import torch import pyvista as pv os.environ.setdefault("OMP_NUM_THREADS", "4") # Optional: extra VTK headless hints (you already used similar locally) os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1") os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1") os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe") os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3") # Force headless: never try to open an X11 window # os.environ["DISPLAY"] = "" pv.OFF_SCREEN = True # Tell PyVista to always render offscreen from scipy.spatial import cKDTree from vtk.util import numpy_support as nps import matplotlib.cm as cm from omegaconf import OmegaConf from huggingface_hub import hf_hub_download from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs import pickle from sklearn.metrics import pairwise_distances from train import get_single_latent from sklearn.neighbors import NearestNeighbors from utils.app_utils2 import ( create_visualization_points, create_visualization_stl, camera_from_bounds, bounds_from_points, convert_vtp_to_glb, convert_vtp_to_stl, time_function, print_timing, mesh_get_variable, mph_to_ms, get_boundary_conditions_text, compute_confidence_score, compute_cosine_score, decimate_mesh, ) # ========= trame ========= from trame.app import TrameApp from trame.decorators import change from trame.ui.vuetify3 import SinglePageLayout from trame.widgets import vuetify3 as v3, html from trame_vtk.widgets.vtk import VtkRemoteView # ========= VTK ========= from vtkmodules.vtkRenderingCore import ( vtkRenderer, vtkRenderWindow, vtkPolyDataMapper, vtkActor, vtkRenderWindowInteractor, ) from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor from vtkmodules.vtkIOGeometry import vtkSTLReader from vtkmodules.vtkFiltersCore import vtkTriangleFilter from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera # ========= Writable paths / caches ========= DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata") os.makedirs(DATA_DIR, exist_ok=True) os.environ.setdefault("MPLCONFIGDIR", DATA_DIR) GEOM_DIR = os.path.join(DATA_DIR, "geometry") SOLN_DIR = os.path.join(DATA_DIR, "solution") WEIGHTS_DIR = os.path.join(DATA_DIR, "weights") for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR): os.makedirs(d, exist_ok=True) HF_DIR = os.path.join(DATA_DIR, "hf_home") os.environ.setdefault("HF_HOME", HF_DIR) os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR) os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR) os.makedirs(HF_DIR, exist_ok=True) for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR): if not os.access(p, os.W_OK): raise RuntimeError(f"Not writable: {p}") # ========= Auto-decimation ladder ========= def auto_target_reduction(num_cells: int) -> float: if num_cells <= 10_000: return 0.0 elif num_cells <= 20_000: return 0.2 elif num_cells <= 50_000: return 0.4 elif num_cells <= 100_000: return 0.5 elif num_cells <= 500_000: return 0.6 elif num_cells < 1_000_000: return 0.8 else: return 0.9 # ========= Registry / choices ========= REGISTRY = { "Incompressible flow inside artery": { "repo_id": "ansysresearch/pretrained_models", "config": "configs/app_configs/Incompressible flow inside artery/config.yaml", "model_attr": "ansysLPFMs", "checkpoints": {"best_val": "ckpt_artery.pt"}, "out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"], }, "Vehicle crash analysis": { "repo_id": "ansysresearch/pretrained_models", "config": "configs/app_configs/Vehicle crash analysis/config.yaml", "model_attr": "ansysLPFMs", "checkpoints": {"best_val": "ckpt_vehiclecrash.pt"}, "out_variable": [ "impact_force", "deformation", "energy_absorption", "x_displacement", "y_displacement", "z_displacement", ], }, "Compressible flow over plane": { "repo_id": "ansysresearch/pretrained_models", "config": "configs/app_configs/Compressible flow over plane/config.yaml", "model_attr": "ansysLPFMs", "checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"}, "out_variable": ["pressure"], }, "Incompressible flow over car": { "repo_id": "ansysresearch/pretrained_models", "config": "configs/app_configs/Incompressible flow over car/config.yaml", "model_attr": "ansysLPFMs", "checkpoints": {"best_val": "ckpt_cadillac_v3.pt"}, "out_variable": ["pressure"], }, } def variables_for(dataset: str): spec = REGISTRY.get(dataset, {}) ov = spec.get("out_variable") if isinstance(ov, str): return [ov] if isinstance(ov, (list, tuple)): return list(ov) return list(spec.get("checkpoints", {}).keys()) ANALYSIS_TYPE_MAPPING = { "External flow": ["Incompressible flow over car", "Compressible flow over plane"], "Internal flow": ["Incompressible flow inside artery"], "Crash analysis": ["Vehicle crash analysis"], } ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys()) DEFAULT_ANALYSIS_TYPE = "External flow" DEFAULT_DATASET = "Incompressible flow over car" VAR_CHOICES0 = variables_for(DEFAULT_DATASET) DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None # ========= Simple cache ========= class GeometryCache: def __init__(self): self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation self.current_mesh = None # what the app is actually using right now GEOMETRY_CACHE = GeometryCache() # ========= Model store ========= class ModelStore: def __init__(self): self._cache = {} def _build(self, dataset: str, progress_cb=None): def tick(x): if progress_cb: try: progress_cb(int(x)) except: pass if dataset in self._cache: tick(12) return self._cache[dataset] print(f"πŸ”§ Building model for {dataset}") start_time = time.time() try: spec = REGISTRY[dataset] repo_id = spec["repo_id"] ckpt_name = spec["checkpoints"]["best_val"] tick(6) t0 = time.time() ckpt_path_hf = hf_hub_download( repo_id=repo_id, filename=ckpt_name, repo_type="model", local_dir=HF_DIR, local_dir_use_symlinks=False, ) print_timing("Model checkpoint download", t0) tick(8) t0 = time.time() ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset) os.makedirs(ckpt_local_dir, exist_ok=True) ckpt_path = os.path.join(ckpt_local_dir, ckpt_name) if not os.path.exists(ckpt_path): shutil.copy(ckpt_path_hf, ckpt_path) print_timing("Local model copy setup", t0) tick(9) t0 = time.time() cfg_path = spec["config"] if not os.path.exists(cfg_path): raise FileNotFoundError(f"Missing config: {cfg_path}") cfg = OmegaConf.load(cfg_path) cfg.save_latent = True print_timing("Configuration loading", t0) tick(11) t0 = time.time() os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0)) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print_timing("Device init", t0) tick(12) t0 = time.time() import models model_cls_name = spec["model_attr"] if not hasattr(models, model_cls_name): raise ValueError(f"Model '{model_cls_name}' not found") model = getattr(models, model_cls_name)(cfg).to(device) print_timing("Model build", t0) tick(14) t0 = time.time() state = torch.load(ckpt_path, map_location=device) model.load_state_dict(state) model.eval() print_timing("Weights load", t0) tick(15) result = (cfg, model, device, accelerator) self._cache[dataset] = result print_timing(f"Total model build for {dataset}", start_time) return result except Exception as e: print_timing(f"Model build failed for {dataset}", e) raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}") def get(self, dataset: str, variable: str, progress_cb=None): return self._build(dataset, progress_cb=progress_cb) MODEL_STORE = ModelStore() # ========= Inference pipeline ========= def _variable_index(dataset: str, variable: str) -> int: ov = REGISTRY[dataset]["out_variable"] return 0 if isinstance(ov, str) else ov.index(variable) @time_function("Mesh Processing") def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None): jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json") json_data = json.load(open(jpath, "r")) pts = np.asarray(mesh.points, dtype=np.float32) N = pts.shape[0] rng = np.random.default_rng(42) idx = rng.permutation(N) points = pts[idx] tgt_np = mesh_get_variable(mesh, variable, N)[idx] pos = torch.from_numpy(points) target = torch.from_numpy(tgt_np).unsqueeze(-1) if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None: if "freestream_velocity" in boundary_conditions: inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1) inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx] pos = torch.cat((pos, inlet_x_velocity), dim=1) if getattr(cfg, "input_normalization", None) == "shift_axis": coords = pos[:, :3].clone() coords[:, 0] = coords[:, 0] - coords[:, 0].min() coords[:, 2] = coords[:, 2] - coords[:, 2].min() y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0 coords[:, 1] = coords[:, 1] - y_center pos[:, :3] = coords if getattr(cfg, "pos_embed_sincos", False): mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32) maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32) pos = 1000.0 * (pos - mins) / (maxs - mins) pos = torch.clamp(pos, 0, 1000) cosine_score = compute_cosine_score(mesh, dataset) return pos, target, points, cosine_score @time_function("Inference") def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None): def p(v): if progress_cb: try: progress_cb(int(v)) except Exception: pass if GEOMETRY_CACHE.current_mesh is None: raise ValueError("No geometry loaded") p(5) cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p) p(15) pos, target, points, cosine_score = process_mesh_fast( GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions ) p(25) confidence_score = 0.0 try: if dataset not in ["Incompressible flow inside artery"]: geom_path = os.path.join(GEOM_DIR, "geometry.stl") latent_features = get_single_latent( mesh_path=geom_path, config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"), device=device, custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None, use_training_velocity=False, model=model, ) embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy") pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl") scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl") embedding = np.load(embedding_path) pca_reducer = pickle.load(open(pca_reducer_path, "rb")) scaler = pickle.load(open(scaler_path, "rb")) train_pair_dists = pairwise_distances(embedding) sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0 n_points, n_features = latent_features.shape np.random.seed(42) target_len = int(pca_reducer.n_features_in_ / 256) if n_points > target_len: latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)] elif n_points < target_len: num_extra = target_len - n_points extra_indices = np.random.choice(n_points, num_extra, replace=True) latent_features = np.vstack([latent_features, latent_features[extra_indices]]) latent_features = latent_features.flatten() confidence_score = compute_confidence_score( latent_features, embedding, scaler, pca_reducer, sigma ) except Exception: confidence_score = 0.0 data = { "input_pos": pos.unsqueeze(0).to(device), "output_feat": target.unsqueeze(0).to(device), } with torch.no_grad(): inp = data["input_pos"] _, N, _ = inp.shape chunk = int(getattr(cfg, "num_points", 10000)) if getattr(cfg, "chunked_eval", False) and chunk < N: input_pos = data["input_pos"] chunk_size = cfg.num_points out_chunks = [] total = (N + chunk_size - 1) // chunk_size for k, i in enumerate(range(0, N, chunk_size)): ch = input_pos[:, i : i + chunk_size, :] n_valid = ch.shape[1] if n_valid < chunk_size: pad = input_pos[:, : chunk_size - n_valid, :] ch = torch.cat([ch, pad], dim=1) data["input_pos"] = ch out_chunk = model(data) if isinstance(out_chunk, (tuple, list)): out_chunk = out_chunk[0] out_chunks.append(out_chunk[:, :n_valid, :]) p(25 + 60 * (k + 1) / max(1, total)) outputs = torch.cat(out_chunks, dim=1) else: p(40) outputs = model(data) if isinstance(outputs, (tuple, list)): outputs = outputs[0] if torch.cuda.is_available(): torch.cuda.synchronize() p(85) vi = _variable_index(dataset, variable) pred = outputs[0, :, vi : vi + 1] if getattr(cfg, "normalization", "") == "std_norm": fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json") j = json.load(open(fp, "r")) mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device) sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device) pred = pred * sd + mu pred_np = pred.squeeze().detach().cpu().numpy() tgt_np = target.squeeze().numpy() pred_t = torch.from_numpy(pred_np).unsqueeze(-1) tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1) rel_l2 = torch.mean( torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1) / torch.norm(tgt_t.squeeze(-1), p=2, dim=-1) ) tgt_mean = torch.mean(tgt_t) ss_tot = torch.sum((tgt_t - tgt_mean) ** 2) ss_res = torch.sum((tgt_t - pred_t) ** 2) r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0) p(100) return { "points": np.asarray(points), "pred": np.asarray(pred_np), "tgt": np.asarray(tgt_np), "cosine_score": float(cosine_score), "confidence_score": float(confidence_score), "abs_err": float(np.mean(np.abs(pred_np - tgt_np))), "mse_err": float(np.mean((pred_np - tgt_np) ** 2)), "rel_l2": float(rel_l2.item()), "r_squared": float(r2.item()), } # ========= VTK helpers ========= def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)): r = vtkSTLReader() r.SetFileName(stl_path) r.Update() tri = vtkTriangleFilter() tri.SetInputConnection(r.GetOutputPort()) tri.Update() m = vtkPolyDataMapper() m.SetInputConnection(tri.GetOutputPort()) a = vtkActor() a.SetMapper(m) a.GetProperty().SetColor(*color) return a def build_jet_lut(vmin, vmax): lut = vtkLookupTable() lut.SetRange(float(vmin), float(vmax)) lut.SetNumberOfTableValues(256) lut.Build() cmap = cm.get_cmap("jet", 256) for i in range(256): r_, g_, b_, _ = cmap(i) lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0) return lut def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None): r = vtkSTLReader() r.SetFileName(stl_path) r.Update() poly = r.GetOutput() stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData()) tree = cKDTree(points_xyz) _, nn_idx = tree.query(stl_pts, k=1) scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx] vtk_arr = nps.numpy_to_vtk(scalars, deep=True) vtk_arr.SetName(array_name) poly.GetPointData().AddArray(vtk_arr) poly.GetPointData().SetActiveScalars(array_name) mapper = vtkPolyDataMapper() mapper.SetInputData(poly) mapper.SetScalarModeToUsePointData() mapper.ScalarVisibilityOn() mapper.SetScalarRange(float(vmin), float(vmax)) if lut is None: lut = build_jet_lut(vmin, vmax) mapper.SetLookupTable(lut) mapper.UseLookupTableScalarRangeOn() actor = vtkActor() actor.SetMapper(mapper) return actor def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8): existing = [] ca = renderer.GetActors2D() ca.InitTraversal() for _ in range(ca.GetNumberOfItems()): a = ca.GetNextItem() if isinstance(a, vtkScalarBarActor): existing.append(a) for a in existing: renderer.RemoveActor2D(a) sbar = vtkScalarBarActor() sbar.SetLookupTable(lut) sbar.SetOrientationToVertical() sbar.SetLabelFormat(label_fmt) sbar.SetNumberOfLabels(int(n_labels)) sbar.SetTitle(title) sbar.SetPosition(0.92, 0.05) sbar.SetPosition2(0.06, 0.90) tp = sbar.GetTitleTextProperty() tp.SetColor(1, 1, 1) tp.SetBold(True) tp.SetFontSize(22) lp = sbar.GetLabelTextProperty() lp.SetColor(1, 1, 1) lp.SetFontSize(18) renderer.AddActor2D(sbar) return sbar # ---------- Small helpers ---------- def poly_count(mesh: pv.PolyData) -> int: if hasattr(mesh, "n_faces_strict"): return mesh.n_faces_strict return mesh.n_cells def md_to_html(txt: str) -> str: if not txt: return "" safe = _html.escape(txt) safe = re.sub(r"\*\*(.+?)\*\*", r"\1", safe) return "
".join(safe.splitlines()) def bc_text_right(dataset: str) -> str: if dataset == "Incompressible flow over car": return ( "Reference Density: 1.225 kg/mΒ³

" "Reference Viscosity: 1.789e-5 PaΒ·s

" "Operating Pressure: 101325 Pa" ) if dataset == "Compressible flow over plane": return ( "Ambient Temperature: 218 K

" "Cruising velocity: 250.0 m/s or 560 mph" ) return "" def bc_text_left(dataset: str) -> str: if dataset == "Compressible flow over plane": return ( "Reference Density: 0.36 kg/mΒ³

" "Reference viscosity: 1.716e-05 kg/(mΒ·s)

" "Operating Pressure: 23842 Pa" ) return "" # ===================================================================== # ======================= APP ======================================= # ===================================================================== class PFMDemo(TrameApp): def __init__(self, server=None): super().__init__(server) # ---------------- VTK RENDERERS ---------------- self.ren_geom = vtkRenderer() self.ren_geom.SetBackground(0.10, 0.16, 0.22) self.rw_geom = vtkRenderWindow() self.rw_geom.SetOffScreenRendering(1) self.rw_geom.AddRenderer(self.ren_geom) self.rwi_geom = vtkRenderWindowInteractor() self.rwi_geom.SetRenderWindow(self.rw_geom) self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera()) try: self.rwi_geom.Initialize() self.rwi_geom.Enable() except Exception: pass self.ren_pred = vtkRenderer() self.ren_pred.SetBackground(0.10, 0.16, 0.22) self.rw_pred = vtkRenderWindow() self.rw_pred.SetOffScreenRendering(1) self.rw_pred.AddRenderer(self.ren_pred) self.rwi_pred = vtkRenderWindowInteractor() self.rwi_pred.SetRenderWindow(self.rw_pred) self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera()) try: self.rwi_pred.Initialize() self.rwi_pred.Enable() except Exception: pass self.scalar_bar = None # timers / flags self._predict_t0 = None self._infer_thread = None self._pre_upload_thread = None self._infer_heartbeat_on = False self._loop = None # ---------------- TRAME STATE ---------------- s = self.state s.theme_dark = True s.analysis_types = ANALYSIS_TYPE s.analysis_type = DEFAULT_ANALYSIS_TYPE s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE] s.dataset = DEFAULT_DATASET s.variable_choices = variables_for(DEFAULT_DATASET) s.variable = s.variable_choices[0] if s.variable_choices else None # dialog (still kept) s.show_decimation_dialog = False s.decim_override_enabled = False s.decim_override_mode = "medium" s.decim_override_custom = 0.5 # menu decimation defaults s.decim_enable = False # user MUST toggle to override auto s.decim_target = 0.5 s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced s.decim_max_faces = int(1e7) # register controller properly # self.server.controller.decimate_again = self.decimate_again # self.server.controller.add("decimate_again", self.decimate_again) ctrl = self.server.controller # βœ… this actually registers the trigger ctrl.add("decimate_again", self.decimate_again) ctrl.add("reset_mesh", self.reset_mesh) s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car") s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane") s.velocity_mph = 45.0 s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET) s.bc_left = bc_text_left(DEFAULT_DATASET) s.bc_right = bc_text_right(DEFAULT_DATASET) s.bc_text_html = s.bc_right or md_to_html(s.bc_text) s.stats_html = "πŸ‘‹ Upload a geometry, then click Predict." s.upload = None # upload s.is_uploading = False s.pm_upload = 0 s.pm_elapsed_upload = 0.0 s.upload_msg = "" # predict s.is_predicting = False s.predict_progress = 0 s.predict_msg = "" s.predict_elapsed = 0.0 s.predict_est_total = 0.0 s.pm_infer = 0 s.pm_elapsed_infer = 0.0 self._build_ui() def _ensure_loop(self): if self._loop is not None: return self._loop try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._loop = loop return loop def _run_coro(self, coro): loop = self._ensure_loop() if loop.is_running(): return asyncio.ensure_future(coro, loop=loop) return loop.run_until_complete(coro) async def _flush_async(self): try: self.server.state.flush() except Exception: pass await asyncio.sleep(0) def _build_ui(self): ctrl = self.server.controller with SinglePageLayout(self.server, full_height=True) as layout: layout.title.set_text("") # clear layout.title.hide = True # hide default with layout.toolbar: with v3.VContainer( fluid=True, style=( "max-width: 1800px;" # overall width "margin: 0 auto;" # center it "padding: 0 8px;" # ← left/right margin "box-sizing: border-box;" ), ): v3.VSpacer() html.Div( "Ansys: Physics Foundation Model", style=( "width:100%;" "text-align:center;" "font-size:34px;" "font-weight:900;" "letter-spacing:0.4px;" "line-height:1.2;" ), ) v3.VSpacer() # toolbar with layout.toolbar: # ← same margin container for the second toolbar row with v3.VContainer( fluid=True, style=( "max-width: 1800px;" "margin: 0 auto;" "padding: 0 8px;" "box-sizing: border-box;" ), ): v3.VSwitch( v_model=("theme_dark",), label="Dark Theme", inset=True, density="compact", hide_details=True, ) v3.VSpacer() with layout.content: html.Style(""" /* Small side padding for the whole app */ .v-application__wrap { padding-left: 8px; padding-right: 8px; padding-bottom: 8px; } :root { --pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; --pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace; } html, body, .v-application { margin: 0; padding: 0; font-family: var(--pfm-font-ui) !important; font-weight: 500; letter-spacing: .25px; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; text-rendering: optimizeLegibility; line-height: 1.5; font-size: 15.5px; color: #ECEFF4; } /* ... keep all your other typography / button / slider styles here ... */ .v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; } .v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; } /* (rest of your .pfm-* classes unchanged) */ """) # html.Style(""" # .v-theme--dark { background: #1F232B !important; } # .v-theme--light { background: #f5f6f8 !important; } # .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; } # .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; } # .v-theme--dark .pfm-viewer { background: #15171d !important; } # .v-theme--light .pfm-viewer { background: #e9edf3 !important; } # .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); } # .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; } # .pfm-btn-big.v-btn { # height: 48px !important; # font-size: 18px !important; # font-weight: 600 !important; # letter-spacing: 1.2px; # text-transform: none !important; # border-radius: 999px !important; # } # .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; } # """) html.Link( rel="stylesheet", href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap", ) with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)): with v3.VContainer( fluid=True, class_="pa-6", style=( "max-width: 2200px;" # max width of content "margin: 8px auto 16px auto;" # top / left-right / bottom "padding: 0 8px;" # inner left/right padding "box-sizing: border-box;" "background: rgba(255,255,255,0.02);" "border-radius: 16px;" ), ): # 1) Physics Application with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3): html.Div( "πŸ§ͺ Physics Application", style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;", ) html.Div( "Select the type of analysis", style="font-size:24px;opacity:.82;margin-bottom:18px;", ) toggle = v3.VBtnToggle( v_model=("analysis_type", self.state.analysis_type), class_="mt-1", mandatory=True, rounded=True, ) # with toggle: # for at in ANALYSIS_TYPE: # v3.VBtn( # at, # value=at, # variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"), # class_="mr-2 pfm-toggle-xxl", # ) with toggle: for at in ANALYSIS_TYPE: v3.VBtn( at, value=at, variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"), class_="mr-2 pfm-toggle-xxl", style=( "font-size:18px;" "font-weight:800;" "letter-spacing:0.4px;" "text-transform:none;" ), ) # 2) Dataset + Variable with v3.VRow(dense=True, class_="mb-3"): with v3.VCol(cols=6): with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3): html.Div( "🧩 Sub Application", style="font-weight:700;font-size:24px;margin-bottom:14px;", ) v3.VSelect( v_model=("dataset", self.state.dataset), items=("dataset_choices", self.state.dataset_choices), hide_details=True, density="comfortable", style=( "font-size:24px;" "font-weight:800;" "height:56px;" "display:flex;" "align-items:center;" ), class_="pfm-big-select-subapp", menuProps={"contentClass": "pfm-subapp-list"}, # <β€” key for dropdown items ) # v3.VSelect( # v_model=("dataset", self.state.dataset), # items=("dataset_choices", self.state.dataset_choices), # hide_details=True, # density="comfortable", # class_="pfm-big-select-subapp pfm-subapp-list", # style="font-size:21px;", # ) with v3.VCol(cols=6): with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3): html.Div( "πŸ“Š Variable to Predict", style="font-weight:700;font-size:20px;margin-bottom:14px;", ) v3.VSelect( v_model=("variable", self.state.variable), items=("variable_choices", self.state.variable_choices), hide_details=True, density="comfortable", class_="pfm-var-select", style=( "font-size:20px;" "font-weight:800;" "height:56px;" "display:flex;" "align-items:center;" ), menuProps={"contentClass": "pfm-var-list"}, ) # v3.VSelect( # v_model=("variable", self.state.variable), # items=("variable_choices", self.state.variable_choices), # hide_details=True, # density="comfortable", # style="font-size:16px;", # ) # 3) Boundary Conditions with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3): html.Div( "🧱 Boundary Conditions", style="font-weight:700;font-size:22px;margin-bottom:16px;", ) # two columns: Left = velocity controls, Right = reference text with v3.VRow(class_="align-start", dense=True): # ---- LEFT: velocity slider / readout ---- with v3.VCol(cols=7, class_="pfm-vel"): html.Div( "πŸš— Velocity (mph)", class_="pfm-vel-title", style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;", ) html.Div( "Set the inlet velocity in miles per hour", class_="pfm-vel-sub", style="margin-bottom:10px;font-size:20px;opacity:.95;", ) v3.VSlider( v_model=("velocity_mph", self.state.velocity_mph), min=30.0, max=80.0, step=0.1, thumb_label=True, v_if=("show_velocity",), style="height:54px;margin-top:12px;max-width:540px;", class_="mt-3 mb-3 pfm-vel-slider", ) html.Div( "{{ velocity_mph.toFixed(0) }} / 80 " "" "({{ (velocity_mph * 0.44704).toFixed(2) }} m/s)", v_if=("show_velocity",), class_="pfm-vel-readout", style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:6px;", ) # ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ---- with v3.VCol(cols=5, class_="pfm-bc-right"): html.Div( v_html=("bc_text_html", ""), style=( "margin-top:6px;" "font-size:18px;" "line-height:1.7;" "min-width:260px;" "max-width:360px;" "text-align:left;" ), ) # 4) Two viewers with v3.VRow(style="margin-top: 24px;"): # LEFT = upload with v3.VCol(cols=6): with v3.VRow(class_="align-center justify-space-between mb-2"): html.Div( "πŸ“€ Input Geometry", ) # βœ… working gear menu with v3.VMenu( location="bottom end", close_on_content_click=False, offset="4 8", ): # activator slot MUST expose { props } and we MUST bind them to the button with v3.Template(v_slot_activator="{ props }"): with v3.VBtn( icon=True, variant="text", density="comfortable", style="min-width:32px;", v_bind="props", # πŸ‘ˆ this is the key ): v3.VIcon("mdi-cog", size="22") # menu content with v3.VCard(class_="pa-4", style="min-width: 280px;"): html.Div("Mesh decimation", class_="mb-3", style="font-size:14px;") v3.VSwitch( v_model=("decim_enable",), label="Enable decimation", inset=True, hide_details=True, class_="mb-4", ) html.Div( "Target reduction (fraction of faces to remove)", class_="mb-1", style="font-size:12px;color:#9ca3af;", ) v3.VSlider( v_model=("decim_target",), min=0.0, max=0.999, step=0.001, hide_details=True, class_="mb-2", ) html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3") with v3.VRow(dense=True, class_="mb-3"): with v3.VCol(cols=6): html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1") v3.VTextField( v_model=("decim_min_faces",), type="number", density="compact", hide_details=True, ) with v3.VCol(cols=6): html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1") v3.VTextField( v_model=("decim_max_faces",), type="number", density="compact", hide_details=True, ) v3.VBtn( "Apply to current mesh", block=True, color="primary", class_="mt-2", click=self.decimate_again, ) v3.VBtn( "Reset to original mesh", block=True, variant="tonal", class_="mt-2", click=self.reset_mesh, # πŸ‘ˆ will call the controller you added ) v3.VFileInput( label="Select 3D File", style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;", multiple=False, show_size=True, accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb", v_model=("upload", None), clearable=True, ) with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"): self.view_geom = VtkRemoteView( self.rw_geom, interactive=True, interactive_ratio=1, server=self.server, ) with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress", rounded=True, elevation=3): html.Div("Upload", style="font-size:18px;") # progress bar: only while uploading v3.VProgressLinear( v_model=("pm_upload", 0), height=22, style="margin-top:10px;margin-bottom:10px;", color="primary", rounded=True, v_show=("is_uploading",), # πŸ‘ˆ bar disappears after upload ) # text: percentage + time + message, only while uploading html.Div( "{{ pm_upload }}% β€” {{ pm_elapsed_upload.toFixed(2) }}s β€” {{ upload_msg }}", style="font-size:14px;", v_show=("is_uploading",), # πŸ‘ˆ hide text after completion ) v3.VBtn( "πŸ—‘οΈ CLEAR", block=True, variant="tonal", class_="mt-3 pfm-btn-big", style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;", click=self.clear, ) # RIGHT = prediction with v3.VCol(cols=6): html.Div( "πŸ“ˆ Prediction Results", style="margin-bottom:10px;", ) html.Div( v_html=("stats_html",), class_="mb-3", style="font-size:20px;line-height:1.4;", ) # v3.VProgressLinear( # v_model=("predict_progress", 0), # height=22, # style="margin-top:6px;margin-bottom:12px;", # color="primary", # rounded=True, # indeterminate=("predict_progress < 10",), # v_show=("is_predicting",), # ) # html.Div( # "Predicting: {{ predict_progress }}%", # style="font-size:14px;margin-bottom:10px;", # v_show=("is_predicting",), # ) with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"): self.view_pred = VtkRemoteView( self.rw_pred, interactive=True, interactive_ratio=1, server=self.server, ) with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress", rounded=True, elevation=3): html.Div("Inference", style="font-size:18px;") # πŸ”΄ OLD: v_model=("predict_progress", 0), indeterminate=... # 🟒 NEW: use pm_infer and a normal (non-indeterminate) bar v3.VProgressLinear( v_model=("pm_infer", 0), height=22, style="margin-top:6px;margin-bottom:12px;", color="success", rounded=True, indeterminate=("predict_progress <= 0",), v_show=("is_predicting",), # πŸ‘ˆ bar only visible while predicting ) # text line: % + elapsed time + current stage message html.Div( "{{ pm_infer }}% β€” {{ pm_elapsed_infer.toFixed(2) }}s β€” {{ predict_msg }}", style="font-size:14px;margin-bottom:10px;", # ❗ if you want the *text* to also disappear at the end, keep v_show; # if you want the final "βœ… Prediction complete β€” 1.23s" to stay, REMOVE v_show v_show=("is_predicting",), ) v3.VBtn( "πŸš€ PREDICT", block=True, color="primary", class_="mt-3 pfm-btn-big", style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;", click=self.predict, ) layout.on_ready = self._first_paint def _first_paint(self, **_): for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)): try: rw.Render() except Exception: pass view.update() # --------------------------------------------------------- # UPLOAD (async) # --------------------------------------------------------- def _write_upload_to_disk(self, payload) -> str: if payload is None: raise ValueError("No file payload") if isinstance(payload, (list, tuple)): payload = payload[0] if isinstance(payload, str): return payload if not isinstance(payload, dict): raise ValueError(f"Unsupported payload: {type(payload)}") if payload.get("path"): return payload["path"] name = payload.get("name") or "upload" content = payload.get("content") if isinstance(content, str) and content.startswith("data:"): content = content.split(",", 1)[1] raw = base64.b64decode(content) if isinstance(content, str) else bytes(content) os.makedirs(GEOM_DIR, exist_ok=True) file_path = os.path.join(GEOM_DIR, name) with open(file_path, "wb") as f: f.write(raw) return file_path def _pre_upload_spinner_loop(self): s = self.state phase = 1 while self._pre_upload_on and not self._upload_actual_started and s.is_uploading: s.pm_upload = max(1, min(9, phase)) s.upload_msg = "Initializing upload..." try: self.server.state.flush() except Exception: pass phase = 1 if phase >= 9 else phase + 1 time.sleep(0.15) def _start_pre_upload_spinner(self): if self._pre_upload_thread and self._pre_upload_thread.is_alive(): return self._upload_actual_started = False self._pre_upload_on = True self._pre_upload_thread = threading.Thread( target=self._pre_upload_spinner_loop, daemon=True ) self._pre_upload_thread.start() def _stop_pre_upload_spinner(self): self._pre_upload_on = False self._pre_upload_thread = None async def _fake_upload_bump(self, stop_event: asyncio.Event): s = self.state while not stop_event.is_set() and s.pm_upload < 9: s.pm_upload += 1 await self._flush_async() await asyncio.sleep(0.05) async def _upload_worker_async(self, upload): s = self.state loop = self._ensure_loop() t0 = time.time() s.is_uploading = True s.upload_msg = "Starting upload..." s.pm_elapsed_upload = 0.0 s.pm_upload = 5 self.server.state.flush() await asyncio.sleep(0) fake_stop = asyncio.Event() fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop)) try: self._upload_actual_started = True self._stop_pre_upload_spinner() if not fake_stop.is_set(): fake_stop.set() try: await fake_task except asyncio.CancelledError: pass s.upload_msg = "Writing file to disk..." s.pm_upload = 10 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload) s.upload_msg = "Reading mesh..." s.pm_upload = 20 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() mesh = await loop.run_in_executor(None, pv.read, file_path) # 3) decimation (auto first) try: nf = poly_count(mesh) except Exception: nf = mesh.n_cells auto_tr = float(auto_target_reduction(nf)) # reflect auto in UI s.decim_target = auto_tr s.decim_min_faces = 5000 # <= allow decimation even for 27k faces s.decim_max_faces = int(1e7) target = auto_tr min_faces = 5000 max_faces = int(1e7) # user override if self.state.decim_enable: target = float(self.state.decim_target or 0.0) min_faces = int(self.state.decim_min_faces or 5000) max_faces = int(self.state.decim_max_faces or 1e7) if target > 0.0: s.upload_msg = f"Decimating mesh ({target:.3f})..." s.pm_upload = max(s.pm_upload, 45) s.pm_elapsed_upload = time.time() - t0 await self._flush_async() dec_cfg = { "enabled": True, "method": "pro", "target_reduction": target, "min_faces": min_faces, "max_faces": max_faces, } mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg) # 4) normals + save s.upload_msg = "Preparing geometry..." s.pm_upload = 75 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() def _normals_and_save(m): m_fixed = m.compute_normals( consistent_normals=True, auto_orient_normals=True, point_normals=True, cell_normals=False, inplace=False, ) geom_path_ = os.path.join(GEOM_DIR, "geometry.stl") m_fixed.save(geom_path_) return geom_path_, m_fixed geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh) # 5) update viewer self.ren_geom.RemoveAllViewProps() self.ren_geom.AddActor(make_actor_from_stl(geom_path)) self.ren_geom.ResetCamera() try: self.rw_geom.Render() except Exception: pass self.view_geom.update() # GEOMETRY_CACHE.current_mesh = mesh_fixed GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True) GEOMETRY_CACHE.current_mesh = mesh_fixed s.upload_msg = "βœ… Geometry ready." s.pm_upload = 100 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() except Exception as e: s.upload_msg = f"❌ Upload failed: {e}" s.pm_upload = 0 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() finally: s.is_uploading = False s.pm_elapsed_upload = time.time() - t0 await self._flush_async() if not fake_stop.is_set(): fake_stop.set() if not fake_task.done(): fake_task.cancel() try: await fake_task except Exception: pass @change("upload") def _on_upload_change(self, upload, **_): if not upload: return self._run_coro(self._upload_worker_async(upload)) def decimate_again(self): self._run_coro(self._decimate_again_async()) async def _decimate_again_async(self): s = self.state loop = self._ensure_loop() if GEOMETRY_CACHE.current_mesh is None: # nothing to decimate s.upload_msg = "No mesh to re-decimate" await self._flush_async() return # --- start "upload-like" progress for manual decimation --- t0 = time.time() s.is_uploading = True s.pm_upload = 5 s.pm_elapsed_upload = 0.0 s.upload_msg = "Starting mesh re-decimation..." await self._flush_async() try: # read parameters from UI try: target = float(s.decim_target) except Exception: target = 0.0 try: min_faces = int(s.decim_min_faces) except Exception: min_faces = 5000 try: max_faces = int(s.decim_max_faces) except Exception: max_faces = int(1e7) if (not s.decim_enable) or target <= 0.0: s.upload_msg = "Decimation disabled" s.pm_upload = 0 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() return # --- bump before heavy decimation call --- s.upload_msg = f"Re-decimating mesh ({target:.3f})..." s.pm_upload = 25 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() dec_cfg = { "enabled": True, "method": "pro", "target_reduction": target, "min_faces": min_faces, "max_faces": max_faces, } # heavy work on executor mesh = await loop.run_in_executor( None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg ) # --- normals + save --- s.upload_msg = "Recomputing normals & saving..." s.pm_upload = 70 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() def _normals_and_save(m): m_fixed = m.compute_normals( consistent_normals=True, auto_orient_normals=True, point_normals=True, cell_normals=False, inplace=False, ) geom_path_ = os.path.join(GEOM_DIR, "geometry.stl") m_fixed.save(geom_path_) return geom_path_, m_fixed geom_path, mesh_fixed = await loop.run_in_executor( None, _normals_and_save, mesh ) # --- update viewer --- s.upload_msg = "Updating viewer..." s.pm_upload = 90 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() self.ren_geom.RemoveAllViewProps() self.ren_geom.AddActor(make_actor_from_stl(geom_path)) self.ren_geom.ResetCamera() try: self.rw_geom.Render() except Exception: pass self.view_geom.update() GEOMETRY_CACHE.current_mesh = mesh_fixed # --- final bump --- s.upload_msg = "βœ… Re-decimated" s.pm_upload = 100 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() except Exception as e: s.upload_msg = f"❌ Re-decimation failed: {e}" s.pm_upload = 0 s.pm_elapsed_upload = time.time() - t0 await self._flush_async() finally: # hide bar + text after we’re done s.is_uploading = False await self._flush_async() def reset_mesh(self): self._run_coro(self._reset_mesh_async()) async def _reset_mesh_async(self): s = self.state if GEOMETRY_CACHE.original_mesh is None: s.upload_msg = "No original mesh to reset to" await self._flush_async() return # use the saved original orig = GEOMETRY_CACHE.original_mesh # save it again as current GEOMETRY_CACHE.current_mesh = orig # write to disk (so the STL on disk matches the viewer) geom_path = os.path.join(GEOM_DIR, "geometry.stl") orig.save(geom_path) # update viewer self.ren_geom.RemoveAllViewProps() self.ren_geom.AddActor(make_actor_from_stl(geom_path)) self.ren_geom.ResetCamera() try: self.rw_geom.Render() except Exception: pass self.view_geom.update() s.upload_msg = "↩️ Reset to original mesh" await self._flush_async() # --------------------------------------------------------- # prediction # --------------------------------------------------------- def _start_infer_heartbeat(self): if self._infer_thread and self._infer_thread.is_alive(): return def loop_fn(): while self._infer_heartbeat_on: if self.state.is_predicting and self._predict_t0 is not None: self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0) try: self.server.state.flush() except Exception: pass time.sleep(0.12) self._infer_heartbeat_on = True self._infer_thread = threading.Thread(target=loop_fn, daemon=True) self._infer_thread.start() def _stop_infer_heartbeat(self): self._infer_heartbeat_on = False self._infer_thread = None async def _predict_worker_async(self): s = self.state loop = self._ensure_loop() t0 = time.time() if GEOMETRY_CACHE.current_mesh is None: s.predict_msg = "❌ Please upload geometry first" s.is_predicting = False await self._flush_async() return s.is_predicting = True s.predict_progress = 1 s.pm_infer = 1 s.predict_msg = "Preparing inference..." self._predict_t0 = time.time() self._start_infer_heartbeat() await self._flush_async() try: dataset = s.dataset variable = s.variable boundary = ( {"freestream_velocity": mph_to_ms(s.velocity_mph)} if dataset == "Incompressible flow over car" else None ) s.predict_msg = "Loading model/checkpoint..." s.predict_progress = 5 s.pm_infer = 5 await self._flush_async() cfg, model, device, _ = await loop.run_in_executor( None, MODEL_STORE.get, dataset, variable, None ) s.predict_msg = "Processing mesh for inference..." s.predict_progress = 35 s.pm_infer = 35 await self._flush_async() def _run_full(): return run_inference_fast( dataset, variable, boundary_conditions=boundary, progress_cb=None, ) viz = await loop.run_in_executor(None, _run_full) s.predict_msg = "Preparing visualization..." s.predict_progress = 85 s.pm_infer = 85 await self._flush_async() stl_path = os.path.join(GEOM_DIR, "geometry.stl") vmin = float(np.min(viz["pred"])) vmax = float(np.max(viz["pred"])) if os.path.exists(stl_path): _tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path) lut = build_jet_lut(vmin, vmax) colored_actor = color_actor_with_scalars_from_prediction( stl_path, viz["points"], viz["pred"], variable, vmin, vmax, lut=lut, ) self.ren_pred.AddActor(colored_actor) units = { "pressure": "Pa", "x_velocity": "m/s", "y_velocity": "m/s", "z_velocity": "m/s", }.get(variable, "") title = f"{variable} ({units})" if units else variable self.scalar_bar = add_or_update_scalar_bar( self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8 ) src_cam = self.ren_geom.GetActiveCamera() dst_cam = self.ren_pred.GetActiveCamera() if src_cam is not None and dst_cam is not None: dst_cam.SetPosition(src_cam.GetPosition()) dst_cam.SetFocalPoint(src_cam.GetFocalPoint()) dst_cam.SetViewUp(src_cam.GetViewUp()) dst_cam.SetParallelScale(src_cam.GetParallelScale()) cr = src_cam.GetClippingRange() dst_cam.SetClippingRange(cr) try: self.rw_pred.Render() except Exception: pass self.view_pred.update() raw_vmin = float(np.min(viz["pred"])) raw_vmax = float(np.max(viz["pred"])) s.stats_html = ( f"{variable} min:{raw_vmin:.3e} " f"max: {raw_vmax:.3e} " f"Confidence: {viz['confidence_score']:.4f}" ) s.predict_msg = "βœ… Prediction complete." s.predict_progress = 100 s.pm_infer = 100 s.predict_elapsed = time.time() - t0 s.pm_elapsed_infer = s.predict_elapsed await self._flush_async() except Exception as e: s.predict_msg = f"❌ Prediction failed: {e}" s.predict_progress = 0 s.pm_infer = 0 await self._flush_async() finally: s.is_predicting = False self._stop_infer_heartbeat() await self._flush_async() @time_function("Inference and Visualization") def predict(self, *_): self._run_coro(self._predict_worker_async()) # --------------------------------------------------------- # dataset wiring # --------------------------------------------------------- @change("analysis_type") def _on_analysis_type_change(self, analysis_type=None, **_): ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", []) default_ds = ds_list[0] if ds_list else None self.state.dataset_choices = ds_list if default_ds and self.state.dataset != default_ds: self.state.dataset = default_ds elif self.state.dataset: self._apply_dataset(self.state.dataset) @change("dataset") def _on_dataset_change(self, dataset=None, **_): if not dataset: return self._apply_dataset(dataset) def _apply_dataset(self, ds: str): s = self.state opts = variables_for(ds) if ds else [] s.variable_choices = opts s.variable = opts[0] if opts else None s.show_velocity = (ds == "Incompressible flow over car") s.is_plane = (ds == "Compressible flow over plane") s.bc_text = get_boundary_conditions_text(ds) s.bc_left = bc_text_left(ds) s.bc_right = bc_text_right(ds) s.bc_text_html = s.bc_right or md_to_html(s.bc_text) # --------------------------------------------------------- # clear # --------------------------------------------------------- def clear(self, *_): for d in [GEOM_DIR, SOLN_DIR]: if os.path.exists(d): shutil.rmtree(d) os.makedirs(d, exist_ok=True) s = self.state s.stats_html = "🧹 Cleared. Upload again." s.is_uploading = False s.pm_upload = 0 s.upload_msg = "" s.pm_elapsed_upload = 0.0 s.is_predicting = False s.predict_progress = 0 s.predict_msg = "" s.pm_infer = 0 s.pm_elapsed_infer = 0.0 self.ren_geom.RemoveAllViewProps() self.ren_pred.RemoveAllViewProps() for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)): try: rw.Render() except Exception: pass view.update() # ---------- main ---------- def main(): app = PFMDemo() app.server.controller.add("decimate_again", app.decimate_again) app.server.controller.add("reset_mesh", app.reset_mesh) # app.server.start(7860) port = int(os.environ.get("PORT", "7860")) # `host` is controlled via CLI (--host 0.0.0.0), here we just: app.server.start( port=port, open_browser=False, # βœ… do NOT try to open a browser in the container show_connection_info=True, # fine to keep, useful logs backend="aiohttp", exec_mode="main", ) if __name__ == "__main__": main()