# ========= Headless / Offscreen safety (before any VTK import) =========
import os
os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1")
os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe")
os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3")
os.environ.setdefault("DISPLAY", "")
# ========= Core setup =========
import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio
import numpy as np
import torch
import pyvista as pv
os.environ.setdefault("OMP_NUM_THREADS", "4")
# Optional: extra VTK headless hints (you already used similar locally)
os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1")
os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe")
os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3")
# Force headless: never try to open an X11 window
# os.environ["DISPLAY"] = ""
pv.OFF_SCREEN = True # Tell PyVista to always render offscreen
from scipy.spatial import cKDTree
from vtk.util import numpy_support as nps
import matplotlib.cm as cm
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import pickle
from sklearn.metrics import pairwise_distances
from train import get_single_latent
from sklearn.neighbors import NearestNeighbors
from utils.app_utils2 import (
create_visualization_points,
create_visualization_stl,
camera_from_bounds,
bounds_from_points,
convert_vtp_to_glb,
convert_vtp_to_stl,
time_function,
print_timing,
mesh_get_variable,
mph_to_ms,
get_boundary_conditions_text,
compute_confidence_score,
compute_cosine_score,
decimate_mesh,
)
# ========= trame =========
from trame.app import TrameApp
from trame.decorators import change
from trame.ui.vuetify3 import SinglePageLayout
from trame.widgets import vuetify3 as v3, html
from trame_vtk.widgets.vtk import VtkRemoteView
# ========= VTK =========
from vtkmodules.vtkRenderingCore import (
vtkRenderer,
vtkRenderWindow,
vtkPolyDataMapper,
vtkActor,
vtkRenderWindowInteractor,
)
from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor
from vtkmodules.vtkIOGeometry import vtkSTLReader
from vtkmodules.vtkFiltersCore import vtkTriangleFilter
from vtkmodules.vtkCommonCore import vtkLookupTable
from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
# ========= Writable paths / caches =========
DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata")
os.makedirs(DATA_DIR, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", DATA_DIR)
GEOM_DIR = os.path.join(DATA_DIR, "geometry")
SOLN_DIR = os.path.join(DATA_DIR, "solution")
WEIGHTS_DIR = os.path.join(DATA_DIR, "weights")
for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR):
os.makedirs(d, exist_ok=True)
HF_DIR = os.path.join(DATA_DIR, "hf_home")
os.environ.setdefault("HF_HOME", HF_DIR)
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR)
os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR)
os.makedirs(HF_DIR, exist_ok=True)
for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR):
if not os.access(p, os.W_OK):
raise RuntimeError(f"Not writable: {p}")
# ========= Auto-decimation ladder =========
def auto_target_reduction(num_cells: int) -> float:
if num_cells <= 10_000:
return 0.0
elif num_cells <= 20_000:
return 0.2
elif num_cells <= 50_000:
return 0.4
elif num_cells <= 100_000:
return 0.5
elif num_cells <= 500_000:
return 0.6
elif num_cells < 1_000_000:
return 0.8
else:
return 0.9
# ========= Registry / choices =========
REGISTRY = {
"Incompressible flow inside artery": {
"repo_id": "ansysresearch/pretrained_models",
"config": "configs/app_configs/Incompressible flow inside artery/config.yaml",
"model_attr": "ansysLPFMs",
"checkpoints": {"best_val": "ckpt_artery.pt"},
"out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"],
},
"Vehicle crash analysis": {
"repo_id": "ansysresearch/pretrained_models",
"config": "configs/app_configs/Vehicle crash analysis/config.yaml",
"model_attr": "ansysLPFMs",
"checkpoints": {"best_val": "ckpt_vehiclecrash.pt"},
"out_variable": [
"impact_force",
"deformation",
"energy_absorption",
"x_displacement",
"y_displacement",
"z_displacement",
],
},
"Compressible flow over plane": {
"repo_id": "ansysresearch/pretrained_models",
"config": "configs/app_configs/Compressible flow over plane/config.yaml",
"model_attr": "ansysLPFMs",
"checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"},
"out_variable": ["pressure"],
},
"Incompressible flow over car": {
"repo_id": "ansysresearch/pretrained_models",
"config": "configs/app_configs/Incompressible flow over car/config.yaml",
"model_attr": "ansysLPFMs",
"checkpoints": {"best_val": "ckpt_cadillac_v3.pt"},
"out_variable": ["pressure"],
},
}
def variables_for(dataset: str):
spec = REGISTRY.get(dataset, {})
ov = spec.get("out_variable")
if isinstance(ov, str):
return [ov]
if isinstance(ov, (list, tuple)):
return list(ov)
return list(spec.get("checkpoints", {}).keys())
ANALYSIS_TYPE_MAPPING = {
"External flow": ["Incompressible flow over car", "Compressible flow over plane"],
"Internal flow": ["Incompressible flow inside artery"],
"Crash analysis": ["Vehicle crash analysis"],
}
ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys())
DEFAULT_ANALYSIS_TYPE = "External flow"
DEFAULT_DATASET = "Incompressible flow over car"
VAR_CHOICES0 = variables_for(DEFAULT_DATASET)
DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None
# ========= Simple cache =========
class GeometryCache:
def __init__(self):
self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation
self.current_mesh = None # what the app is actually using right now
GEOMETRY_CACHE = GeometryCache()
# ========= Model store =========
class ModelStore:
def __init__(self):
self._cache = {}
def _build(self, dataset: str, progress_cb=None):
def tick(x):
if progress_cb:
try:
progress_cb(int(x))
except:
pass
if dataset in self._cache:
tick(12)
return self._cache[dataset]
print(f"π§ Building model for {dataset}")
start_time = time.time()
try:
spec = REGISTRY[dataset]
repo_id = spec["repo_id"]
ckpt_name = spec["checkpoints"]["best_val"]
tick(6)
t0 = time.time()
ckpt_path_hf = hf_hub_download(
repo_id=repo_id,
filename=ckpt_name,
repo_type="model",
local_dir=HF_DIR,
local_dir_use_symlinks=False,
)
print_timing("Model checkpoint download", t0)
tick(8)
t0 = time.time()
ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset)
os.makedirs(ckpt_local_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_local_dir, ckpt_name)
if not os.path.exists(ckpt_path):
shutil.copy(ckpt_path_hf, ckpt_path)
print_timing("Local model copy setup", t0)
tick(9)
t0 = time.time()
cfg_path = spec["config"]
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"Missing config: {cfg_path}")
cfg = OmegaConf.load(cfg_path)
cfg.save_latent = True
print_timing("Configuration loading", t0)
tick(11)
t0 = time.time()
os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_timing("Device init", t0)
tick(12)
t0 = time.time()
import models
model_cls_name = spec["model_attr"]
if not hasattr(models, model_cls_name):
raise ValueError(f"Model '{model_cls_name}' not found")
model = getattr(models, model_cls_name)(cfg).to(device)
print_timing("Model build", t0)
tick(14)
t0 = time.time()
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
model.eval()
print_timing("Weights load", t0)
tick(15)
result = (cfg, model, device, accelerator)
self._cache[dataset] = result
print_timing(f"Total model build for {dataset}", start_time)
return result
except Exception as e:
print_timing(f"Model build failed for {dataset}", e)
raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}")
def get(self, dataset: str, variable: str, progress_cb=None):
return self._build(dataset, progress_cb=progress_cb)
MODEL_STORE = ModelStore()
# ========= Inference pipeline =========
def _variable_index(dataset: str, variable: str) -> int:
ov = REGISTRY[dataset]["out_variable"]
return 0 if isinstance(ov, str) else ov.index(variable)
@time_function("Mesh Processing")
def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None):
jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
json_data = json.load(open(jpath, "r"))
pts = np.asarray(mesh.points, dtype=np.float32)
N = pts.shape[0]
rng = np.random.default_rng(42)
idx = rng.permutation(N)
points = pts[idx]
tgt_np = mesh_get_variable(mesh, variable, N)[idx]
pos = torch.from_numpy(points)
target = torch.from_numpy(tgt_np).unsqueeze(-1)
if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None:
if "freestream_velocity" in boundary_conditions:
inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1)
inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx]
pos = torch.cat((pos, inlet_x_velocity), dim=1)
if getattr(cfg, "input_normalization", None) == "shift_axis":
coords = pos[:, :3].clone()
coords[:, 0] = coords[:, 0] - coords[:, 0].min()
coords[:, 2] = coords[:, 2] - coords[:, 2].min()
y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
coords[:, 1] = coords[:, 1] - y_center
pos[:, :3] = coords
if getattr(cfg, "pos_embed_sincos", False):
mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32)
maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32)
pos = 1000.0 * (pos - mins) / (maxs - mins)
pos = torch.clamp(pos, 0, 1000)
cosine_score = compute_cosine_score(mesh, dataset)
return pos, target, points, cosine_score
@time_function("Inference")
def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None):
def p(v):
if progress_cb:
try:
progress_cb(int(v))
except Exception:
pass
if GEOMETRY_CACHE.current_mesh is None:
raise ValueError("No geometry loaded")
p(5)
cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p)
p(15)
pos, target, points, cosine_score = process_mesh_fast(
GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions
)
p(25)
confidence_score = 0.0
try:
if dataset not in ["Incompressible flow inside artery"]:
geom_path = os.path.join(GEOM_DIR, "geometry.stl")
latent_features = get_single_latent(
mesh_path=geom_path,
config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"),
device=device,
custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None,
use_training_velocity=False,
model=model,
)
embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy")
pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl")
scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl")
embedding = np.load(embedding_path)
pca_reducer = pickle.load(open(pca_reducer_path, "rb"))
scaler = pickle.load(open(scaler_path, "rb"))
train_pair_dists = pairwise_distances(embedding)
sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0
n_points, n_features = latent_features.shape
np.random.seed(42)
target_len = int(pca_reducer.n_features_in_ / 256)
if n_points > target_len:
latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)]
elif n_points < target_len:
num_extra = target_len - n_points
extra_indices = np.random.choice(n_points, num_extra, replace=True)
latent_features = np.vstack([latent_features, latent_features[extra_indices]])
latent_features = latent_features.flatten()
confidence_score = compute_confidence_score(
latent_features, embedding, scaler, pca_reducer, sigma
)
except Exception:
confidence_score = 0.0
data = {
"input_pos": pos.unsqueeze(0).to(device),
"output_feat": target.unsqueeze(0).to(device),
}
with torch.no_grad():
inp = data["input_pos"]
_, N, _ = inp.shape
chunk = int(getattr(cfg, "num_points", 10000))
if getattr(cfg, "chunked_eval", False) and chunk < N:
input_pos = data["input_pos"]
chunk_size = cfg.num_points
out_chunks = []
total = (N + chunk_size - 1) // chunk_size
for k, i in enumerate(range(0, N, chunk_size)):
ch = input_pos[:, i : i + chunk_size, :]
n_valid = ch.shape[1]
if n_valid < chunk_size:
pad = input_pos[:, : chunk_size - n_valid, :]
ch = torch.cat([ch, pad], dim=1)
data["input_pos"] = ch
out_chunk = model(data)
if isinstance(out_chunk, (tuple, list)):
out_chunk = out_chunk[0]
out_chunks.append(out_chunk[:, :n_valid, :])
p(25 + 60 * (k + 1) / max(1, total))
outputs = torch.cat(out_chunks, dim=1)
else:
p(40)
outputs = model(data)
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
if torch.cuda.is_available():
torch.cuda.synchronize()
p(85)
vi = _variable_index(dataset, variable)
pred = outputs[0, :, vi : vi + 1]
if getattr(cfg, "normalization", "") == "std_norm":
fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
j = json.load(open(fp, "r"))
mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device)
sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device)
pred = pred * sd + mu
pred_np = pred.squeeze().detach().cpu().numpy()
tgt_np = target.squeeze().numpy()
pred_t = torch.from_numpy(pred_np).unsqueeze(-1)
tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1)
rel_l2 = torch.mean(
torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1)
/ torch.norm(tgt_t.squeeze(-1), p=2, dim=-1)
)
tgt_mean = torch.mean(tgt_t)
ss_tot = torch.sum((tgt_t - tgt_mean) ** 2)
ss_res = torch.sum((tgt_t - pred_t) ** 2)
r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0)
p(100)
return {
"points": np.asarray(points),
"pred": np.asarray(pred_np),
"tgt": np.asarray(tgt_np),
"cosine_score": float(cosine_score),
"confidence_score": float(confidence_score),
"abs_err": float(np.mean(np.abs(pred_np - tgt_np))),
"mse_err": float(np.mean((pred_np - tgt_np) ** 2)),
"rel_l2": float(rel_l2.item()),
"r_squared": float(r2.item()),
}
# ========= VTK helpers =========
def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)):
r = vtkSTLReader()
r.SetFileName(stl_path)
r.Update()
tri = vtkTriangleFilter()
tri.SetInputConnection(r.GetOutputPort())
tri.Update()
m = vtkPolyDataMapper()
m.SetInputConnection(tri.GetOutputPort())
a = vtkActor()
a.SetMapper(m)
a.GetProperty().SetColor(*color)
return a
def build_jet_lut(vmin, vmax):
lut = vtkLookupTable()
lut.SetRange(float(vmin), float(vmax))
lut.SetNumberOfTableValues(256)
lut.Build()
cmap = cm.get_cmap("jet", 256)
for i in range(256):
r_, g_, b_, _ = cmap(i)
lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0)
return lut
def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None):
r = vtkSTLReader()
r.SetFileName(stl_path)
r.Update()
poly = r.GetOutput()
stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData())
tree = cKDTree(points_xyz)
_, nn_idx = tree.query(stl_pts, k=1)
scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx]
vtk_arr = nps.numpy_to_vtk(scalars, deep=True)
vtk_arr.SetName(array_name)
poly.GetPointData().AddArray(vtk_arr)
poly.GetPointData().SetActiveScalars(array_name)
mapper = vtkPolyDataMapper()
mapper.SetInputData(poly)
mapper.SetScalarModeToUsePointData()
mapper.ScalarVisibilityOn()
mapper.SetScalarRange(float(vmin), float(vmax))
if lut is None:
lut = build_jet_lut(vmin, vmax)
mapper.SetLookupTable(lut)
mapper.UseLookupTableScalarRangeOn()
actor = vtkActor()
actor.SetMapper(mapper)
return actor
def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8):
existing = []
ca = renderer.GetActors2D()
ca.InitTraversal()
for _ in range(ca.GetNumberOfItems()):
a = ca.GetNextItem()
if isinstance(a, vtkScalarBarActor):
existing.append(a)
for a in existing:
renderer.RemoveActor2D(a)
sbar = vtkScalarBarActor()
sbar.SetLookupTable(lut)
sbar.SetOrientationToVertical()
sbar.SetLabelFormat(label_fmt)
sbar.SetNumberOfLabels(int(n_labels))
sbar.SetTitle(title)
sbar.SetPosition(0.92, 0.05)
sbar.SetPosition2(0.06, 0.90)
tp = sbar.GetTitleTextProperty()
tp.SetColor(1, 1, 1)
tp.SetBold(True)
tp.SetFontSize(22)
lp = sbar.GetLabelTextProperty()
lp.SetColor(1, 1, 1)
lp.SetFontSize(18)
renderer.AddActor2D(sbar)
return sbar
# ---------- Small helpers ----------
def poly_count(mesh: pv.PolyData) -> int:
if hasattr(mesh, "n_faces_strict"):
return mesh.n_faces_strict
return mesh.n_cells
def md_to_html(txt: str) -> str:
if not txt:
return ""
safe = _html.escape(txt)
safe = re.sub(r"\*\*(.+?)\*\*", r"\1", safe)
return "
".join(safe.splitlines())
def bc_text_right(dataset: str) -> str:
if dataset == "Incompressible flow over car":
return (
"Reference Density: 1.225 kg/mΒ³
"
"Reference Viscosity: 1.789e-5 PaΒ·s
"
"Operating Pressure: 101325 Pa"
)
if dataset == "Compressible flow over plane":
return (
"Ambient Temperature: 218 K
"
"Cruising velocity: 250.0 m/s or 560 mph"
)
return ""
def bc_text_left(dataset: str) -> str:
if dataset == "Compressible flow over plane":
return (
"Reference Density: 0.36 kg/mΒ³
"
"Reference viscosity: 1.716e-05 kg/(mΒ·s)
"
"Operating Pressure: 23842 Pa"
)
return ""
# =====================================================================
# ======================= APP =======================================
# =====================================================================
class PFMDemo(TrameApp):
def __init__(self, server=None):
super().__init__(server)
# ---------------- VTK RENDERERS ----------------
self.ren_geom = vtkRenderer()
self.ren_geom.SetBackground(0.10, 0.16, 0.22)
self.rw_geom = vtkRenderWindow()
self.rw_geom.SetOffScreenRendering(1)
self.rw_geom.AddRenderer(self.ren_geom)
self.rwi_geom = vtkRenderWindowInteractor()
self.rwi_geom.SetRenderWindow(self.rw_geom)
self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
try:
self.rwi_geom.Initialize()
self.rwi_geom.Enable()
except Exception:
pass
self.ren_pred = vtkRenderer()
self.ren_pred.SetBackground(0.10, 0.16, 0.22)
self.rw_pred = vtkRenderWindow()
self.rw_pred.SetOffScreenRendering(1)
self.rw_pred.AddRenderer(self.ren_pred)
self.rwi_pred = vtkRenderWindowInteractor()
self.rwi_pred.SetRenderWindow(self.rw_pred)
self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
try:
self.rwi_pred.Initialize()
self.rwi_pred.Enable()
except Exception:
pass
self.scalar_bar = None
# timers / flags
self._predict_t0 = None
self._infer_thread = None
self._pre_upload_thread = None
self._infer_heartbeat_on = False
self._loop = None
# ---------------- TRAME STATE ----------------
s = self.state
s.theme_dark = True
s.analysis_types = ANALYSIS_TYPE
s.analysis_type = DEFAULT_ANALYSIS_TYPE
s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE]
s.dataset = DEFAULT_DATASET
s.variable_choices = variables_for(DEFAULT_DATASET)
s.variable = s.variable_choices[0] if s.variable_choices else None
# dialog (still kept)
s.show_decimation_dialog = False
s.decim_override_enabled = False
s.decim_override_mode = "medium"
s.decim_override_custom = 0.5
# menu decimation defaults
s.decim_enable = False # user MUST toggle to override auto
s.decim_target = 0.5
s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced
s.decim_max_faces = int(1e7)
# register controller properly
# self.server.controller.decimate_again = self.decimate_again
# self.server.controller.add("decimate_again", self.decimate_again)
ctrl = self.server.controller
# β
this actually registers the trigger
ctrl.add("decimate_again", self.decimate_again)
ctrl.add("reset_mesh", self.reset_mesh)
s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car")
s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane")
s.velocity_mph = 45.0
s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET)
s.bc_left = bc_text_left(DEFAULT_DATASET)
s.bc_right = bc_text_right(DEFAULT_DATASET)
s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
s.stats_html = "π Upload a geometry, then click Predict."
s.upload = None
# upload
s.is_uploading = False
s.pm_upload = 0
s.pm_elapsed_upload = 0.0
s.upload_msg = ""
# predict
s.is_predicting = False
s.predict_progress = 0
s.predict_msg = ""
s.predict_elapsed = 0.0
s.predict_est_total = 0.0
s.pm_infer = 0
s.pm_elapsed_infer = 0.0
self._build_ui()
def _ensure_loop(self):
if self._loop is not None:
return self._loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
return loop
def _run_coro(self, coro):
loop = self._ensure_loop()
if loop.is_running():
return asyncio.ensure_future(coro, loop=loop)
return loop.run_until_complete(coro)
async def _flush_async(self):
try:
self.server.state.flush()
except Exception:
pass
await asyncio.sleep(0)
def _build_ui(self):
ctrl = self.server.controller
with SinglePageLayout(self.server, full_height=True) as layout:
layout.title.set_text("") # clear
layout.title.hide = True # hide default
with layout.toolbar:
with v3.VContainer(
fluid=True,
style=(
"max-width: 1800px;" # overall width
"margin: 0 auto;" # center it
"padding: 0 8px;" # β left/right margin
"box-sizing: border-box;"
),
):
v3.VSpacer()
html.Div(
"Ansys: Physics Foundation Model",
style=(
"width:100%;"
"text-align:center;"
"font-size:34px;"
"font-weight:900;"
"letter-spacing:0.4px;"
"line-height:1.2;"
),
)
v3.VSpacer()
# toolbar
with layout.toolbar:
# β same margin container for the second toolbar row
with v3.VContainer(
fluid=True,
style=(
"max-width: 1800px;"
"margin: 0 auto;"
"padding: 0 8px;"
"box-sizing: border-box;"
),
):
v3.VSwitch(
v_model=("theme_dark",),
label="Dark Theme",
inset=True,
density="compact",
hide_details=True,
)
v3.VSpacer()
with layout.content:
html.Style("""
/* Small side padding for the whole app */
.v-application__wrap {
padding-left: 8px;
padding-right: 8px;
padding-bottom: 8px;
}
:root {
--pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
--pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace;
}
html, body, .v-application {
margin: 0;
padding: 0;
font-family: var(--pfm-font-ui) !important;
font-weight: 500;
letter-spacing: .25px;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
line-height: 1.5;
font-size: 15.5px;
color: #ECEFF4;
}
/* ... keep all your other typography / button / slider styles here ... */
.v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; }
.v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; }
/* (rest of your .pfm-* classes unchanged) */
""")
# html.Style("""
# .v-theme--dark { background: #1F232B !important; }
# .v-theme--light { background: #f5f6f8 !important; }
# .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; }
# .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; }
# .v-theme--dark .pfm-viewer { background: #15171d !important; }
# .v-theme--light .pfm-viewer { background: #e9edf3 !important; }
# .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); }
# .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; }
# .pfm-btn-big.v-btn {
# height: 48px !important;
# font-size: 18px !important;
# font-weight: 600 !important;
# letter-spacing: 1.2px;
# text-transform: none !important;
# border-radius: 999px !important;
# }
# .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; }
# """)
html.Link(
rel="stylesheet",
href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap",
)
with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)):
with v3.VContainer(
fluid=True,
class_="pa-6",
style=(
"max-width: 2200px;" # max width of content
"margin: 8px auto 16px auto;" # top / left-right / bottom
"padding: 0 8px;" # inner left/right padding
"box-sizing: border-box;"
"background: rgba(255,255,255,0.02);"
"border-radius: 16px;"
),
):
# 1) Physics Application
with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
html.Div(
"π§ͺ Physics Application",
style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;",
)
html.Div(
"Select the type of analysis",
style="font-size:24px;opacity:.82;margin-bottom:18px;",
)
toggle = v3.VBtnToggle(
v_model=("analysis_type", self.state.analysis_type),
class_="mt-1",
mandatory=True,
rounded=True,
)
# with toggle:
# for at in ANALYSIS_TYPE:
# v3.VBtn(
# at,
# value=at,
# variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
# class_="mr-2 pfm-toggle-xxl",
# )
with toggle:
for at in ANALYSIS_TYPE:
v3.VBtn(
at,
value=at,
variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
class_="mr-2 pfm-toggle-xxl",
style=(
"font-size:18px;"
"font-weight:800;"
"letter-spacing:0.4px;"
"text-transform:none;"
),
)
# 2) Dataset + Variable
with v3.VRow(dense=True, class_="mb-3"):
with v3.VCol(cols=6):
with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
html.Div(
"π§© Sub Application",
style="font-weight:700;font-size:24px;margin-bottom:14px;",
)
v3.VSelect(
v_model=("dataset", self.state.dataset),
items=("dataset_choices", self.state.dataset_choices),
hide_details=True,
density="comfortable",
style=(
"font-size:24px;"
"font-weight:800;"
"height:56px;"
"display:flex;"
"align-items:center;"
),
class_="pfm-big-select-subapp",
menuProps={"contentClass": "pfm-subapp-list"}, # <β key for dropdown items
)
# v3.VSelect(
# v_model=("dataset", self.state.dataset),
# items=("dataset_choices", self.state.dataset_choices),
# hide_details=True,
# density="comfortable",
# class_="pfm-big-select-subapp pfm-subapp-list",
# style="font-size:21px;",
# )
with v3.VCol(cols=6):
with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
html.Div(
"π Variable to Predict",
style="font-weight:700;font-size:20px;margin-bottom:14px;",
)
v3.VSelect(
v_model=("variable", self.state.variable),
items=("variable_choices", self.state.variable_choices),
hide_details=True,
density="comfortable",
class_="pfm-var-select",
style=(
"font-size:20px;"
"font-weight:800;"
"height:56px;"
"display:flex;"
"align-items:center;"
),
menuProps={"contentClass": "pfm-var-list"},
)
# v3.VSelect(
# v_model=("variable", self.state.variable),
# items=("variable_choices", self.state.variable_choices),
# hide_details=True,
# density="comfortable",
# style="font-size:16px;",
# )
# 3) Boundary Conditions
with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
html.Div(
"π§± Boundary Conditions",
style="font-weight:700;font-size:22px;margin-bottom:16px;",
)
# two columns: Left = velocity controls, Right = reference text
with v3.VRow(class_="align-start", dense=True):
# ---- LEFT: velocity slider / readout ----
with v3.VCol(cols=7, class_="pfm-vel"):
html.Div(
"π Velocity (mph)",
class_="pfm-vel-title",
style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;",
)
html.Div(
"Set the inlet velocity in miles per hour",
class_="pfm-vel-sub",
style="margin-bottom:10px;font-size:20px;opacity:.95;",
)
v3.VSlider(
v_model=("velocity_mph", self.state.velocity_mph),
min=30.0, max=80.0, step=0.1,
thumb_label=True,
v_if=("show_velocity",),
style="height:54px;margin-top:12px;max-width:540px;",
class_="mt-3 mb-3 pfm-vel-slider",
)
html.Div(
"{{ velocity_mph.toFixed(0) }} / 80 "
""
"({{ (velocity_mph * 0.44704).toFixed(2) }} m/s)",
v_if=("show_velocity",),
class_="pfm-vel-readout",
style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:6px;",
)
# ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ----
with v3.VCol(cols=5, class_="pfm-bc-right"):
html.Div(
v_html=("bc_text_html", ""),
style=(
"margin-top:6px;"
"font-size:18px;"
"line-height:1.7;"
"min-width:260px;"
"max-width:360px;"
"text-align:left;"
),
)
# 4) Two viewers
with v3.VRow(style="margin-top: 24px;"):
# LEFT = upload
with v3.VCol(cols=6):
with v3.VRow(class_="align-center justify-space-between mb-2"):
html.Div(
"π€ Input Geometry",
)
# β
working gear menu
with v3.VMenu(
location="bottom end",
close_on_content_click=False,
offset="4 8",
):
# activator slot MUST expose { props } and we MUST bind them to the button
with v3.Template(v_slot_activator="{ props }"):
with v3.VBtn(
icon=True,
variant="text",
density="comfortable",
style="min-width:32px;",
v_bind="props", # π this is the key
):
v3.VIcon("mdi-cog", size="22")
# menu content
with v3.VCard(class_="pa-4", style="min-width: 280px;"):
html.Div("Mesh decimation", class_="mb-3", style="font-size:14px;")
v3.VSwitch(
v_model=("decim_enable",),
label="Enable decimation",
inset=True,
hide_details=True,
class_="mb-4",
)
html.Div(
"Target reduction (fraction of faces to remove)",
class_="mb-1",
style="font-size:12px;color:#9ca3af;",
)
v3.VSlider(
v_model=("decim_target",),
min=0.0,
max=0.999,
step=0.001,
hide_details=True,
class_="mb-2",
)
html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3")
with v3.VRow(dense=True, class_="mb-3"):
with v3.VCol(cols=6):
html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
v3.VTextField(
v_model=("decim_min_faces",),
type="number",
density="compact",
hide_details=True,
)
with v3.VCol(cols=6):
html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
v3.VTextField(
v_model=("decim_max_faces",),
type="number",
density="compact",
hide_details=True,
)
v3.VBtn(
"Apply to current mesh",
block=True,
color="primary",
class_="mt-2",
click=self.decimate_again,
)
v3.VBtn(
"Reset to original mesh",
block=True,
variant="tonal",
class_="mt-2",
click=self.reset_mesh, # π will call the controller you added
)
v3.VFileInput(
label="Select 3D File",
style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;",
multiple=False,
show_size=True,
accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb",
v_model=("upload", None),
clearable=True,
)
with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
self.view_geom = VtkRemoteView(
self.rw_geom,
interactive=True,
interactive_ratio=1,
server=self.server,
)
with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
rounded=True, elevation=3):
html.Div("Upload", style="font-size:18px;")
# progress bar: only while uploading
v3.VProgressLinear(
v_model=("pm_upload", 0),
height=22,
style="margin-top:10px;margin-bottom:10px;",
color="primary",
rounded=True,
v_show=("is_uploading",), # π bar disappears after upload
)
# text: percentage + time + message, only while uploading
html.Div(
"{{ pm_upload }}% β {{ pm_elapsed_upload.toFixed(2) }}s β {{ upload_msg }}",
style="font-size:14px;",
v_show=("is_uploading",), # π hide text after completion
)
v3.VBtn(
"ποΈ CLEAR",
block=True,
variant="tonal",
class_="mt-3 pfm-btn-big",
style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
click=self.clear,
)
# RIGHT = prediction
with v3.VCol(cols=6):
html.Div(
"π Prediction Results",
style="margin-bottom:10px;",
)
html.Div(
v_html=("stats_html",),
class_="mb-3",
style="font-size:20px;line-height:1.4;",
)
# v3.VProgressLinear(
# v_model=("predict_progress", 0),
# height=22,
# style="margin-top:6px;margin-bottom:12px;",
# color="primary",
# rounded=True,
# indeterminate=("predict_progress < 10",),
# v_show=("is_predicting",),
# )
# html.Div(
# "Predicting: {{ predict_progress }}%",
# style="font-size:14px;margin-bottom:10px;",
# v_show=("is_predicting",),
# )
with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
self.view_pred = VtkRemoteView(
self.rw_pred,
interactive=True,
interactive_ratio=1,
server=self.server,
)
with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
rounded=True, elevation=3):
html.Div("Inference", style="font-size:18px;")
# π΄ OLD: v_model=("predict_progress", 0), indeterminate=...
# π’ NEW: use pm_infer and a normal (non-indeterminate) bar
v3.VProgressLinear(
v_model=("pm_infer", 0),
height=22,
style="margin-top:6px;margin-bottom:12px;",
color="success",
rounded=True,
indeterminate=("predict_progress <= 0",),
v_show=("is_predicting",), # π bar only visible while predicting
)
# text line: % + elapsed time + current stage message
html.Div(
"{{ pm_infer }}% β {{ pm_elapsed_infer.toFixed(2) }}s β {{ predict_msg }}",
style="font-size:14px;margin-bottom:10px;",
# β if you want the *text* to also disappear at the end, keep v_show;
# if you want the final "β
Prediction complete β 1.23s" to stay, REMOVE v_show
v_show=("is_predicting",),
)
v3.VBtn(
"π PREDICT",
block=True,
color="primary",
class_="mt-3 pfm-btn-big",
style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
click=self.predict,
)
layout.on_ready = self._first_paint
def _first_paint(self, **_):
for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
try:
rw.Render()
except Exception:
pass
view.update()
# ---------------------------------------------------------
# UPLOAD (async)
# ---------------------------------------------------------
def _write_upload_to_disk(self, payload) -> str:
if payload is None:
raise ValueError("No file payload")
if isinstance(payload, (list, tuple)):
payload = payload[0]
if isinstance(payload, str):
return payload
if not isinstance(payload, dict):
raise ValueError(f"Unsupported payload: {type(payload)}")
if payload.get("path"):
return payload["path"]
name = payload.get("name") or "upload"
content = payload.get("content")
if isinstance(content, str) and content.startswith("data:"):
content = content.split(",", 1)[1]
raw = base64.b64decode(content) if isinstance(content, str) else bytes(content)
os.makedirs(GEOM_DIR, exist_ok=True)
file_path = os.path.join(GEOM_DIR, name)
with open(file_path, "wb") as f:
f.write(raw)
return file_path
def _pre_upload_spinner_loop(self):
s = self.state
phase = 1
while self._pre_upload_on and not self._upload_actual_started and s.is_uploading:
s.pm_upload = max(1, min(9, phase))
s.upload_msg = "Initializing upload..."
try:
self.server.state.flush()
except Exception:
pass
phase = 1 if phase >= 9 else phase + 1
time.sleep(0.15)
def _start_pre_upload_spinner(self):
if self._pre_upload_thread and self._pre_upload_thread.is_alive():
return
self._upload_actual_started = False
self._pre_upload_on = True
self._pre_upload_thread = threading.Thread(
target=self._pre_upload_spinner_loop, daemon=True
)
self._pre_upload_thread.start()
def _stop_pre_upload_spinner(self):
self._pre_upload_on = False
self._pre_upload_thread = None
async def _fake_upload_bump(self, stop_event: asyncio.Event):
s = self.state
while not stop_event.is_set() and s.pm_upload < 9:
s.pm_upload += 1
await self._flush_async()
await asyncio.sleep(0.05)
async def _upload_worker_async(self, upload):
s = self.state
loop = self._ensure_loop()
t0 = time.time()
s.is_uploading = True
s.upload_msg = "Starting upload..."
s.pm_elapsed_upload = 0.0
s.pm_upload = 5
self.server.state.flush()
await asyncio.sleep(0)
fake_stop = asyncio.Event()
fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop))
try:
self._upload_actual_started = True
self._stop_pre_upload_spinner()
if not fake_stop.is_set():
fake_stop.set()
try:
await fake_task
except asyncio.CancelledError:
pass
s.upload_msg = "Writing file to disk..."
s.pm_upload = 10
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload)
s.upload_msg = "Reading mesh..."
s.pm_upload = 20
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
mesh = await loop.run_in_executor(None, pv.read, file_path)
# 3) decimation (auto first)
try:
nf = poly_count(mesh)
except Exception:
nf = mesh.n_cells
auto_tr = float(auto_target_reduction(nf))
# reflect auto in UI
s.decim_target = auto_tr
s.decim_min_faces = 5000 # <= allow decimation even for 27k faces
s.decim_max_faces = int(1e7)
target = auto_tr
min_faces = 5000
max_faces = int(1e7)
# user override
if self.state.decim_enable:
target = float(self.state.decim_target or 0.0)
min_faces = int(self.state.decim_min_faces or 5000)
max_faces = int(self.state.decim_max_faces or 1e7)
if target > 0.0:
s.upload_msg = f"Decimating mesh ({target:.3f})..."
s.pm_upload = max(s.pm_upload, 45)
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
dec_cfg = {
"enabled": True,
"method": "pro",
"target_reduction": target,
"min_faces": min_faces,
"max_faces": max_faces,
}
mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg)
# 4) normals + save
s.upload_msg = "Preparing geometry..."
s.pm_upload = 75
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
def _normals_and_save(m):
m_fixed = m.compute_normals(
consistent_normals=True,
auto_orient_normals=True,
point_normals=True,
cell_normals=False,
inplace=False,
)
geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
m_fixed.save(geom_path_)
return geom_path_, m_fixed
geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh)
# 5) update viewer
self.ren_geom.RemoveAllViewProps()
self.ren_geom.AddActor(make_actor_from_stl(geom_path))
self.ren_geom.ResetCamera()
try:
self.rw_geom.Render()
except Exception:
pass
self.view_geom.update()
# GEOMETRY_CACHE.current_mesh = mesh_fixed
GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True)
GEOMETRY_CACHE.current_mesh = mesh_fixed
s.upload_msg = "β
Geometry ready."
s.pm_upload = 100
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
except Exception as e:
s.upload_msg = f"β Upload failed: {e}"
s.pm_upload = 0
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
finally:
s.is_uploading = False
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
if not fake_stop.is_set():
fake_stop.set()
if not fake_task.done():
fake_task.cancel()
try:
await fake_task
except Exception:
pass
@change("upload")
def _on_upload_change(self, upload, **_):
if not upload:
return
self._run_coro(self._upload_worker_async(upload))
def decimate_again(self):
self._run_coro(self._decimate_again_async())
async def _decimate_again_async(self):
s = self.state
loop = self._ensure_loop()
if GEOMETRY_CACHE.current_mesh is None:
# nothing to decimate
s.upload_msg = "No mesh to re-decimate"
await self._flush_async()
return
# --- start "upload-like" progress for manual decimation ---
t0 = time.time()
s.is_uploading = True
s.pm_upload = 5
s.pm_elapsed_upload = 0.0
s.upload_msg = "Starting mesh re-decimation..."
await self._flush_async()
try:
# read parameters from UI
try:
target = float(s.decim_target)
except Exception:
target = 0.0
try:
min_faces = int(s.decim_min_faces)
except Exception:
min_faces = 5000
try:
max_faces = int(s.decim_max_faces)
except Exception:
max_faces = int(1e7)
if (not s.decim_enable) or target <= 0.0:
s.upload_msg = "Decimation disabled"
s.pm_upload = 0
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
return
# --- bump before heavy decimation call ---
s.upload_msg = f"Re-decimating mesh ({target:.3f})..."
s.pm_upload = 25
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
dec_cfg = {
"enabled": True,
"method": "pro",
"target_reduction": target,
"min_faces": min_faces,
"max_faces": max_faces,
}
# heavy work on executor
mesh = await loop.run_in_executor(
None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg
)
# --- normals + save ---
s.upload_msg = "Recomputing normals & saving..."
s.pm_upload = 70
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
def _normals_and_save(m):
m_fixed = m.compute_normals(
consistent_normals=True,
auto_orient_normals=True,
point_normals=True,
cell_normals=False,
inplace=False,
)
geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
m_fixed.save(geom_path_)
return geom_path_, m_fixed
geom_path, mesh_fixed = await loop.run_in_executor(
None, _normals_and_save, mesh
)
# --- update viewer ---
s.upload_msg = "Updating viewer..."
s.pm_upload = 90
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
self.ren_geom.RemoveAllViewProps()
self.ren_geom.AddActor(make_actor_from_stl(geom_path))
self.ren_geom.ResetCamera()
try:
self.rw_geom.Render()
except Exception:
pass
self.view_geom.update()
GEOMETRY_CACHE.current_mesh = mesh_fixed
# --- final bump ---
s.upload_msg = "β
Re-decimated"
s.pm_upload = 100
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
except Exception as e:
s.upload_msg = f"β Re-decimation failed: {e}"
s.pm_upload = 0
s.pm_elapsed_upload = time.time() - t0
await self._flush_async()
finally:
# hide bar + text after weβre done
s.is_uploading = False
await self._flush_async()
def reset_mesh(self):
self._run_coro(self._reset_mesh_async())
async def _reset_mesh_async(self):
s = self.state
if GEOMETRY_CACHE.original_mesh is None:
s.upload_msg = "No original mesh to reset to"
await self._flush_async()
return
# use the saved original
orig = GEOMETRY_CACHE.original_mesh
# save it again as current
GEOMETRY_CACHE.current_mesh = orig
# write to disk (so the STL on disk matches the viewer)
geom_path = os.path.join(GEOM_DIR, "geometry.stl")
orig.save(geom_path)
# update viewer
self.ren_geom.RemoveAllViewProps()
self.ren_geom.AddActor(make_actor_from_stl(geom_path))
self.ren_geom.ResetCamera()
try:
self.rw_geom.Render()
except Exception:
pass
self.view_geom.update()
s.upload_msg = "β©οΈ Reset to original mesh"
await self._flush_async()
# ---------------------------------------------------------
# prediction
# ---------------------------------------------------------
def _start_infer_heartbeat(self):
if self._infer_thread and self._infer_thread.is_alive():
return
def loop_fn():
while self._infer_heartbeat_on:
if self.state.is_predicting and self._predict_t0 is not None:
self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0)
try:
self.server.state.flush()
except Exception:
pass
time.sleep(0.12)
self._infer_heartbeat_on = True
self._infer_thread = threading.Thread(target=loop_fn, daemon=True)
self._infer_thread.start()
def _stop_infer_heartbeat(self):
self._infer_heartbeat_on = False
self._infer_thread = None
async def _predict_worker_async(self):
s = self.state
loop = self._ensure_loop()
t0 = time.time()
if GEOMETRY_CACHE.current_mesh is None:
s.predict_msg = "β Please upload geometry first"
s.is_predicting = False
await self._flush_async()
return
s.is_predicting = True
s.predict_progress = 1
s.pm_infer = 1
s.predict_msg = "Preparing inference..."
self._predict_t0 = time.time()
self._start_infer_heartbeat()
await self._flush_async()
try:
dataset = s.dataset
variable = s.variable
boundary = (
{"freestream_velocity": mph_to_ms(s.velocity_mph)}
if dataset == "Incompressible flow over car"
else None
)
s.predict_msg = "Loading model/checkpoint..."
s.predict_progress = 5
s.pm_infer = 5
await self._flush_async()
cfg, model, device, _ = await loop.run_in_executor(
None, MODEL_STORE.get, dataset, variable, None
)
s.predict_msg = "Processing mesh for inference..."
s.predict_progress = 35
s.pm_infer = 35
await self._flush_async()
def _run_full():
return run_inference_fast(
dataset,
variable,
boundary_conditions=boundary,
progress_cb=None,
)
viz = await loop.run_in_executor(None, _run_full)
s.predict_msg = "Preparing visualization..."
s.predict_progress = 85
s.pm_infer = 85
await self._flush_async()
stl_path = os.path.join(GEOM_DIR, "geometry.stl")
vmin = float(np.min(viz["pred"]))
vmax = float(np.max(viz["pred"]))
if os.path.exists(stl_path):
_tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path)
lut = build_jet_lut(vmin, vmax)
colored_actor = color_actor_with_scalars_from_prediction(
stl_path,
viz["points"],
viz["pred"],
variable,
vmin,
vmax,
lut=lut,
)
self.ren_pred.AddActor(colored_actor)
units = {
"pressure": "Pa",
"x_velocity": "m/s",
"y_velocity": "m/s",
"z_velocity": "m/s",
}.get(variable, "")
title = f"{variable} ({units})" if units else variable
self.scalar_bar = add_or_update_scalar_bar(
self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8
)
src_cam = self.ren_geom.GetActiveCamera()
dst_cam = self.ren_pred.GetActiveCamera()
if src_cam is not None and dst_cam is not None:
dst_cam.SetPosition(src_cam.GetPosition())
dst_cam.SetFocalPoint(src_cam.GetFocalPoint())
dst_cam.SetViewUp(src_cam.GetViewUp())
dst_cam.SetParallelScale(src_cam.GetParallelScale())
cr = src_cam.GetClippingRange()
dst_cam.SetClippingRange(cr)
try:
self.rw_pred.Render()
except Exception:
pass
self.view_pred.update()
raw_vmin = float(np.min(viz["pred"]))
raw_vmax = float(np.max(viz["pred"]))
s.stats_html = (
f"{variable} min:{raw_vmin:.3e} "
f"max: {raw_vmax:.3e} "
f"Confidence: {viz['confidence_score']:.4f}"
)
s.predict_msg = "β
Prediction complete."
s.predict_progress = 100
s.pm_infer = 100
s.predict_elapsed = time.time() - t0
s.pm_elapsed_infer = s.predict_elapsed
await self._flush_async()
except Exception as e:
s.predict_msg = f"β Prediction failed: {e}"
s.predict_progress = 0
s.pm_infer = 0
await self._flush_async()
finally:
s.is_predicting = False
self._stop_infer_heartbeat()
await self._flush_async()
@time_function("Inference and Visualization")
def predict(self, *_):
self._run_coro(self._predict_worker_async())
# ---------------------------------------------------------
# dataset wiring
# ---------------------------------------------------------
@change("analysis_type")
def _on_analysis_type_change(self, analysis_type=None, **_):
ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", [])
default_ds = ds_list[0] if ds_list else None
self.state.dataset_choices = ds_list
if default_ds and self.state.dataset != default_ds:
self.state.dataset = default_ds
elif self.state.dataset:
self._apply_dataset(self.state.dataset)
@change("dataset")
def _on_dataset_change(self, dataset=None, **_):
if not dataset:
return
self._apply_dataset(dataset)
def _apply_dataset(self, ds: str):
s = self.state
opts = variables_for(ds) if ds else []
s.variable_choices = opts
s.variable = opts[0] if opts else None
s.show_velocity = (ds == "Incompressible flow over car")
s.is_plane = (ds == "Compressible flow over plane")
s.bc_text = get_boundary_conditions_text(ds)
s.bc_left = bc_text_left(ds)
s.bc_right = bc_text_right(ds)
s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
# ---------------------------------------------------------
# clear
# ---------------------------------------------------------
def clear(self, *_):
for d in [GEOM_DIR, SOLN_DIR]:
if os.path.exists(d):
shutil.rmtree(d)
os.makedirs(d, exist_ok=True)
s = self.state
s.stats_html = "π§Ή Cleared. Upload again."
s.is_uploading = False
s.pm_upload = 0
s.upload_msg = ""
s.pm_elapsed_upload = 0.0
s.is_predicting = False
s.predict_progress = 0
s.predict_msg = ""
s.pm_infer = 0
s.pm_elapsed_infer = 0.0
self.ren_geom.RemoveAllViewProps()
self.ren_pred.RemoveAllViewProps()
for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
try:
rw.Render()
except Exception:
pass
view.update()
# ---------- main ----------
def main():
app = PFMDemo()
app.server.controller.add("decimate_again", app.decimate_again)
app.server.controller.add("reset_mesh", app.reset_mesh)
# app.server.start(7872)
port = int(os.environ.get("PORT", "7860"))
# `host` is controlled via CLI (--host 0.0.0.0), here we just:
app.server.start(
port=port,
open_browser=False, # β
do NOT try to open a browser in the container
show_connection_info=True, # fine to keep, useful logs
backend="aiohttp",
exec_mode="main",
)
if __name__ == "__main__":
main()