udbhav commited on
Commit
67fb03c
·
0 Parent(s):

Recreate Trame_app branch with clean history

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +69 -0
  2. .gitattributes +35 -0
  3. .github/workflows/hf_sync.yml +27 -0
  4. .github/workflows/run_test.yml +32 -0
  5. .gitignore +80 -0
  6. Dockerfile +70 -0
  7. LICENSE +201 -0
  8. Main_app_trame.py +1701 -0
  9. Main_app_trame2.py +1766 -0
  10. README.md +8 -0
  11. ULIP +1 -0
  12. app.py +1766 -0
  13. configs/DrivAerML/config.yaml +40 -0
  14. configs/DriveAerNet/config.yaml +57 -0
  15. configs/app_configs/Compressible flow over plane/config.yaml +58 -0
  16. configs/app_configs/Compressible flow over plane/full_transform_params.json +37 -0
  17. configs/app_configs/Compressible flow over plane/pca_embedding.npy +3 -0
  18. configs/app_configs/Compressible flow over plane/pca_reducer.pkl +3 -0
  19. configs/app_configs/Compressible flow over plane/pca_scaler.pkl +3 -0
  20. configs/app_configs/Compressible flow over plane/train_dist.npz +3 -0
  21. configs/app_configs/Incompressible flow inside artery/config.yaml +39 -0
  22. configs/app_configs/Incompressible flow inside artery/full_transform_params.json +50 -0
  23. configs/app_configs/Incompressible flow inside artery/train_dist.npz +3 -0
  24. configs/app_configs/Incompressible flow over car/config.yaml +59 -0
  25. configs/app_configs/Incompressible flow over car/full_transform_params.json +40 -0
  26. configs/app_configs/Incompressible flow over car/pca_embedding.npy +3 -0
  27. configs/app_configs/Incompressible flow over car/pca_reducer.pkl +3 -0
  28. configs/app_configs/Incompressible flow over car/pca_scaler.pkl +3 -0
  29. configs/app_configs/Incompressible flow over car/train_dist.npz +3 -0
  30. configs/app_configs/Structural analysis of bracket/config.yaml +58 -0
  31. configs/app_configs/Structural analysis of bracket/full_transform_params.json +61 -0
  32. configs/app_configs/Structural analysis of bracket/train_dist.npz +3 -0
  33. configs/app_configs/Vehicle crash analysis/config.yaml +58 -0
  34. configs/app_configs/Vehicle crash analysis/full_transform_params.json +61 -0
  35. configs/app_configs/Vehicle crash analysis/train_dist.npz +3 -0
  36. configs/artery/config.yaml +38 -0
  37. configs/artery/full_transform_params.json +50 -0
  38. configs/cadillac/config.yaml +60 -0
  39. configs/deepjeb/config.yaml +58 -0
  40. configs/deepjeb/full_transform_params.json +61 -0
  41. configs/driveaerpp/config.yaml +56 -0
  42. configs/driveaerpp/full_transform_params.json +37 -0
  43. configs/elasticity/config.yaml +32 -0
  44. configs/plane_engine_test/config.yaml +57 -0
  45. configs/plane_transonic/config.yaml +62 -0
  46. configs/shapenet_car_pv/config.yaml +37 -0
  47. dataset_loader.py +188 -0
  48. datasets/DrivAerML/README.md +15 -0
  49. datasets/DrivAerML/data_preprocessing.yaml +20 -0
  50. datasets/DrivAerML/dataset_drivaerml.py +451 -0
.dockerignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Python
7
+ __pycache__
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ .Python
12
+ env
13
+ pip-log.txt
14
+ pip-delete-this-directory.txt
15
+ .tox
16
+ .coverage
17
+ .coverage.*
18
+ .cache
19
+ nosetests.xml
20
+ coverage.xml
21
+ *.cover
22
+ *.log
23
+ .git
24
+ .mypy_cache
25
+ .pytest_cache
26
+ .hypothesis
27
+
28
+ # Jupyter Notebook
29
+ .ipynb_checkpoints
30
+ *.ipynb
31
+
32
+ # IDEs
33
+ .vscode
34
+ .idea
35
+ *.swp
36
+ *.swo
37
+ *~
38
+
39
+ # OS
40
+ .DS_Store
41
+ .DS_Store?
42
+ ._*
43
+ .Spotlight-V100
44
+ .Trashes
45
+ ehthumbs.db
46
+ Thumbs.db
47
+
48
+ # Temporary files
49
+ *.tmp
50
+ *.temp
51
+ *.bak
52
+
53
+ # Large data files (if any)
54
+ *.h5
55
+ *.hdf5
56
+ *.pkl
57
+ *.pickle
58
+ *.npz
59
+ *.npy
60
+
61
+ # Documentation
62
+ README.md
63
+ docs/
64
+ *.md
65
+
66
+ # Test files
67
+ test_*
68
+ *_test.py
69
+ tests/
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.github/workflows/hf_sync.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Space
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+
15
+ - name: Set up Git
16
+ run: |
17
+ git config --global user.email "[email protected]"
18
+ git config --global user.name "Navaneeth"
19
+
20
+ - name: Push to Hugging Face Space
21
+ env:
22
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
23
+ run: |
24
+ git remote add space https://user:${HF_TOKEN}@huggingface.co/spaces/ansysresearch/AnsysLPFM-App
25
+
26
+
27
+ git push --force space main
.github/workflows/run_test.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run PyTest on Push to Main
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+
13
+ steps:
14
+ - name: Checkout repository
15
+ uses: actions/checkout@v3
16
+
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v5
19
+ with:
20
+ python-version: '3.12'
21
+
22
+ - name: Install dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install -r requirements.txt
26
+ # Optional: if you use Hydra, OmegaConf, or HDBSCAN etc.
27
+ pip install pytest omegaconf hdbscan scikit-learn torch
28
+
29
+ - name: Run tests
30
+ run: |
31
+ export PYTHONPATH=$PYTHONPATH:$(pwd)
32
+ pytest -v --maxfail=1 --disable-warnings
.gitignore ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VSCode settings
2
+ .vscode/
3
+
4
+ # Data folders
5
+ Data/
6
+ # metrics/
7
+ # Python cache
8
+ __pycache__/
9
+ **/__pycache__/
10
+ *.pyc
11
+ *.pyo
12
+ *.pyd
13
+
14
+ # Python cache in specific directories
15
+ utils/*.pyc
16
+ utils/__pycache__/
17
+ trainers/*.pyc
18
+ trainers/__pycache__/
19
+ models/*.pyc
20
+ models/__pycache__/
21
+ datasets/*.pyc
22
+ datasets/__pycache__/
23
+ datasets/*/*.pyc
24
+ datasets/*/__pycache__/
25
+ postprocessing.py
26
+ *.mp4
27
+ *.stl
28
+ *.png
29
+ *.mp4
30
+ *.stl
31
+ *.png
32
+ datasets/artery/*.pyc
33
+ datasets/artery/__pycache__/
34
+ datasets/elasticity/*.pyc
35
+ datasets/elasticity/__pycache__/
36
+ datasets/DrivAerML/*.pyc
37
+ datasets/DrivAerML/__pycache__/
38
+ datasets/plane/*.pyc
39
+ datasets/plane/__pycache__/
40
+ datasets/deepjeb/*.pyc
41
+ datasets/deepjeb/__pycache__/
42
+ datasets/driveaerpp/*.pyc
43
+ datasets/driveaerpp/__pycache__/
44
+ datasets/shapenet_car_pv/*.pyc
45
+ datasets/shapenet_car_pv/__pycache__/
46
+ datasets/DriveAerNet/*.pyc
47
+ datasets/DriveAerNet/__pycache__/
48
+
49
+ # TensorBoard logs
50
+ tensorboard_logs/
51
+ log_files
52
+ Dockerfile
53
+ .dockerignore
54
+
55
+ __pycache__/
56
+ **/__pycache__/
57
+ *.pyc
58
+ *.pyo
59
+ *.pyd
60
+
61
+ # Output files
62
+ out_trans_driveraerml.out
63
+ *.log
64
+ # Ignore all __pycache__ in subfolders
65
+ datasets/DrivAerML/__pycache__/
66
+ datasets/DriveAerNet/__pycache__/
67
+ datasets/elasticity/__pycache__/
68
+ datasets/shapenet_car_pv/__pycache__/
69
+ models/__pycache__/
70
+ trainers/__pycache__/
71
+
72
+
73
+ # umap
74
+ *.png
75
+ similarity/
76
+ *.vtk
77
+ *.npy
78
+
79
+ # Ignore all experiment metrics and checkpoints
80
+ metrics/*
Dockerfile ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Use NVIDIA CUDA base image (Ubuntu 22.04 jammy)
4
+ FROM nvidia/cuda:12.3.2-devel-ubuntu22.04
5
+
6
+ # Non-interactive apt
7
+ ENV DEBIAN_FRONTEND=noninteractive
8
+ ENV PYTHONUNBUFFERED=1
9
+ ENV CUDA_HOME=/usr/local/cuda
10
+ ENV PATH=${CUDA_HOME}/bin:/usr/local/bin:${PATH}
11
+ ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
12
+
13
+ # System deps + Deadsnakes PPA for Python 3.12
14
+ RUN apt-get update && apt-get install -y --no-install-recommends \
15
+ software-properties-common \
16
+ ca-certificates \
17
+ curl wget git build-essential \
18
+ libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender1 \
19
+ libgomp1 libgcc-s1 \
20
+ && add-apt-repository ppa:deadsnakes/ppa -y \
21
+ && apt-get update && apt-get install -y --no-install-recommends \
22
+ python3.12 python3.12-dev python3.12-venv \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Install pip for Python 3.12
26
+ RUN curl -sS https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py \
27
+ && python3.12 /tmp/get-pip.py \
28
+ && rm /tmp/get-pip.py
29
+
30
+ # Make 'python' and 'pip' point to 3.12
31
+ RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \
32
+ ln -sf /usr/local/bin/pip3.12 /usr/local/bin/pip
33
+
34
+ # ---------------- Writable app data & HF cache ----------------
35
+ ENV APP_DATA_DIR=/data
36
+ ENV HF_HOME=/data/hf_home
37
+ ENV HUGGINGFACE_HUB_CACHE=/data/hf_home
38
+ ENV TRANSFORMERS_CACHE=/data/hf_home
39
+ ENV MPLCONFIGDIR=/data/matplotlib
40
+
41
+ RUN mkdir -p /data/geometry /data/solution /data/weights /data/hf_home /data/matplotlib \
42
+ && chmod -R 777 /data
43
+
44
+ # ---------------- Install frpc for Gradio share=True ----------------
45
+ RUN mkdir -p /data/hf_home/gradio/frpc && \
46
+ wget https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 -O /data/hf_home/gradio/frpc/frpc_linux_amd64_v0.3 && \
47
+ chmod +x /data/hf_home/gradio/frpc/frpc_linux_amd64_v0.3
48
+
49
+ # ---------------- Application setup ----------------
50
+ WORKDIR /app
51
+
52
+ # Copy requirements first (better caching)
53
+ COPY requirements.txt .
54
+
55
+ # Install Python deps
56
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel \
57
+ && pip install --no-cache-dir -r requirements.txt
58
+
59
+ # Copy the rest of the source
60
+ COPY . .
61
+
62
+ # Permissions
63
+ RUN chmod +x app.py
64
+ RUN chown -R 1000:1000 /app
65
+
66
+ # Expose Gradio port
67
+ EXPOSE 7860
68
+
69
+ # Run app
70
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Main_app_trame.py ADDED
@@ -0,0 +1,1701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Headless / Offscreen safety (before any VTK import) =========
2
+ import os
3
+ os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1")
4
+ os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
5
+ os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe")
6
+ os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3")
7
+ os.environ.setdefault("DISPLAY", "")
8
+
9
+ # ========= Core setup =========
10
+ import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio
11
+ import numpy as np
12
+ import torch
13
+ import pyvista as pv
14
+ from scipy.spatial import cKDTree
15
+ from vtk.util import numpy_support as nps
16
+ import matplotlib.cm as cm
17
+
18
+ from omegaconf import OmegaConf
19
+ from huggingface_hub import hf_hub_download
20
+ from accelerate import Accelerator
21
+ from accelerate.utils import DistributedDataParallelKwargs
22
+ import pickle
23
+ from sklearn.metrics import pairwise_distances
24
+ from train import get_single_latent
25
+ from sklearn.neighbors import NearestNeighbors
26
+
27
+ from utils.app_utils2 import (
28
+ create_visualization_points,
29
+ create_visualization_stl,
30
+ camera_from_bounds,
31
+ bounds_from_points,
32
+ convert_vtp_to_glb,
33
+ convert_vtp_to_stl,
34
+ time_function,
35
+ print_timing,
36
+ mesh_get_variable,
37
+ mph_to_ms,
38
+ get_boundary_conditions_text,
39
+ compute_confidence_score,
40
+ compute_cosine_score,
41
+ decimate_mesh,
42
+ )
43
+
44
+ # ========= trame =========
45
+ from trame.app import TrameApp
46
+ from trame.decorators import change
47
+ from trame.ui.vuetify3 import SinglePageLayout
48
+ from trame.widgets import vuetify3 as v3, html
49
+ from trame_vtk.widgets.vtk import VtkRemoteView
50
+
51
+ # ========= VTK =========
52
+ from vtkmodules.vtkRenderingCore import (
53
+ vtkRenderer,
54
+ vtkRenderWindow,
55
+ vtkPolyDataMapper,
56
+ vtkActor,
57
+ vtkRenderWindowInteractor,
58
+ )
59
+ from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor
60
+ from vtkmodules.vtkIOGeometry import vtkSTLReader
61
+ from vtkmodules.vtkFiltersCore import vtkTriangleFilter
62
+ from vtkmodules.vtkCommonCore import vtkLookupTable
63
+ from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
64
+
65
+ # ========= Writable paths / caches =========
66
+ DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata")
67
+ os.makedirs(DATA_DIR, exist_ok=True)
68
+ os.environ.setdefault("MPLCONFIGDIR", DATA_DIR)
69
+
70
+ GEOM_DIR = os.path.join(DATA_DIR, "geometry")
71
+ SOLN_DIR = os.path.join(DATA_DIR, "solution")
72
+ WEIGHTS_DIR = os.path.join(DATA_DIR, "weights")
73
+ for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR):
74
+ os.makedirs(d, exist_ok=True)
75
+
76
+ HF_DIR = os.path.join(DATA_DIR, "hf_home")
77
+ os.environ.setdefault("HF_HOME", HF_DIR)
78
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR)
79
+ os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR)
80
+ os.makedirs(HF_DIR, exist_ok=True)
81
+ for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR):
82
+ if not os.access(p, os.W_OK):
83
+ raise RuntimeError(f"Not writable: {p}")
84
+
85
+ # ========= Auto-decimation ladder =========
86
+ def auto_target_reduction(num_cells: int) -> float:
87
+ if num_cells <= 10_000:
88
+ return 0.0
89
+ elif num_cells <= 20_000:
90
+ return 0.2
91
+ elif num_cells <= 50_000:
92
+ return 0.4
93
+ elif num_cells <= 100_000:
94
+ return 0.5
95
+ elif num_cells <= 500_000:
96
+ return 0.6
97
+ elif num_cells < 1_000_000:
98
+ return 0.8
99
+ else:
100
+ return 0.9
101
+
102
+ # ========= Registry / choices =========
103
+ REGISTRY = {
104
+ "Incompressible flow inside artery": {
105
+ "repo_id": "ansysresearch/pretrained_models",
106
+ "config": "configs/app_configs/Incompressible flow inside artery/config.yaml",
107
+ "model_attr": "ansysLPFMs",
108
+ "checkpoints": {"best_val": "ckpt_artery.pt"},
109
+ "out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"],
110
+ },
111
+ "Vehicle crash analysis": {
112
+ "repo_id": "ansysresearch/pretrained_models",
113
+ "config": "configs/app_configs/Vehicle crash analysis/config.yaml",
114
+ "model_attr": "ansysLPFMs",
115
+ "checkpoints": {"best_val": "ckpt_vehiclecrash.pt"},
116
+ "out_variable": [
117
+ "impact_force",
118
+ "deformation",
119
+ "energy_absorption",
120
+ "x_displacement",
121
+ "y_displacement",
122
+ "z_displacement",
123
+ ],
124
+ },
125
+ "Compressible flow over plane": {
126
+ "repo_id": "ansysresearch/pretrained_models",
127
+ "config": "configs/app_configs/Compressible flow over plane/config.yaml",
128
+ "model_attr": "ansysLPFMs",
129
+ "checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"},
130
+ "out_variable": ["pressure"],
131
+ },
132
+ "Incompressible flow over car": {
133
+ "repo_id": "ansysresearch/pretrained_models",
134
+ "config": "configs/app_configs/Incompressible flow over car/config.yaml",
135
+ "model_attr": "ansysLPFMs",
136
+ "checkpoints": {"best_val": "ckpt_cadillac_v3.pt"},
137
+ "out_variable": ["pressure"],
138
+ },
139
+ }
140
+
141
+ def variables_for(dataset: str):
142
+ spec = REGISTRY.get(dataset, {})
143
+ ov = spec.get("out_variable")
144
+ if isinstance(ov, str):
145
+ return [ov]
146
+ if isinstance(ov, (list, tuple)):
147
+ return list(ov)
148
+ return list(spec.get("checkpoints", {}).keys())
149
+
150
+ ANALYSIS_TYPE_MAPPING = {
151
+ "External flow": ["Incompressible flow over car", "Compressible flow over plane"],
152
+ "Internal flow": ["Incompressible flow inside artery"],
153
+ "Crash analysis": ["Vehicle crash analysis"],
154
+ }
155
+ ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys())
156
+ DEFAULT_ANALYSIS_TYPE = "External flow"
157
+ DEFAULT_DATASET = "Incompressible flow over car"
158
+ VAR_CHOICES0 = variables_for(DEFAULT_DATASET)
159
+ DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None
160
+
161
+ # ========= Simple cache =========
162
+ class GeometryCache:
163
+ def __init__(self):
164
+ self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation
165
+ self.current_mesh = None # what the app is actually using right now
166
+
167
+ GEOMETRY_CACHE = GeometryCache()
168
+
169
+ # ========= Model store =========
170
+ class ModelStore:
171
+ def __init__(self):
172
+ self._cache = {}
173
+
174
+ def _build(self, dataset: str, progress_cb=None):
175
+ def tick(x):
176
+ if progress_cb:
177
+ try:
178
+ progress_cb(int(x))
179
+ except:
180
+ pass
181
+
182
+ if dataset in self._cache:
183
+ tick(12)
184
+ return self._cache[dataset]
185
+
186
+ print(f"🔧 Building model for {dataset}")
187
+ start_time = time.time()
188
+ try:
189
+ spec = REGISTRY[dataset]
190
+ repo_id = spec["repo_id"]
191
+ ckpt_name = spec["checkpoints"]["best_val"]
192
+
193
+ tick(6)
194
+ t0 = time.time()
195
+ ckpt_path_hf = hf_hub_download(
196
+ repo_id=repo_id,
197
+ filename=ckpt_name,
198
+ repo_type="model",
199
+ local_dir=HF_DIR,
200
+ local_dir_use_symlinks=False,
201
+ )
202
+ print_timing("Model checkpoint download", t0)
203
+ tick(8)
204
+
205
+ t0 = time.time()
206
+ ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset)
207
+ os.makedirs(ckpt_local_dir, exist_ok=True)
208
+ ckpt_path = os.path.join(ckpt_local_dir, ckpt_name)
209
+ if not os.path.exists(ckpt_path):
210
+ shutil.copy(ckpt_path_hf, ckpt_path)
211
+ print_timing("Local model copy setup", t0)
212
+ tick(9)
213
+
214
+ t0 = time.time()
215
+ cfg_path = spec["config"]
216
+ if not os.path.exists(cfg_path):
217
+ raise FileNotFoundError(f"Missing config: {cfg_path}")
218
+ cfg = OmegaConf.load(cfg_path)
219
+ cfg.save_latent = True
220
+ print_timing("Configuration loading", t0)
221
+ tick(11)
222
+
223
+ t0 = time.time()
224
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0))
225
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
226
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
227
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
+ print_timing("Device init", t0)
229
+ tick(12)
230
+
231
+ t0 = time.time()
232
+ import models
233
+ model_cls_name = spec["model_attr"]
234
+ if not hasattr(models, model_cls_name):
235
+ raise ValueError(f"Model '{model_cls_name}' not found")
236
+ model = getattr(models, model_cls_name)(cfg).to(device)
237
+ print_timing("Model build", t0)
238
+ tick(14)
239
+
240
+ t0 = time.time()
241
+ state = torch.load(ckpt_path, map_location=device)
242
+ model.load_state_dict(state)
243
+ model.eval()
244
+ print_timing("Weights load", t0)
245
+ tick(15)
246
+
247
+ result = (cfg, model, device, accelerator)
248
+ self._cache[dataset] = result
249
+ print_timing(f"Total model build for {dataset}", start_time)
250
+ return result
251
+ except Exception as e:
252
+ print_timing(f"Model build failed for {dataset}", e)
253
+ raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}")
254
+
255
+ def get(self, dataset: str, variable: str, progress_cb=None):
256
+ return self._build(dataset, progress_cb=progress_cb)
257
+
258
+ MODEL_STORE = ModelStore()
259
+
260
+ # ========= Inference pipeline =========
261
+ def _variable_index(dataset: str, variable: str) -> int:
262
+ ov = REGISTRY[dataset]["out_variable"]
263
+ return 0 if isinstance(ov, str) else ov.index(variable)
264
+
265
+ @time_function("Mesh Processing")
266
+ def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None):
267
+ jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
268
+ json_data = json.load(open(jpath, "r"))
269
+
270
+ pts = np.asarray(mesh.points, dtype=np.float32)
271
+ N = pts.shape[0]
272
+ rng = np.random.default_rng(42)
273
+ idx = rng.permutation(N)
274
+ points = pts[idx]
275
+ tgt_np = mesh_get_variable(mesh, variable, N)[idx]
276
+ pos = torch.from_numpy(points)
277
+ target = torch.from_numpy(tgt_np).unsqueeze(-1)
278
+
279
+ if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None:
280
+ if "freestream_velocity" in boundary_conditions:
281
+ inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1)
282
+ inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx]
283
+ pos = torch.cat((pos, inlet_x_velocity), dim=1)
284
+
285
+ if getattr(cfg, "input_normalization", None) == "shift_axis":
286
+ coords = pos[:, :3].clone()
287
+ coords[:, 0] = coords[:, 0] - coords[:, 0].min()
288
+ coords[:, 2] = coords[:, 2] - coords[:, 2].min()
289
+ y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
290
+ coords[:, 1] = coords[:, 1] - y_center
291
+ pos[:, :3] = coords
292
+
293
+ if getattr(cfg, "pos_embed_sincos", False):
294
+ mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32)
295
+ maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32)
296
+ pos = 1000.0 * (pos - mins) / (maxs - mins)
297
+ pos = torch.clamp(pos, 0, 1000)
298
+
299
+ cosine_score = compute_cosine_score(mesh, dataset)
300
+ return pos, target, points, cosine_score
301
+
302
+ @time_function("Inference")
303
+ def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None):
304
+ def p(v):
305
+ if progress_cb:
306
+ try:
307
+ progress_cb(int(v))
308
+ except Exception:
309
+ pass
310
+
311
+ if GEOMETRY_CACHE.current_mesh is None:
312
+ raise ValueError("No geometry loaded")
313
+
314
+ p(5)
315
+ cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p)
316
+ p(15)
317
+
318
+ pos, target, points, cosine_score = process_mesh_fast(
319
+ GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions
320
+ )
321
+ p(25)
322
+
323
+ confidence_score = 0.0
324
+ try:
325
+ if dataset not in ["Incompressible flow inside artery"]:
326
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
327
+ latent_features = get_single_latent(
328
+ mesh_path=geom_path,
329
+ config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"),
330
+ device=device,
331
+ custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None,
332
+ use_training_velocity=False,
333
+ model=model,
334
+ )
335
+
336
+ embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy")
337
+ pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl")
338
+ scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl")
339
+
340
+ embedding = np.load(embedding_path)
341
+ pca_reducer = pickle.load(open(pca_reducer_path, "rb"))
342
+ scaler = pickle.load(open(scaler_path, "rb"))
343
+
344
+ train_pair_dists = pairwise_distances(embedding)
345
+ sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0
346
+
347
+ n_points, n_features = latent_features.shape
348
+ np.random.seed(42)
349
+ target_len = int(pca_reducer.n_features_in_ / 256)
350
+ if n_points > target_len:
351
+ latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)]
352
+ elif n_points < target_len:
353
+ num_extra = target_len - n_points
354
+ extra_indices = np.random.choice(n_points, num_extra, replace=True)
355
+ latent_features = np.vstack([latent_features, latent_features[extra_indices]])
356
+
357
+ latent_features = latent_features.flatten()
358
+
359
+ confidence_score = compute_confidence_score(
360
+ latent_features, embedding, scaler, pca_reducer, sigma
361
+ )
362
+ except Exception:
363
+ confidence_score = 0.0
364
+
365
+ data = {
366
+ "input_pos": pos.unsqueeze(0).to(device),
367
+ "output_feat": target.unsqueeze(0).to(device),
368
+ }
369
+
370
+ with torch.no_grad():
371
+ inp = data["input_pos"]
372
+ _, N, _ = inp.shape
373
+ chunk = int(getattr(cfg, "num_points", 10000))
374
+
375
+ if getattr(cfg, "chunked_eval", False) and chunk < N:
376
+ input_pos = data["input_pos"]
377
+ chunk_size = cfg.num_points
378
+ out_chunks = []
379
+ total = (N + chunk_size - 1) // chunk_size
380
+
381
+ for k, i in enumerate(range(0, N, chunk_size)):
382
+ ch = input_pos[:, i : i + chunk_size, :]
383
+ n_valid = ch.shape[1]
384
+
385
+ if n_valid < chunk_size:
386
+ pad = input_pos[:, : chunk_size - n_valid, :]
387
+ ch = torch.cat([ch, pad], dim=1)
388
+
389
+ data["input_pos"] = ch
390
+ out_chunk = model(data)
391
+ if isinstance(out_chunk, (tuple, list)):
392
+ out_chunk = out_chunk[0]
393
+ out_chunks.append(out_chunk[:, :n_valid, :])
394
+
395
+ p(25 + 60 * (k + 1) / max(1, total))
396
+
397
+ outputs = torch.cat(out_chunks, dim=1)
398
+
399
+ else:
400
+ p(40)
401
+ outputs = model(data)
402
+ if isinstance(outputs, (tuple, list)):
403
+ outputs = outputs[0]
404
+ if torch.cuda.is_available():
405
+ torch.cuda.synchronize()
406
+ p(85)
407
+
408
+ vi = _variable_index(dataset, variable)
409
+ pred = outputs[0, :, vi : vi + 1]
410
+
411
+ if getattr(cfg, "normalization", "") == "std_norm":
412
+ fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
413
+ j = json.load(open(fp, "r"))
414
+ mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device)
415
+ sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device)
416
+ pred = pred * sd + mu
417
+
418
+ pred_np = pred.squeeze().detach().cpu().numpy()
419
+ tgt_np = target.squeeze().numpy()
420
+
421
+ pred_t = torch.from_numpy(pred_np).unsqueeze(-1)
422
+ tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1)
423
+ rel_l2 = torch.mean(
424
+ torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1)
425
+ / torch.norm(tgt_t.squeeze(-1), p=2, dim=-1)
426
+ )
427
+ tgt_mean = torch.mean(tgt_t)
428
+ ss_tot = torch.sum((tgt_t - tgt_mean) ** 2)
429
+ ss_res = torch.sum((tgt_t - pred_t) ** 2)
430
+ r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0)
431
+
432
+ p(100)
433
+ return {
434
+ "points": np.asarray(points),
435
+ "pred": np.asarray(pred_np),
436
+ "tgt": np.asarray(tgt_np),
437
+ "cosine_score": float(cosine_score),
438
+ "confidence_score": float(confidence_score),
439
+ "abs_err": float(np.mean(np.abs(pred_np - tgt_np))),
440
+ "mse_err": float(np.mean((pred_np - tgt_np) ** 2)),
441
+ "rel_l2": float(rel_l2.item()),
442
+ "r_squared": float(r2.item()),
443
+ }
444
+
445
+ # ========= VTK helpers =========
446
+ def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)):
447
+ r = vtkSTLReader()
448
+ r.SetFileName(stl_path)
449
+ r.Update()
450
+ tri = vtkTriangleFilter()
451
+ tri.SetInputConnection(r.GetOutputPort())
452
+ tri.Update()
453
+ m = vtkPolyDataMapper()
454
+ m.SetInputConnection(tri.GetOutputPort())
455
+ a = vtkActor()
456
+ a.SetMapper(m)
457
+ a.GetProperty().SetColor(*color)
458
+ return a
459
+
460
+ def build_jet_lut(vmin, vmax):
461
+ lut = vtkLookupTable()
462
+ lut.SetRange(float(vmin), float(vmax))
463
+ lut.SetNumberOfTableValues(256)
464
+ lut.Build()
465
+ cmap = cm.get_cmap("jet", 256)
466
+ for i in range(256):
467
+ r_, g_, b_, _ = cmap(i)
468
+ lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0)
469
+ return lut
470
+
471
+ def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None):
472
+ r = vtkSTLReader()
473
+ r.SetFileName(stl_path)
474
+ r.Update()
475
+ poly = r.GetOutput()
476
+
477
+ stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData())
478
+ tree = cKDTree(points_xyz)
479
+ _, nn_idx = tree.query(stl_pts, k=1)
480
+ scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx]
481
+
482
+ vtk_arr = nps.numpy_to_vtk(scalars, deep=True)
483
+ vtk_arr.SetName(array_name)
484
+ poly.GetPointData().AddArray(vtk_arr)
485
+ poly.GetPointData().SetActiveScalars(array_name)
486
+
487
+ mapper = vtkPolyDataMapper()
488
+ mapper.SetInputData(poly)
489
+ mapper.SetScalarModeToUsePointData()
490
+ mapper.ScalarVisibilityOn()
491
+ mapper.SetScalarRange(float(vmin), float(vmax))
492
+ if lut is None:
493
+ lut = build_jet_lut(vmin, vmax)
494
+ mapper.SetLookupTable(lut)
495
+ mapper.UseLookupTableScalarRangeOn()
496
+
497
+ actor = vtkActor()
498
+ actor.SetMapper(mapper)
499
+ return actor
500
+
501
+ def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8):
502
+ existing = []
503
+ ca = renderer.GetActors2D()
504
+ ca.InitTraversal()
505
+ for _ in range(ca.GetNumberOfItems()):
506
+ a = ca.GetNextItem()
507
+ if isinstance(a, vtkScalarBarActor):
508
+ existing.append(a)
509
+ for a in existing:
510
+ renderer.RemoveActor2D(a)
511
+
512
+ sbar = vtkScalarBarActor()
513
+ sbar.SetLookupTable(lut)
514
+ sbar.SetOrientationToVertical()
515
+ sbar.SetLabelFormat(label_fmt)
516
+ sbar.SetNumberOfLabels(int(n_labels))
517
+ sbar.SetTitle(title)
518
+ sbar.SetPosition(0.92, 0.05)
519
+ sbar.SetPosition2(0.06, 0.90)
520
+
521
+ tp = sbar.GetTitleTextProperty()
522
+ tp.SetColor(1, 1, 1)
523
+ tp.SetBold(True)
524
+ tp.SetFontSize(22)
525
+ lp = sbar.GetLabelTextProperty()
526
+ lp.SetColor(1, 1, 1)
527
+ lp.SetFontSize(18)
528
+
529
+ renderer.AddActor2D(sbar)
530
+ return sbar
531
+
532
+ # ---------- Small helpers ----------
533
+ def poly_count(mesh: pv.PolyData) -> int:
534
+ if hasattr(mesh, "n_faces_strict"):
535
+ return mesh.n_faces_strict
536
+ return mesh.n_cells
537
+
538
+ def md_to_html(txt: str) -> str:
539
+ if not txt:
540
+ return ""
541
+ safe = _html.escape(txt)
542
+ safe = re.sub(r"\*\*(.+?)\*\*", r"<b>\1</b>", safe)
543
+ return "<br/>".join(safe.splitlines())
544
+
545
+ def bc_text_right(dataset: str) -> str:
546
+ if dataset == "Incompressible flow over car":
547
+ return (
548
+ "<b>Reference Density:</b> 1.225 kg/m³<br><br>"
549
+ "<b>Reference Viscosity:</b> 1.789e-5 Pa·s<br><br>"
550
+ "<b>Operating Pressure:</b> 101325 Pa"
551
+ )
552
+ if dataset == "Compressible flow over plane":
553
+ return (
554
+ "<b>Ambient Temperature:</b> 218 K<br><br>"
555
+ "<b>Cruising velocity:</b> 250.0 m/s or 560 mph"
556
+ )
557
+ return ""
558
+
559
+ def bc_text_left(dataset: str) -> str:
560
+ if dataset == "Compressible flow over plane":
561
+ return (
562
+ "<b>Reference Density:</b> 0.36 kg/m³<br><br>"
563
+ "<b>Reference viscosity:</b> 1.716e-05 kg/(m·s)<br><br>"
564
+ "<b>Operating Pressure:</b> 23842 Pa"
565
+ )
566
+ return ""
567
+
568
+ # =====================================================================
569
+ # ======================= APP =======================================
570
+ # =====================================================================
571
+ class PFMDemo(TrameApp):
572
+ def __init__(self, server=None):
573
+ super().__init__(server)
574
+
575
+ # ---------------- VTK RENDERERS ----------------
576
+ self.ren_geom = vtkRenderer()
577
+ self.ren_geom.SetBackground(0.10, 0.16, 0.22)
578
+ self.rw_geom = vtkRenderWindow()
579
+ self.rw_geom.SetOffScreenRendering(1)
580
+ self.rw_geom.AddRenderer(self.ren_geom)
581
+ self.rwi_geom = vtkRenderWindowInteractor()
582
+ self.rwi_geom.SetRenderWindow(self.rw_geom)
583
+ self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
584
+ try:
585
+ self.rwi_geom.Initialize()
586
+ self.rwi_geom.Enable()
587
+ except Exception:
588
+ pass
589
+
590
+ self.ren_pred = vtkRenderer()
591
+ self.ren_pred.SetBackground(0.10, 0.16, 0.22)
592
+ self.rw_pred = vtkRenderWindow()
593
+ self.rw_pred.SetOffScreenRendering(1)
594
+ self.rw_pred.AddRenderer(self.ren_pred)
595
+ self.rwi_pred = vtkRenderWindowInteractor()
596
+ self.rwi_pred.SetRenderWindow(self.rw_pred)
597
+ self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
598
+ try:
599
+ self.rwi_pred.Initialize()
600
+ self.rwi_pred.Enable()
601
+ except Exception:
602
+ pass
603
+
604
+ self.scalar_bar = None
605
+
606
+ # timers / flags
607
+ self._predict_t0 = None
608
+ self._infer_thread = None
609
+ self._pre_upload_thread = None
610
+ self._infer_heartbeat_on = False
611
+ self._loop = None
612
+
613
+ # ---------------- TRAME STATE ----------------
614
+ s = self.state
615
+ s.theme_dark = True
616
+
617
+ s.analysis_types = ANALYSIS_TYPE
618
+ s.analysis_type = DEFAULT_ANALYSIS_TYPE
619
+ s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE]
620
+ s.dataset = DEFAULT_DATASET
621
+ s.variable_choices = variables_for(DEFAULT_DATASET)
622
+ s.variable = s.variable_choices[0] if s.variable_choices else None
623
+
624
+ # dialog (still kept)
625
+ s.show_decimation_dialog = False
626
+ s.decim_override_enabled = False
627
+ s.decim_override_mode = "medium"
628
+ s.decim_override_custom = 0.5
629
+
630
+ # menu decimation defaults
631
+ s.decim_enable = False # user MUST toggle to override auto
632
+ s.decim_target = 0.5
633
+ s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced
634
+ s.decim_max_faces = int(1e7)
635
+
636
+ # register controller properly
637
+ # self.server.controller.decimate_again = self.decimate_again
638
+ # self.server.controller.add("decimate_again", self.decimate_again)
639
+ ctrl = self.server.controller
640
+
641
+ # ✅ this actually registers the trigger
642
+ ctrl.add("decimate_again", self.decimate_again)
643
+ ctrl.add("reset_mesh", self.reset_mesh)
644
+
645
+
646
+
647
+ s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car")
648
+ s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane")
649
+ s.velocity_mph = 45.0
650
+
651
+ s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET)
652
+ s.bc_left = bc_text_left(DEFAULT_DATASET)
653
+ s.bc_right = bc_text_right(DEFAULT_DATASET)
654
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
655
+
656
+ s.stats_html = "👋 Upload a geometry, then click Predict."
657
+ s.upload = None
658
+
659
+ # upload
660
+ s.is_uploading = False
661
+ s.pm_upload = 0
662
+ s.pm_elapsed_upload = 0.0
663
+ s.upload_msg = ""
664
+
665
+ # predict
666
+ s.is_predicting = False
667
+ s.predict_progress = 0
668
+ s.predict_msg = ""
669
+ s.predict_elapsed = 0.0
670
+ s.predict_est_total = 0.0
671
+ s.pm_infer = 0
672
+ s.pm_elapsed_infer = 0.0
673
+
674
+ self._build_ui()
675
+
676
+ def _ensure_loop(self):
677
+ if self._loop is not None:
678
+ return self._loop
679
+ try:
680
+ loop = asyncio.get_event_loop()
681
+ except RuntimeError:
682
+ loop = asyncio.new_event_loop()
683
+ asyncio.set_event_loop(loop)
684
+ self._loop = loop
685
+ return loop
686
+
687
+ def _run_coro(self, coro):
688
+ loop = self._ensure_loop()
689
+ if loop.is_running():
690
+ return asyncio.ensure_future(coro, loop=loop)
691
+ return loop.run_until_complete(coro)
692
+
693
+ async def _flush_async(self):
694
+ try:
695
+ self.server.state.flush()
696
+ except Exception:
697
+ pass
698
+ await asyncio.sleep(0)
699
+
700
+ def _build_ui(self):
701
+ ctrl = self.server.controller
702
+ with SinglePageLayout(self.server, full_height=True) as layout:
703
+ layout.title.set_text("") # clear
704
+ layout.title.hide = True # hide default
705
+ with layout.toolbar:
706
+ with v3.VContainer(
707
+ fluid=True,
708
+ style=(
709
+ "max-width: 1800px;" # overall width
710
+ "margin: 0 auto;" # center it
711
+ "padding: 0 8px;" # ← left/right margin
712
+ "box-sizing: border-box;"
713
+ ),
714
+ ):
715
+ v3.VSpacer()
716
+ html.Div(
717
+ "Ansys: Physics Foundation Model (Powered by Trame)",
718
+ style=(
719
+ "width:100%;"
720
+ "text-align:center;"
721
+ "font-size:34px;"
722
+ "font-weight:900;"
723
+ "letter-spacing:0.4px;"
724
+ "line-height:1.2;"
725
+ ),
726
+ )
727
+ v3.VSpacer()
728
+
729
+ # toolbar
730
+ with layout.toolbar:
731
+ # ← same margin container for the second toolbar row
732
+ with v3.VContainer(
733
+ fluid=True,
734
+ style=(
735
+ "max-width: 1800px;"
736
+ "margin: 0 auto;"
737
+ "padding: 0 8px;"
738
+ "box-sizing: border-box;"
739
+ ),
740
+ ):
741
+ v3.VSwitch(
742
+ v_model=("theme_dark",),
743
+ label="Dark Theme",
744
+ inset=True,
745
+ density="compact",
746
+ hide_details=True,
747
+ )
748
+ v3.VSpacer()
749
+
750
+ with layout.content:
751
+ html.Style("""
752
+ /* Small side padding for the whole app */
753
+ .v-application__wrap {
754
+ padding-left: 8px;
755
+ padding-right: 8px;
756
+ padding-bottom: 8px;
757
+ }
758
+
759
+ :root {
760
+ --pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
761
+ --pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace;
762
+ }
763
+
764
+ html, body, .v-application {
765
+ margin: 0;
766
+ padding: 0;
767
+ font-family: var(--pfm-font-ui) !important;
768
+ font-weight: 500;
769
+ letter-spacing: .25px;
770
+ -webkit-font-smoothing: antialiased;
771
+ -moz-osx-font-smoothing: grayscale;
772
+ text-rendering: optimizeLegibility;
773
+ line-height: 1.5;
774
+ font-size: 15.5px;
775
+ color: #ECEFF4;
776
+ }
777
+
778
+ /* ... keep all your other typography / button / slider styles here ... */
779
+
780
+ .v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; }
781
+ .v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; }
782
+
783
+ /* (rest of your .pfm-* classes unchanged) */
784
+ """)
785
+
786
+ # html.Style("""
787
+ # .v-theme--dark { background: #1F232B !important; }
788
+ # .v-theme--light { background: #f5f6f8 !important; }
789
+ # .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; }
790
+ # .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; }
791
+ # .v-theme--dark .pfm-viewer { background: #15171d !important; }
792
+ # .v-theme--light .pfm-viewer { background: #e9edf3 !important; }
793
+ # .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); }
794
+ # .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; }
795
+ # .pfm-btn-big.v-btn {
796
+ # height: 48px !important;
797
+ # font-size: 18px !important;
798
+ # font-weight: 600 !important;
799
+ # letter-spacing: 1.2px;
800
+ # text-transform: none !important;
801
+ # border-radius: 999px !important;
802
+ # }
803
+ # .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; }
804
+ # """)
805
+ html.Link(
806
+ rel="stylesheet",
807
+ href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap",
808
+ )
809
+
810
+ with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)):
811
+ with v3.VContainer(
812
+ fluid=True,
813
+ class_="pa-6",
814
+ style=(
815
+ "max-width: 2200px;" # max width of content
816
+ "margin: 8px auto 16px auto;" # top / left-right / bottom
817
+ "padding: 0 8px;" # inner left/right padding
818
+ "box-sizing: border-box;"
819
+ "background: rgba(255,255,255,0.02);"
820
+ "border-radius: 16px;"
821
+ ),
822
+ ):
823
+
824
+
825
+ # 1) Physics Application
826
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
827
+ html.Div(
828
+ "🧪 <b>Physics Application</b>",
829
+ style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;",
830
+ )
831
+ html.Div(
832
+ "Select the type of analysis",
833
+ style="font-size:24px;opacity:.82;margin-bottom:18px;",
834
+ )
835
+ toggle = v3.VBtnToggle(
836
+ v_model=("analysis_type", self.state.analysis_type),
837
+ class_="mt-1",
838
+ mandatory=True,
839
+ rounded=True,
840
+ )
841
+ # with toggle:
842
+ # for at in ANALYSIS_TYPE:
843
+ # v3.VBtn(
844
+ # at,
845
+ # value=at,
846
+ # variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
847
+ # class_="mr-2 pfm-toggle-xxl",
848
+ # )
849
+
850
+ with toggle:
851
+ for at in ANALYSIS_TYPE:
852
+ v3.VBtn(
853
+ at,
854
+ value=at,
855
+ variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
856
+ class_="mr-2 pfm-toggle-xxl",
857
+ style=(
858
+ "font-size:18px;"
859
+ "font-weight:800;"
860
+ "letter-spacing:0.4px;"
861
+ "text-transform:none;"
862
+ ),
863
+ )
864
+
865
+ # 2) Dataset + Variable
866
+ with v3.VRow(dense=True, class_="mb-3"):
867
+ with v3.VCol(cols=6):
868
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
869
+ html.Div(
870
+ "🧩 Sub Application",
871
+ style="font-weight:700;font-size:24px;margin-bottom:14px;",
872
+ )
873
+ v3.VSelect(
874
+ v_model=("dataset", self.state.dataset),
875
+ items=("dataset_choices", self.state.dataset_choices),
876
+ hide_details=True,
877
+ density="comfortable",
878
+ style=(
879
+ "font-size:24px;"
880
+ "font-weight:800;"
881
+ "height:56px;"
882
+ "display:flex;"
883
+ "align-items:center;"
884
+ ),
885
+ class_="pfm-big-select-subapp",
886
+ menu_props={"content_class": "pfm-subapp-list"}, # <— key for dropdown items
887
+ )
888
+ # v3.VSelect(
889
+ # v_model=("dataset", self.state.dataset),
890
+ # items=("dataset_choices", self.state.dataset_choices),
891
+ # hide_details=True,
892
+ # density="comfortable",
893
+ # class_="pfm-big-select-subapp pfm-subapp-list",
894
+ # style="font-size:21px;",
895
+ # )
896
+ with v3.VCol(cols=6):
897
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
898
+ html.Div(
899
+ "📊 Variable to Predict",
900
+ style="font-weight:700;font-size:20px;margin-bottom:14px;",
901
+ )
902
+
903
+ v3.VSelect(
904
+ v_model=("variable", self.state.variable),
905
+ items=("variable_choices", self.state.variable_choices),
906
+ hide_details=True,
907
+ density="comfortable",
908
+ class_="pfm-var-select",
909
+ style=(
910
+ "font-size:20px;"
911
+ "font-weight:800;"
912
+ "height:56px;"
913
+ "display:flex;"
914
+ "align-items:center;"
915
+ ),
916
+ menu_props={"content_class": "pfm-var-list"},
917
+ )
918
+ # v3.VSelect(
919
+ # v_model=("variable", self.state.variable),
920
+ # items=("variable_choices", self.state.variable_choices),
921
+ # hide_details=True,
922
+ # density="comfortable",
923
+ # style="font-size:16px;",
924
+ # )
925
+
926
+ # 3) Boundary Conditions
927
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
928
+ html.Div(
929
+ "🧱 Boundary Conditions",
930
+ style="font-weight:700;font-size:22px;margin-bottom:16px;",
931
+ )
932
+
933
+ # two columns: Left = velocity controls, Right = reference text
934
+ with v3.VRow(class_="align-start", dense=True):
935
+ # ---- LEFT: velocity slider / readout ----
936
+ with v3.VCol(cols=7, class_="pfm-vel"):
937
+ html.Div(
938
+ "🚗 Velocity (mph)",
939
+ class_="pfm-vel-title",
940
+ style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;",
941
+ )
942
+ html.Div(
943
+ "Set the inlet velocity in miles per hour",
944
+ class_="pfm-vel-sub",
945
+ style="margin-bottom:10px;font-size:20px;opacity:.95;",
946
+ )
947
+ v3.VSlider(
948
+ v_model=("velocity_mph", self.state.velocity_mph),
949
+ min=30.0, max=80.0, step=0.1,
950
+ thumb_label=True,
951
+ v_if=("show_velocity",),
952
+ style="height:54px;margin-top:12px;max-width:540px;",
953
+ class_="mt-3 mb-3 pfm-vel-slider",
954
+ )
955
+ html.Div(
956
+ "{{ velocity_mph.toFixed(0) }} / 80 "
957
+ "<span style='opacity:.95'>"
958
+ "({{ (velocity_mph * 0.44704).toFixed(2) }} m/s)</span>",
959
+ v_if=("show_velocity",),
960
+ class_="pfm-vel-readout",
961
+ style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:6px;",
962
+ )
963
+
964
+ # ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ----
965
+ with v3.VCol(cols=5, class_="pfm-bc-right"):
966
+ html.Div(
967
+ v_html=("bc_text_html", ""),
968
+ style=(
969
+ "margin-top:6px;"
970
+ "font-size:18px;"
971
+ "line-height:1.7;"
972
+ "min-width:260px;"
973
+ "max-width:360px;"
974
+ "text-align:left;"
975
+ ),
976
+ )
977
+
978
+
979
+
980
+ # 4) Two viewers
981
+ with v3.VRow(style="margin-top: 24px;"):
982
+ # LEFT = upload
983
+ with v3.VCol(cols=6):
984
+ with v3.VRow(class_="align-center justify-space-between mb-2"):
985
+ html.Div(
986
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📤 Input Geometry</span>",
987
+ )
988
+
989
+ # ✅ working gear menu
990
+ with v3.VMenu(
991
+ location="bottom end",
992
+ close_on_content_click=False,
993
+ offset="4 8",
994
+ ):
995
+ # activator slot MUST expose { props } and we MUST bind them to the button
996
+ with v3.Template(v_slot_activator="{ props }"):
997
+ with v3.VBtn(
998
+ icon=True,
999
+ variant="text",
1000
+ density="comfortable",
1001
+ style="min-width:32px;",
1002
+ v_bind="props", # 👈 this is the key
1003
+ ):
1004
+ v3.VIcon("mdi-cog", size="22")
1005
+
1006
+ # menu content
1007
+ with v3.VCard(class_="pa-4", style="min-width: 280px;"):
1008
+ html.Div("<b>Mesh decimation</b>", class_="mb-3", style="font-size:14px;")
1009
+
1010
+ v3.VSwitch(
1011
+ v_model=("decim_enable",),
1012
+ label="Enable decimation",
1013
+ inset=True,
1014
+ hide_details=True,
1015
+ class_="mb-4",
1016
+ )
1017
+
1018
+ html.Div(
1019
+ "Target reduction (fraction of faces to remove)",
1020
+ class_="mb-1",
1021
+ style="font-size:12px;color:#9ca3af;",
1022
+ )
1023
+ v3.VSlider(
1024
+ v_model=("decim_target",),
1025
+ min=0.0,
1026
+ max=0.999,
1027
+ step=0.001,
1028
+ hide_details=True,
1029
+ class_="mb-2",
1030
+ )
1031
+ html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3")
1032
+
1033
+ with v3.VRow(dense=True, class_="mb-3"):
1034
+ with v3.VCol(cols=6):
1035
+ html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1036
+ v3.VTextField(
1037
+ v_model=("decim_min_faces",),
1038
+ type="number",
1039
+ density="compact",
1040
+ hide_details=True,
1041
+ )
1042
+ with v3.VCol(cols=6):
1043
+ html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1044
+ v3.VTextField(
1045
+ v_model=("decim_max_faces",),
1046
+ type="number",
1047
+ density="compact",
1048
+ hide_details=True,
1049
+ )
1050
+
1051
+ v3.VBtn(
1052
+ "Apply to current mesh",
1053
+ block=True,
1054
+ color="primary",
1055
+ class_="mt-2",
1056
+ click=self.decimate_again,
1057
+ )
1058
+
1059
+ v3.VBtn(
1060
+ "Reset to original mesh",
1061
+ block=True,
1062
+ variant="tonal",
1063
+ class_="mt-2",
1064
+ click=self.reset_mesh, # 👈 will call the controller you added
1065
+ )
1066
+
1067
+
1068
+ v3.VFileInput(
1069
+ label="Select 3D File",
1070
+ style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;",
1071
+ multiple=False,
1072
+ show_size=True,
1073
+ accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb",
1074
+ v_model=("upload", None),
1075
+ clearable=True,
1076
+ )
1077
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1078
+ self.view_geom = VtkRemoteView(
1079
+ self.rw_geom,
1080
+ interactive=True,
1081
+ interactive_ratio=1,
1082
+ server=self.server,
1083
+ )
1084
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1085
+ rounded=True, elevation=3):
1086
+ html.Div("<b>Upload</b>", style="font-size:18px;")
1087
+ v3.VProgressLinear(
1088
+ v_model=("pm_upload", 0),
1089
+ height=22,
1090
+ style="margin-top:10px;margin-bottom:10px;",
1091
+ color="primary",
1092
+ rounded=True,
1093
+ )
1094
+ html.Div(
1095
+ "{{ upload_msg }} — {{ pm_elapsed_upload.toFixed(2) }}s",
1096
+ style="font-size:14px;",
1097
+ )
1098
+ v3.VBtn(
1099
+ "🗑️ CLEAR",
1100
+ block=True,
1101
+ variant="tonal",
1102
+ class_="mt-3 pfm-btn-big",
1103
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1104
+ click=self.clear,
1105
+ )
1106
+ # RIGHT = prediction
1107
+ with v3.VCol(cols=6):
1108
+ html.Div(
1109
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📈 Prediction Results</span>",
1110
+ style="margin-bottom:10px;",
1111
+ )
1112
+ html.Div(
1113
+ v_html=("stats_html",),
1114
+ class_="mb-3",
1115
+ style="font-size:20px;line-height:1.4;",
1116
+ )
1117
+ v3.VProgressLinear(
1118
+ v_model=("predict_progress", 0),
1119
+ height=22,
1120
+ style="margin-top:6px;margin-bottom:12px;",
1121
+ color="primary",
1122
+ rounded=True,
1123
+ indeterminate=("predict_progress < 10",),
1124
+ v_show=("is_predicting",),
1125
+ )
1126
+ html.Div(
1127
+ "Predicting: {{ predict_progress }}%",
1128
+ style="font-size:14px;margin-bottom:10px;",
1129
+ v_show=("is_predicting",),
1130
+ )
1131
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1132
+ self.view_pred = VtkRemoteView(
1133
+ self.rw_pred,
1134
+ interactive=True,
1135
+ interactive_ratio=1,
1136
+ server=self.server,
1137
+ )
1138
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1139
+ rounded=True, elevation=3):
1140
+ html.Div("<b>Inference</b>", style="font-size:18px;")
1141
+ v3.VProgressLinear(
1142
+ v_model=("pm_infer", 0),
1143
+ height=22,
1144
+ style="margin-top:10px;margin-bottom:10px;",
1145
+ color="success",
1146
+ rounded=True,
1147
+ )
1148
+ html.Div(
1149
+ "{{ pm_infer }}% — {{ pm_elapsed_infer.toFixed(2) }}s",
1150
+ style="font-size:14px;",
1151
+ )
1152
+ v3.VBtn(
1153
+ "🚀 PREDICT",
1154
+ block=True,
1155
+ color="primary",
1156
+ class_="mt-3 pfm-btn-big",
1157
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1158
+ click=self.predict,
1159
+ )
1160
+
1161
+ layout.on_ready = self._first_paint
1162
+
1163
+ def _first_paint(self, **_):
1164
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1165
+ try:
1166
+ rw.Render()
1167
+ except Exception:
1168
+ pass
1169
+ view.update()
1170
+
1171
+ # ---------------------------------------------------------
1172
+ # UPLOAD (async)
1173
+ # ---------------------------------------------------------
1174
+ def _write_upload_to_disk(self, payload) -> str:
1175
+ if payload is None:
1176
+ raise ValueError("No file payload")
1177
+ if isinstance(payload, (list, tuple)):
1178
+ payload = payload[0]
1179
+ if isinstance(payload, str):
1180
+ return payload
1181
+ if not isinstance(payload, dict):
1182
+ raise ValueError(f"Unsupported payload: {type(payload)}")
1183
+ if payload.get("path"):
1184
+ return payload["path"]
1185
+ name = payload.get("name") or "upload"
1186
+ content = payload.get("content")
1187
+ if isinstance(content, str) and content.startswith("data:"):
1188
+ content = content.split(",", 1)[1]
1189
+ raw = base64.b64decode(content) if isinstance(content, str) else bytes(content)
1190
+ os.makedirs(GEOM_DIR, exist_ok=True)
1191
+ file_path = os.path.join(GEOM_DIR, name)
1192
+ with open(file_path, "wb") as f:
1193
+ f.write(raw)
1194
+ return file_path
1195
+
1196
+ def _pre_upload_spinner_loop(self):
1197
+ s = self.state
1198
+ phase = 1
1199
+ while self._pre_upload_on and not self._upload_actual_started and s.is_uploading:
1200
+ s.pm_upload = max(1, min(9, phase))
1201
+ s.upload_msg = "Initializing upload..."
1202
+ try:
1203
+ self.server.state.flush()
1204
+ except Exception:
1205
+ pass
1206
+ phase = 1 if phase >= 9 else phase + 1
1207
+ time.sleep(0.15)
1208
+
1209
+ def _start_pre_upload_spinner(self):
1210
+ if self._pre_upload_thread and self._pre_upload_thread.is_alive():
1211
+ return
1212
+ self._upload_actual_started = False
1213
+ self._pre_upload_on = True
1214
+ self._pre_upload_thread = threading.Thread(
1215
+ target=self._pre_upload_spinner_loop, daemon=True
1216
+ )
1217
+ self._pre_upload_thread.start()
1218
+
1219
+ def _stop_pre_upload_spinner(self):
1220
+ self._pre_upload_on = False
1221
+ self._pre_upload_thread = None
1222
+
1223
+ async def _fake_upload_bump(self, stop_event: asyncio.Event):
1224
+ s = self.state
1225
+ while not stop_event.is_set() and s.pm_upload < 9:
1226
+ s.pm_upload += 1
1227
+ await self._flush_async()
1228
+ await asyncio.sleep(0.05)
1229
+
1230
+ async def _upload_worker_async(self, upload):
1231
+ s = self.state
1232
+ loop = self._ensure_loop()
1233
+ t0 = time.time()
1234
+
1235
+ s.is_uploading = True
1236
+ s.upload_msg = "Starting upload..."
1237
+ s.pm_elapsed_upload = 0.0
1238
+
1239
+ s.pm_upload = 5
1240
+ self.server.state.flush()
1241
+ await asyncio.sleep(0)
1242
+
1243
+ fake_stop = asyncio.Event()
1244
+ fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop))
1245
+
1246
+ try:
1247
+ self._upload_actual_started = True
1248
+ self._stop_pre_upload_spinner()
1249
+
1250
+ if not fake_stop.is_set():
1251
+ fake_stop.set()
1252
+ try:
1253
+ await fake_task
1254
+ except asyncio.CancelledError:
1255
+ pass
1256
+
1257
+ s.upload_msg = "Writing file to disk..."
1258
+ s.pm_upload = 10
1259
+ s.pm_elapsed_upload = time.time() - t0
1260
+ await self._flush_async()
1261
+ file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload)
1262
+
1263
+ s.upload_msg = "Reading mesh..."
1264
+ s.pm_upload = 20
1265
+ s.pm_elapsed_upload = time.time() - t0
1266
+ await self._flush_async()
1267
+ mesh = await loop.run_in_executor(None, pv.read, file_path)
1268
+
1269
+ # 3) decimation (auto first)
1270
+ try:
1271
+ nf = poly_count(mesh)
1272
+ except Exception:
1273
+ nf = mesh.n_cells
1274
+
1275
+ auto_tr = float(auto_target_reduction(nf))
1276
+
1277
+ # reflect auto in UI
1278
+ s.decim_target = auto_tr
1279
+ s.decim_min_faces = 5000 # <= allow decimation even for 27k faces
1280
+ s.decim_max_faces = int(1e7)
1281
+
1282
+ target = auto_tr
1283
+ min_faces = 5000
1284
+ max_faces = int(1e7)
1285
+
1286
+ # user override
1287
+ if self.state.decim_enable:
1288
+ target = float(self.state.decim_target or 0.0)
1289
+ min_faces = int(self.state.decim_min_faces or 5000)
1290
+ max_faces = int(self.state.decim_max_faces or 1e7)
1291
+
1292
+ if target > 0.0:
1293
+ s.upload_msg = f"Decimating mesh ({target:.3f})..."
1294
+ s.pm_upload = max(s.pm_upload, 45)
1295
+ s.pm_elapsed_upload = time.time() - t0
1296
+ await self._flush_async()
1297
+
1298
+ dec_cfg = {
1299
+ "enabled": True,
1300
+ "method": "pro",
1301
+ "target_reduction": target,
1302
+ "min_faces": min_faces,
1303
+ "max_faces": max_faces,
1304
+ }
1305
+ mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg)
1306
+
1307
+ # 4) normals + save
1308
+ s.upload_msg = "Preparing geometry..."
1309
+ s.pm_upload = 75
1310
+ s.pm_elapsed_upload = time.time() - t0
1311
+ await self._flush_async()
1312
+
1313
+ def _normals_and_save(m):
1314
+ m_fixed = m.compute_normals(
1315
+ consistent_normals=True,
1316
+ auto_orient_normals=True,
1317
+ point_normals=True,
1318
+ cell_normals=False,
1319
+ inplace=False,
1320
+ )
1321
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1322
+ m_fixed.save(geom_path_)
1323
+ return geom_path_, m_fixed
1324
+
1325
+ geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh)
1326
+
1327
+ # 5) update viewer
1328
+ self.ren_geom.RemoveAllViewProps()
1329
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1330
+ self.ren_geom.ResetCamera()
1331
+ try:
1332
+ self.rw_geom.Render()
1333
+ except Exception:
1334
+ pass
1335
+ self.view_geom.update()
1336
+ # GEOMETRY_CACHE.current_mesh = mesh_fixed
1337
+ GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True)
1338
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1339
+
1340
+ s.upload_msg = "✅ Geometry ready."
1341
+ s.pm_upload = 100
1342
+ s.pm_elapsed_upload = time.time() - t0
1343
+ await self._flush_async()
1344
+
1345
+ except Exception as e:
1346
+ s.upload_msg = f"❌ Upload failed: {e}"
1347
+ s.pm_upload = 0
1348
+ s.pm_elapsed_upload = time.time() - t0
1349
+ await self._flush_async()
1350
+ finally:
1351
+ s.is_uploading = False
1352
+ s.pm_elapsed_upload = time.time() - t0
1353
+ await self._flush_async()
1354
+
1355
+ if not fake_stop.is_set():
1356
+ fake_stop.set()
1357
+ if not fake_task.done():
1358
+ fake_task.cancel()
1359
+ try:
1360
+ await fake_task
1361
+ except Exception:
1362
+ pass
1363
+
1364
+ @change("upload")
1365
+ def _on_upload_change(self, upload, **_):
1366
+ if not upload:
1367
+ return
1368
+ self._run_coro(self._upload_worker_async(upload))
1369
+
1370
+ def decimate_again(self):
1371
+ self._run_coro(self._decimate_again_async())
1372
+
1373
+ async def _decimate_again_async(self):
1374
+ s = self.state
1375
+ loop = self._ensure_loop()
1376
+
1377
+ if GEOMETRY_CACHE.current_mesh is None:
1378
+ s.upload_msg = "No mesh to re-decimate"
1379
+ await self._flush_async()
1380
+ return
1381
+
1382
+ try:
1383
+ target = float(s.decim_target)
1384
+ except Exception:
1385
+ target = 0.0
1386
+
1387
+ try:
1388
+ min_faces = int(s.decim_min_faces)
1389
+ except Exception:
1390
+ min_faces = 5000
1391
+
1392
+ try:
1393
+ max_faces = int(s.decim_max_faces)
1394
+ except Exception:
1395
+ max_faces = int(1e7)
1396
+
1397
+ if (not s.decim_enable) or target <= 0.0:
1398
+ s.upload_msg = "Decimation disabled"
1399
+ await self._flush_async()
1400
+ return
1401
+
1402
+ s.upload_msg = f"Re-decimating ({target:.3f})..."
1403
+ await self._flush_async()
1404
+
1405
+ dec_cfg = {
1406
+ "enabled": True,
1407
+ "method": "pro",
1408
+ "target_reduction": target,
1409
+ "min_faces": min_faces,
1410
+ "max_faces": max_faces,
1411
+ }
1412
+
1413
+ mesh = await loop.run_in_executor(None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg)
1414
+
1415
+ def _normals_and_save(m):
1416
+ m_fixed = m.compute_normals(
1417
+ consistent_normals=True,
1418
+ auto_orient_normals=True,
1419
+ point_normals=True,
1420
+ cell_normals=False,
1421
+ inplace=False,
1422
+ )
1423
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1424
+ m_fixed.save(geom_path_)
1425
+ return geom_path_, m_fixed
1426
+
1427
+ geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh)
1428
+
1429
+ self.ren_geom.RemoveAllViewProps()
1430
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1431
+ self.ren_geom.ResetCamera()
1432
+ try:
1433
+ self.rw_geom.Render()
1434
+ except Exception:
1435
+ pass
1436
+ self.view_geom.update()
1437
+
1438
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1439
+ s.upload_msg = "✅ Re-decimated"
1440
+ await self._flush_async()
1441
+
1442
+ def reset_mesh(self):
1443
+ self._run_coro(self._reset_mesh_async())
1444
+
1445
+ async def _reset_mesh_async(self):
1446
+ s = self.state
1447
+
1448
+ if GEOMETRY_CACHE.original_mesh is None:
1449
+ s.upload_msg = "No original mesh to reset to"
1450
+ await self._flush_async()
1451
+ return
1452
+
1453
+ # use the saved original
1454
+ orig = GEOMETRY_CACHE.original_mesh
1455
+
1456
+ # save it again as current
1457
+ GEOMETRY_CACHE.current_mesh = orig
1458
+
1459
+ # write to disk (so the STL on disk matches the viewer)
1460
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
1461
+ orig.save(geom_path)
1462
+
1463
+ # update viewer
1464
+ self.ren_geom.RemoveAllViewProps()
1465
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1466
+ self.ren_geom.ResetCamera()
1467
+ try:
1468
+ self.rw_geom.Render()
1469
+ except Exception:
1470
+ pass
1471
+ self.view_geom.update()
1472
+
1473
+ s.upload_msg = "↩️ Reset to original mesh"
1474
+ await self._flush_async()
1475
+
1476
+ # ---------------------------------------------------------
1477
+ # prediction
1478
+ # ---------------------------------------------------------
1479
+ def _start_infer_heartbeat(self):
1480
+ if self._infer_thread and self._infer_thread.is_alive():
1481
+ return
1482
+
1483
+ def loop_fn():
1484
+ while self._infer_heartbeat_on:
1485
+ if self.state.is_predicting and self._predict_t0 is not None:
1486
+ self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0)
1487
+ try:
1488
+ self.server.state.flush()
1489
+ except Exception:
1490
+ pass
1491
+ time.sleep(0.12)
1492
+
1493
+ self._infer_heartbeat_on = True
1494
+ self._infer_thread = threading.Thread(target=loop_fn, daemon=True)
1495
+ self._infer_thread.start()
1496
+
1497
+ def _stop_infer_heartbeat(self):
1498
+ self._infer_heartbeat_on = False
1499
+ self._infer_thread = None
1500
+
1501
+ async def _predict_worker_async(self):
1502
+ s = self.state
1503
+ loop = self._ensure_loop()
1504
+ t0 = time.time()
1505
+
1506
+ if GEOMETRY_CACHE.current_mesh is None:
1507
+ s.predict_msg = "❌ Please upload geometry first"
1508
+ s.is_predicting = False
1509
+ await self._flush_async()
1510
+ return
1511
+
1512
+ s.is_predicting = True
1513
+ s.predict_progress = 1
1514
+ s.pm_infer = 1
1515
+ s.predict_msg = "Preparing inference..."
1516
+ self._predict_t0 = time.time()
1517
+ self._start_infer_heartbeat()
1518
+ await self._flush_async()
1519
+
1520
+ try:
1521
+ dataset = s.dataset
1522
+ variable = s.variable
1523
+ boundary = (
1524
+ {"freestream_velocity": mph_to_ms(s.velocity_mph)}
1525
+ if dataset == "Incompressible flow over car"
1526
+ else None
1527
+ )
1528
+
1529
+ s.predict_msg = "Loading model/checkpoint..."
1530
+ s.predict_progress = 5
1531
+ s.pm_infer = 5
1532
+ await self._flush_async()
1533
+
1534
+ cfg, model, device, _ = await loop.run_in_executor(
1535
+ None, MODEL_STORE.get, dataset, variable, None
1536
+ )
1537
+
1538
+ s.predict_msg = "Processing mesh for inference..."
1539
+ s.predict_progress = 35
1540
+ s.pm_infer = 35
1541
+ await self._flush_async()
1542
+
1543
+ def _run_full():
1544
+ return run_inference_fast(
1545
+ dataset,
1546
+ variable,
1547
+ boundary_conditions=boundary,
1548
+ progress_cb=None,
1549
+ )
1550
+ viz = await loop.run_in_executor(None, _run_full)
1551
+
1552
+ s.predict_msg = "Preparing visualization..."
1553
+ s.predict_progress = 85
1554
+ s.pm_infer = 85
1555
+ await self._flush_async()
1556
+
1557
+ stl_path = os.path.join(GEOM_DIR, "geometry.stl")
1558
+ vmin = float(np.min(viz["pred"]))
1559
+ vmax = float(np.max(viz["pred"]))
1560
+
1561
+ if os.path.exists(stl_path):
1562
+ _tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path)
1563
+ lut = build_jet_lut(vmin, vmax)
1564
+ colored_actor = color_actor_with_scalars_from_prediction(
1565
+ stl_path,
1566
+ viz["points"],
1567
+ viz["pred"],
1568
+ variable,
1569
+ vmin,
1570
+ vmax,
1571
+ lut=lut,
1572
+ )
1573
+ self.ren_pred.AddActor(colored_actor)
1574
+
1575
+ units = {
1576
+ "pressure": "Pa",
1577
+ "x_velocity": "m/s",
1578
+ "y_velocity": "m/s",
1579
+ "z_velocity": "m/s",
1580
+ }.get(variable, "")
1581
+ title = f"{variable} ({units})" if units else variable
1582
+ self.scalar_bar = add_or_update_scalar_bar(
1583
+ self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8
1584
+ )
1585
+
1586
+ src_cam = self.ren_geom.GetActiveCamera()
1587
+ dst_cam = self.ren_pred.GetActiveCamera()
1588
+ if src_cam is not None and dst_cam is not None:
1589
+ dst_cam.SetPosition(src_cam.GetPosition())
1590
+ dst_cam.SetFocalPoint(src_cam.GetFocalPoint())
1591
+ dst_cam.SetViewUp(src_cam.GetViewUp())
1592
+ dst_cam.SetParallelScale(src_cam.GetParallelScale())
1593
+ cr = src_cam.GetClippingRange()
1594
+ dst_cam.SetClippingRange(cr)
1595
+
1596
+ try:
1597
+ self.rw_pred.Render()
1598
+ except Exception:
1599
+ pass
1600
+ self.view_pred.update()
1601
+
1602
+ raw_vmin = float(np.min(viz["pred"]))
1603
+ raw_vmax = float(np.max(viz["pred"]))
1604
+
1605
+ s.stats_html = (
1606
+ f"<b>{variable} min:</b>{raw_vmin:.3e} "
1607
+ f"<b>max:</b> {raw_vmax:.3e} "
1608
+ f"<b>Confidence:</b> {viz['confidence_score']:.4f}"
1609
+ )
1610
+
1611
+ s.predict_msg = "✅ Prediction complete."
1612
+ s.predict_progress = 100
1613
+ s.pm_infer = 100
1614
+ s.predict_elapsed = time.time() - t0
1615
+ s.pm_elapsed_infer = s.predict_elapsed
1616
+ await self._flush_async()
1617
+
1618
+ except Exception as e:
1619
+ s.predict_msg = f"❌ Prediction failed: {e}"
1620
+ s.predict_progress = 0
1621
+ s.pm_infer = 0
1622
+ await self._flush_async()
1623
+ finally:
1624
+ s.is_predicting = False
1625
+ self._stop_infer_heartbeat()
1626
+ await self._flush_async()
1627
+
1628
+ @time_function("Inference and Visualization")
1629
+ def predict(self, *_):
1630
+ self._run_coro(self._predict_worker_async())
1631
+
1632
+ # ---------------------------------------------------------
1633
+ # dataset wiring
1634
+ # ---------------------------------------------------------
1635
+ @change("analysis_type")
1636
+ def _on_analysis_type_change(self, analysis_type=None, **_):
1637
+ ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", [])
1638
+ default_ds = ds_list[0] if ds_list else None
1639
+ self.state.dataset_choices = ds_list
1640
+ if default_ds and self.state.dataset != default_ds:
1641
+ self.state.dataset = default_ds
1642
+ elif self.state.dataset:
1643
+ self._apply_dataset(self.state.dataset)
1644
+
1645
+ @change("dataset")
1646
+ def _on_dataset_change(self, dataset=None, **_):
1647
+ if not dataset:
1648
+ return
1649
+ self._apply_dataset(dataset)
1650
+
1651
+ def _apply_dataset(self, ds: str):
1652
+ s = self.state
1653
+ opts = variables_for(ds) if ds else []
1654
+ s.variable_choices = opts
1655
+ s.variable = opts[0] if opts else None
1656
+
1657
+ s.show_velocity = (ds == "Incompressible flow over car")
1658
+ s.is_plane = (ds == "Compressible flow over plane")
1659
+
1660
+ s.bc_text = get_boundary_conditions_text(ds)
1661
+ s.bc_left = bc_text_left(ds)
1662
+ s.bc_right = bc_text_right(ds)
1663
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
1664
+
1665
+ # ---------------------------------------------------------
1666
+ # clear
1667
+ # ---------------------------------------------------------
1668
+ def clear(self, *_):
1669
+ for d in [GEOM_DIR, SOLN_DIR]:
1670
+ if os.path.exists(d):
1671
+ shutil.rmtree(d)
1672
+ os.makedirs(d, exist_ok=True)
1673
+ s = self.state
1674
+ s.stats_html = "🧹 Cleared. Upload again."
1675
+ s.is_uploading = False
1676
+ s.pm_upload = 0
1677
+ s.upload_msg = ""
1678
+ s.pm_elapsed_upload = 0.0
1679
+ s.is_predicting = False
1680
+ s.predict_progress = 0
1681
+ s.predict_msg = ""
1682
+ s.pm_infer = 0
1683
+ s.pm_elapsed_infer = 0.0
1684
+ self.ren_geom.RemoveAllViewProps()
1685
+ self.ren_pred.RemoveAllViewProps()
1686
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1687
+ try:
1688
+ rw.Render()
1689
+ except Exception:
1690
+ pass
1691
+ view.update()
1692
+
1693
+ # ---------- main ----------
1694
+ def main():
1695
+ app = PFMDemo()
1696
+ app.server.controller.add("decimate_again", app.decimate_again)
1697
+ app.server.controller.add("reset_mesh", app.reset_mesh)
1698
+ app.server.start(7872)
1699
+
1700
+ if __name__ == "__main__":
1701
+ main()
Main_app_trame2.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Headless / Offscreen safety (before any VTK import) =========
2
+ import os
3
+ os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1")
4
+ os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
5
+ os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe")
6
+ os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3")
7
+ os.environ.setdefault("DISPLAY", "")
8
+
9
+ # ========= Core setup =========
10
+ import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio
11
+ import numpy as np
12
+ import torch
13
+ import pyvista as pv
14
+ from scipy.spatial import cKDTree
15
+ from vtk.util import numpy_support as nps
16
+ import matplotlib.cm as cm
17
+
18
+ from omegaconf import OmegaConf
19
+ from huggingface_hub import hf_hub_download
20
+ from accelerate import Accelerator
21
+ from accelerate.utils import DistributedDataParallelKwargs
22
+ import pickle
23
+ from sklearn.metrics import pairwise_distances
24
+ from train import get_single_latent
25
+ from sklearn.neighbors import NearestNeighbors
26
+
27
+ from utils.app_utils2 import (
28
+ create_visualization_points,
29
+ create_visualization_stl,
30
+ camera_from_bounds,
31
+ bounds_from_points,
32
+ convert_vtp_to_glb,
33
+ convert_vtp_to_stl,
34
+ time_function,
35
+ print_timing,
36
+ mesh_get_variable,
37
+ mph_to_ms,
38
+ get_boundary_conditions_text,
39
+ compute_confidence_score,
40
+ compute_cosine_score,
41
+ decimate_mesh,
42
+ )
43
+
44
+ # ========= trame =========
45
+ from trame.app import TrameApp
46
+ from trame.decorators import change
47
+ from trame.ui.vuetify3 import SinglePageLayout
48
+ from trame.widgets import vuetify3 as v3, html
49
+ from trame_vtk.widgets.vtk import VtkRemoteView
50
+
51
+ # ========= VTK =========
52
+ from vtkmodules.vtkRenderingCore import (
53
+ vtkRenderer,
54
+ vtkRenderWindow,
55
+ vtkPolyDataMapper,
56
+ vtkActor,
57
+ vtkRenderWindowInteractor,
58
+ )
59
+ from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor
60
+ from vtkmodules.vtkIOGeometry import vtkSTLReader
61
+ from vtkmodules.vtkFiltersCore import vtkTriangleFilter
62
+ from vtkmodules.vtkCommonCore import vtkLookupTable
63
+ from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
64
+
65
+ # ========= Writable paths / caches =========
66
+ DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata")
67
+ os.makedirs(DATA_DIR, exist_ok=True)
68
+ os.environ.setdefault("MPLCONFIGDIR", DATA_DIR)
69
+
70
+ GEOM_DIR = os.path.join(DATA_DIR, "geometry")
71
+ SOLN_DIR = os.path.join(DATA_DIR, "solution")
72
+ WEIGHTS_DIR = os.path.join(DATA_DIR, "weights")
73
+ for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR):
74
+ os.makedirs(d, exist_ok=True)
75
+
76
+ HF_DIR = os.path.join(DATA_DIR, "hf_home")
77
+ os.environ.setdefault("HF_HOME", HF_DIR)
78
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR)
79
+ os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR)
80
+ os.makedirs(HF_DIR, exist_ok=True)
81
+ for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR):
82
+ if not os.access(p, os.W_OK):
83
+ raise RuntimeError(f"Not writable: {p}")
84
+
85
+ # ========= Auto-decimation ladder =========
86
+ def auto_target_reduction(num_cells: int) -> float:
87
+ if num_cells <= 10_000:
88
+ return 0.0
89
+ elif num_cells <= 20_000:
90
+ return 0.2
91
+ elif num_cells <= 50_000:
92
+ return 0.4
93
+ elif num_cells <= 100_000:
94
+ return 0.5
95
+ elif num_cells <= 500_000:
96
+ return 0.6
97
+ elif num_cells < 1_000_000:
98
+ return 0.8
99
+ else:
100
+ return 0.9
101
+
102
+ # ========= Registry / choices =========
103
+ REGISTRY = {
104
+ "Incompressible flow inside artery": {
105
+ "repo_id": "ansysresearch/pretrained_models",
106
+ "config": "configs/app_configs/Incompressible flow inside artery/config.yaml",
107
+ "model_attr": "ansysLPFMs",
108
+ "checkpoints": {"best_val": "ckpt_artery.pt"},
109
+ "out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"],
110
+ },
111
+ "Vehicle crash analysis": {
112
+ "repo_id": "ansysresearch/pretrained_models",
113
+ "config": "configs/app_configs/Vehicle crash analysis/config.yaml",
114
+ "model_attr": "ansysLPFMs",
115
+ "checkpoints": {"best_val": "ckpt_vehiclecrash.pt"},
116
+ "out_variable": [
117
+ "impact_force",
118
+ "deformation",
119
+ "energy_absorption",
120
+ "x_displacement",
121
+ "y_displacement",
122
+ "z_displacement",
123
+ ],
124
+ },
125
+ "Compressible flow over plane": {
126
+ "repo_id": "ansysresearch/pretrained_models",
127
+ "config": "configs/app_configs/Compressible flow over plane/config.yaml",
128
+ "model_attr": "ansysLPFMs",
129
+ "checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"},
130
+ "out_variable": ["pressure"],
131
+ },
132
+ "Incompressible flow over car": {
133
+ "repo_id": "ansysresearch/pretrained_models",
134
+ "config": "configs/app_configs/Incompressible flow over car/config.yaml",
135
+ "model_attr": "ansysLPFMs",
136
+ "checkpoints": {"best_val": "ckpt_cadillac_v3.pt"},
137
+ "out_variable": ["pressure"],
138
+ },
139
+ }
140
+
141
+ def variables_for(dataset: str):
142
+ spec = REGISTRY.get(dataset, {})
143
+ ov = spec.get("out_variable")
144
+ if isinstance(ov, str):
145
+ return [ov]
146
+ if isinstance(ov, (list, tuple)):
147
+ return list(ov)
148
+ return list(spec.get("checkpoints", {}).keys())
149
+
150
+ ANALYSIS_TYPE_MAPPING = {
151
+ "External flow": ["Incompressible flow over car", "Compressible flow over plane"],
152
+ "Internal flow": ["Incompressible flow inside artery"],
153
+ "Crash analysis": ["Vehicle crash analysis"],
154
+ }
155
+ ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys())
156
+ DEFAULT_ANALYSIS_TYPE = "External flow"
157
+ DEFAULT_DATASET = "Incompressible flow over car"
158
+ VAR_CHOICES0 = variables_for(DEFAULT_DATASET)
159
+ DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None
160
+
161
+ # ========= Simple cache =========
162
+ class GeometryCache:
163
+ def __init__(self):
164
+ self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation
165
+ self.current_mesh = None # what the app is actually using right now
166
+
167
+ GEOMETRY_CACHE = GeometryCache()
168
+
169
+ # ========= Model store =========
170
+ class ModelStore:
171
+ def __init__(self):
172
+ self._cache = {}
173
+
174
+ def _build(self, dataset: str, progress_cb=None):
175
+ def tick(x):
176
+ if progress_cb:
177
+ try:
178
+ progress_cb(int(x))
179
+ except:
180
+ pass
181
+
182
+ if dataset in self._cache:
183
+ tick(12)
184
+ return self._cache[dataset]
185
+
186
+ print(f"🔧 Building model for {dataset}")
187
+ start_time = time.time()
188
+ try:
189
+ spec = REGISTRY[dataset]
190
+ repo_id = spec["repo_id"]
191
+ ckpt_name = spec["checkpoints"]["best_val"]
192
+
193
+ tick(6)
194
+ t0 = time.time()
195
+ ckpt_path_hf = hf_hub_download(
196
+ repo_id=repo_id,
197
+ filename=ckpt_name,
198
+ repo_type="model",
199
+ local_dir=HF_DIR,
200
+ local_dir_use_symlinks=False,
201
+ )
202
+ print_timing("Model checkpoint download", t0)
203
+ tick(8)
204
+
205
+ t0 = time.time()
206
+ ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset)
207
+ os.makedirs(ckpt_local_dir, exist_ok=True)
208
+ ckpt_path = os.path.join(ckpt_local_dir, ckpt_name)
209
+ if not os.path.exists(ckpt_path):
210
+ shutil.copy(ckpt_path_hf, ckpt_path)
211
+ print_timing("Local model copy setup", t0)
212
+ tick(9)
213
+
214
+ t0 = time.time()
215
+ cfg_path = spec["config"]
216
+ if not os.path.exists(cfg_path):
217
+ raise FileNotFoundError(f"Missing config: {cfg_path}")
218
+ cfg = OmegaConf.load(cfg_path)
219
+ cfg.save_latent = True
220
+ print_timing("Configuration loading", t0)
221
+ tick(11)
222
+
223
+ t0 = time.time()
224
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0))
225
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
226
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
227
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
+ print_timing("Device init", t0)
229
+ tick(12)
230
+
231
+ t0 = time.time()
232
+ import models
233
+ model_cls_name = spec["model_attr"]
234
+ if not hasattr(models, model_cls_name):
235
+ raise ValueError(f"Model '{model_cls_name}' not found")
236
+ model = getattr(models, model_cls_name)(cfg).to(device)
237
+ print_timing("Model build", t0)
238
+ tick(14)
239
+
240
+ t0 = time.time()
241
+ state = torch.load(ckpt_path, map_location=device)
242
+ model.load_state_dict(state)
243
+ model.eval()
244
+ print_timing("Weights load", t0)
245
+ tick(15)
246
+
247
+ result = (cfg, model, device, accelerator)
248
+ self._cache[dataset] = result
249
+ print_timing(f"Total model build for {dataset}", start_time)
250
+ return result
251
+ except Exception as e:
252
+ print_timing(f"Model build failed for {dataset}", e)
253
+ raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}")
254
+
255
+ def get(self, dataset: str, variable: str, progress_cb=None):
256
+ return self._build(dataset, progress_cb=progress_cb)
257
+
258
+ MODEL_STORE = ModelStore()
259
+
260
+ # ========= Inference pipeline =========
261
+ def _variable_index(dataset: str, variable: str) -> int:
262
+ ov = REGISTRY[dataset]["out_variable"]
263
+ return 0 if isinstance(ov, str) else ov.index(variable)
264
+
265
+ @time_function("Mesh Processing")
266
+ def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None):
267
+ jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
268
+ json_data = json.load(open(jpath, "r"))
269
+
270
+ pts = np.asarray(mesh.points, dtype=np.float32)
271
+ N = pts.shape[0]
272
+ rng = np.random.default_rng(42)
273
+ idx = rng.permutation(N)
274
+ points = pts[idx]
275
+ tgt_np = mesh_get_variable(mesh, variable, N)[idx]
276
+ pos = torch.from_numpy(points)
277
+ target = torch.from_numpy(tgt_np).unsqueeze(-1)
278
+
279
+ if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None:
280
+ if "freestream_velocity" in boundary_conditions:
281
+ inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1)
282
+ inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx]
283
+ pos = torch.cat((pos, inlet_x_velocity), dim=1)
284
+
285
+ if getattr(cfg, "input_normalization", None) == "shift_axis":
286
+ coords = pos[:, :3].clone()
287
+ coords[:, 0] = coords[:, 0] - coords[:, 0].min()
288
+ coords[:, 2] = coords[:, 2] - coords[:, 2].min()
289
+ y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
290
+ coords[:, 1] = coords[:, 1] - y_center
291
+ pos[:, :3] = coords
292
+
293
+ if getattr(cfg, "pos_embed_sincos", False):
294
+ mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32)
295
+ maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32)
296
+ pos = 1000.0 * (pos - mins) / (maxs - mins)
297
+ pos = torch.clamp(pos, 0, 1000)
298
+
299
+ cosine_score = compute_cosine_score(mesh, dataset)
300
+ return pos, target, points, cosine_score
301
+
302
+ @time_function("Inference")
303
+ def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None):
304
+ def p(v):
305
+ if progress_cb:
306
+ try:
307
+ progress_cb(int(v))
308
+ except Exception:
309
+ pass
310
+
311
+ if GEOMETRY_CACHE.current_mesh is None:
312
+ raise ValueError("No geometry loaded")
313
+
314
+ p(5)
315
+ cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p)
316
+ p(15)
317
+
318
+ pos, target, points, cosine_score = process_mesh_fast(
319
+ GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions
320
+ )
321
+ p(25)
322
+
323
+ confidence_score = 0.0
324
+ try:
325
+ if dataset not in ["Incompressible flow inside artery"]:
326
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
327
+ latent_features = get_single_latent(
328
+ mesh_path=geom_path,
329
+ config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"),
330
+ device=device,
331
+ custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None,
332
+ use_training_velocity=False,
333
+ model=model,
334
+ )
335
+
336
+ embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy")
337
+ pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl")
338
+ scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl")
339
+
340
+ embedding = np.load(embedding_path)
341
+ pca_reducer = pickle.load(open(pca_reducer_path, "rb"))
342
+ scaler = pickle.load(open(scaler_path, "rb"))
343
+
344
+ train_pair_dists = pairwise_distances(embedding)
345
+ sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0
346
+
347
+ n_points, n_features = latent_features.shape
348
+ np.random.seed(42)
349
+ target_len = int(pca_reducer.n_features_in_ / 256)
350
+ if n_points > target_len:
351
+ latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)]
352
+ elif n_points < target_len:
353
+ num_extra = target_len - n_points
354
+ extra_indices = np.random.choice(n_points, num_extra, replace=True)
355
+ latent_features = np.vstack([latent_features, latent_features[extra_indices]])
356
+
357
+ latent_features = latent_features.flatten()
358
+
359
+ confidence_score = compute_confidence_score(
360
+ latent_features, embedding, scaler, pca_reducer, sigma
361
+ )
362
+ except Exception:
363
+ confidence_score = 0.0
364
+
365
+ data = {
366
+ "input_pos": pos.unsqueeze(0).to(device),
367
+ "output_feat": target.unsqueeze(0).to(device),
368
+ }
369
+
370
+ with torch.no_grad():
371
+ inp = data["input_pos"]
372
+ _, N, _ = inp.shape
373
+ chunk = int(getattr(cfg, "num_points", 10000))
374
+
375
+ if getattr(cfg, "chunked_eval", False) and chunk < N:
376
+ input_pos = data["input_pos"]
377
+ chunk_size = cfg.num_points
378
+ out_chunks = []
379
+ total = (N + chunk_size - 1) // chunk_size
380
+
381
+ for k, i in enumerate(range(0, N, chunk_size)):
382
+ ch = input_pos[:, i : i + chunk_size, :]
383
+ n_valid = ch.shape[1]
384
+
385
+ if n_valid < chunk_size:
386
+ pad = input_pos[:, : chunk_size - n_valid, :]
387
+ ch = torch.cat([ch, pad], dim=1)
388
+
389
+ data["input_pos"] = ch
390
+ out_chunk = model(data)
391
+ if isinstance(out_chunk, (tuple, list)):
392
+ out_chunk = out_chunk[0]
393
+ out_chunks.append(out_chunk[:, :n_valid, :])
394
+
395
+ p(25 + 60 * (k + 1) / max(1, total))
396
+
397
+ outputs = torch.cat(out_chunks, dim=1)
398
+
399
+ else:
400
+ p(40)
401
+ outputs = model(data)
402
+ if isinstance(outputs, (tuple, list)):
403
+ outputs = outputs[0]
404
+ if torch.cuda.is_available():
405
+ torch.cuda.synchronize()
406
+ p(85)
407
+
408
+ vi = _variable_index(dataset, variable)
409
+ pred = outputs[0, :, vi : vi + 1]
410
+
411
+ if getattr(cfg, "normalization", "") == "std_norm":
412
+ fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
413
+ j = json.load(open(fp, "r"))
414
+ mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device)
415
+ sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device)
416
+ pred = pred * sd + mu
417
+
418
+ pred_np = pred.squeeze().detach().cpu().numpy()
419
+ tgt_np = target.squeeze().numpy()
420
+
421
+ pred_t = torch.from_numpy(pred_np).unsqueeze(-1)
422
+ tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1)
423
+ rel_l2 = torch.mean(
424
+ torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1)
425
+ / torch.norm(tgt_t.squeeze(-1), p=2, dim=-1)
426
+ )
427
+ tgt_mean = torch.mean(tgt_t)
428
+ ss_tot = torch.sum((tgt_t - tgt_mean) ** 2)
429
+ ss_res = torch.sum((tgt_t - pred_t) ** 2)
430
+ r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0)
431
+
432
+ p(100)
433
+ return {
434
+ "points": np.asarray(points),
435
+ "pred": np.asarray(pred_np),
436
+ "tgt": np.asarray(tgt_np),
437
+ "cosine_score": float(cosine_score),
438
+ "confidence_score": float(confidence_score),
439
+ "abs_err": float(np.mean(np.abs(pred_np - tgt_np))),
440
+ "mse_err": float(np.mean((pred_np - tgt_np) ** 2)),
441
+ "rel_l2": float(rel_l2.item()),
442
+ "r_squared": float(r2.item()),
443
+ }
444
+
445
+ # ========= VTK helpers =========
446
+ def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)):
447
+ r = vtkSTLReader()
448
+ r.SetFileName(stl_path)
449
+ r.Update()
450
+ tri = vtkTriangleFilter()
451
+ tri.SetInputConnection(r.GetOutputPort())
452
+ tri.Update()
453
+ m = vtkPolyDataMapper()
454
+ m.SetInputConnection(tri.GetOutputPort())
455
+ a = vtkActor()
456
+ a.SetMapper(m)
457
+ a.GetProperty().SetColor(*color)
458
+ return a
459
+
460
+ def build_jet_lut(vmin, vmax):
461
+ lut = vtkLookupTable()
462
+ lut.SetRange(float(vmin), float(vmax))
463
+ lut.SetNumberOfTableValues(256)
464
+ lut.Build()
465
+ cmap = cm.get_cmap("jet", 256)
466
+ for i in range(256):
467
+ r_, g_, b_, _ = cmap(i)
468
+ lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0)
469
+ return lut
470
+
471
+ def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None):
472
+ r = vtkSTLReader()
473
+ r.SetFileName(stl_path)
474
+ r.Update()
475
+ poly = r.GetOutput()
476
+
477
+ stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData())
478
+ tree = cKDTree(points_xyz)
479
+ _, nn_idx = tree.query(stl_pts, k=1)
480
+ scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx]
481
+
482
+ vtk_arr = nps.numpy_to_vtk(scalars, deep=True)
483
+ vtk_arr.SetName(array_name)
484
+ poly.GetPointData().AddArray(vtk_arr)
485
+ poly.GetPointData().SetActiveScalars(array_name)
486
+
487
+ mapper = vtkPolyDataMapper()
488
+ mapper.SetInputData(poly)
489
+ mapper.SetScalarModeToUsePointData()
490
+ mapper.ScalarVisibilityOn()
491
+ mapper.SetScalarRange(float(vmin), float(vmax))
492
+ if lut is None:
493
+ lut = build_jet_lut(vmin, vmax)
494
+ mapper.SetLookupTable(lut)
495
+ mapper.UseLookupTableScalarRangeOn()
496
+
497
+ actor = vtkActor()
498
+ actor.SetMapper(mapper)
499
+ return actor
500
+
501
+ def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8):
502
+ existing = []
503
+ ca = renderer.GetActors2D()
504
+ ca.InitTraversal()
505
+ for _ in range(ca.GetNumberOfItems()):
506
+ a = ca.GetNextItem()
507
+ if isinstance(a, vtkScalarBarActor):
508
+ existing.append(a)
509
+ for a in existing:
510
+ renderer.RemoveActor2D(a)
511
+
512
+ sbar = vtkScalarBarActor()
513
+ sbar.SetLookupTable(lut)
514
+ sbar.SetOrientationToVertical()
515
+ sbar.SetLabelFormat(label_fmt)
516
+ sbar.SetNumberOfLabels(int(n_labels))
517
+ sbar.SetTitle(title)
518
+ sbar.SetPosition(0.92, 0.05)
519
+ sbar.SetPosition2(0.06, 0.90)
520
+
521
+ tp = sbar.GetTitleTextProperty()
522
+ tp.SetColor(1, 1, 1)
523
+ tp.SetBold(True)
524
+ tp.SetFontSize(22)
525
+ lp = sbar.GetLabelTextProperty()
526
+ lp.SetColor(1, 1, 1)
527
+ lp.SetFontSize(18)
528
+
529
+ renderer.AddActor2D(sbar)
530
+ return sbar
531
+
532
+ # ---------- Small helpers ----------
533
+ def poly_count(mesh: pv.PolyData) -> int:
534
+ if hasattr(mesh, "n_faces_strict"):
535
+ return mesh.n_faces_strict
536
+ return mesh.n_cells
537
+
538
+ def md_to_html(txt: str) -> str:
539
+ if not txt:
540
+ return ""
541
+ safe = _html.escape(txt)
542
+ safe = re.sub(r"\*\*(.+?)\*\*", r"<b>\1</b>", safe)
543
+ return "<br/>".join(safe.splitlines())
544
+
545
+ def bc_text_right(dataset: str) -> str:
546
+ if dataset == "Incompressible flow over car":
547
+ return (
548
+ "<b>Reference Density:</b> 1.225 kg/m³<br><br>"
549
+ "<b>Reference Viscosity:</b> 1.789e-5 Pa·s<br><br>"
550
+ "<b>Operating Pressure:</b> 101325 Pa"
551
+ )
552
+ if dataset == "Compressible flow over plane":
553
+ return (
554
+ "<b>Ambient Temperature:</b> 218 K<br><br>"
555
+ "<b>Cruising velocity:</b> 250.0 m/s or 560 mph"
556
+ )
557
+ return ""
558
+
559
+ def bc_text_left(dataset: str) -> str:
560
+ if dataset == "Compressible flow over plane":
561
+ return (
562
+ "<b>Reference Density:</b> 0.36 kg/m³<br><br>"
563
+ "<b>Reference viscosity:</b> 1.716e-05 kg/(m·s)<br><br>"
564
+ "<b>Operating Pressure:</b> 23842 Pa"
565
+ )
566
+ return ""
567
+
568
+ # =====================================================================
569
+ # ======================= APP =======================================
570
+ # =====================================================================
571
+ class PFMDemo(TrameApp):
572
+ def __init__(self, server=None):
573
+ super().__init__(server)
574
+
575
+ # ---------------- VTK RENDERERS ----------------
576
+ self.ren_geom = vtkRenderer()
577
+ self.ren_geom.SetBackground(0.10, 0.16, 0.22)
578
+ self.rw_geom = vtkRenderWindow()
579
+ self.rw_geom.SetOffScreenRendering(1)
580
+ self.rw_geom.AddRenderer(self.ren_geom)
581
+ self.rwi_geom = vtkRenderWindowInteractor()
582
+ self.rwi_geom.SetRenderWindow(self.rw_geom)
583
+ self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
584
+ try:
585
+ self.rwi_geom.Initialize()
586
+ self.rwi_geom.Enable()
587
+ except Exception:
588
+ pass
589
+
590
+ self.ren_pred = vtkRenderer()
591
+ self.ren_pred.SetBackground(0.10, 0.16, 0.22)
592
+ self.rw_pred = vtkRenderWindow()
593
+ self.rw_pred.SetOffScreenRendering(1)
594
+ self.rw_pred.AddRenderer(self.ren_pred)
595
+ self.rwi_pred = vtkRenderWindowInteractor()
596
+ self.rwi_pred.SetRenderWindow(self.rw_pred)
597
+ self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
598
+ try:
599
+ self.rwi_pred.Initialize()
600
+ self.rwi_pred.Enable()
601
+ except Exception:
602
+ pass
603
+
604
+ self.scalar_bar = None
605
+
606
+ # timers / flags
607
+ self._predict_t0 = None
608
+ self._infer_thread = None
609
+ self._pre_upload_thread = None
610
+ self._infer_heartbeat_on = False
611
+ self._loop = None
612
+
613
+ # ---------------- TRAME STATE ----------------
614
+ s = self.state
615
+ s.theme_dark = True
616
+
617
+ s.analysis_types = ANALYSIS_TYPE
618
+ s.analysis_type = DEFAULT_ANALYSIS_TYPE
619
+ s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE]
620
+ s.dataset = DEFAULT_DATASET
621
+ s.variable_choices = variables_for(DEFAULT_DATASET)
622
+ s.variable = s.variable_choices[0] if s.variable_choices else None
623
+
624
+ # dialog (still kept)
625
+ s.show_decimation_dialog = False
626
+ s.decim_override_enabled = False
627
+ s.decim_override_mode = "medium"
628
+ s.decim_override_custom = 0.5
629
+
630
+ # menu decimation defaults
631
+ s.decim_enable = False # user MUST toggle to override auto
632
+ s.decim_target = 0.5
633
+ s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced
634
+ s.decim_max_faces = int(1e7)
635
+
636
+ # register controller properly
637
+ # self.server.controller.decimate_again = self.decimate_again
638
+ # self.server.controller.add("decimate_again", self.decimate_again)
639
+ ctrl = self.server.controller
640
+
641
+ # ✅ this actually registers the trigger
642
+ ctrl.add("decimate_again", self.decimate_again)
643
+ ctrl.add("reset_mesh", self.reset_mesh)
644
+
645
+
646
+
647
+ s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car")
648
+ s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane")
649
+ s.velocity_mph = 45.0
650
+
651
+ s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET)
652
+ s.bc_left = bc_text_left(DEFAULT_DATASET)
653
+ s.bc_right = bc_text_right(DEFAULT_DATASET)
654
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
655
+
656
+ s.stats_html = "👋 Upload a geometry, then click Predict."
657
+ s.upload = None
658
+
659
+ # upload
660
+ s.is_uploading = False
661
+ s.pm_upload = 0
662
+ s.pm_elapsed_upload = 0.0
663
+ s.upload_msg = ""
664
+
665
+ # predict
666
+ s.is_predicting = False
667
+ s.predict_progress = 0
668
+ s.predict_msg = ""
669
+ s.predict_elapsed = 0.0
670
+ s.predict_est_total = 0.0
671
+ s.pm_infer = 0
672
+ s.pm_elapsed_infer = 0.0
673
+
674
+ self._build_ui()
675
+
676
+ def _ensure_loop(self):
677
+ if self._loop is not None:
678
+ return self._loop
679
+ try:
680
+ loop = asyncio.get_event_loop()
681
+ except RuntimeError:
682
+ loop = asyncio.new_event_loop()
683
+ asyncio.set_event_loop(loop)
684
+ self._loop = loop
685
+ return loop
686
+
687
+ def _run_coro(self, coro):
688
+ loop = self._ensure_loop()
689
+ if loop.is_running():
690
+ return asyncio.ensure_future(coro, loop=loop)
691
+ return loop.run_until_complete(coro)
692
+
693
+ async def _flush_async(self):
694
+ try:
695
+ self.server.state.flush()
696
+ except Exception:
697
+ pass
698
+ await asyncio.sleep(0)
699
+
700
+ def _build_ui(self):
701
+ ctrl = self.server.controller
702
+ with SinglePageLayout(self.server, full_height=True) as layout:
703
+ layout.title.set_text("") # clear
704
+ layout.title.hide = True # hide default
705
+ with layout.toolbar:
706
+ with v3.VContainer(
707
+ fluid=True,
708
+ style=(
709
+ "max-width: 1800px;" # overall width
710
+ "margin: 0 auto;" # center it
711
+ "padding: 0 8px;" # ← left/right margin
712
+ "box-sizing: border-box;"
713
+ ),
714
+ ):
715
+ v3.VSpacer()
716
+ html.Div(
717
+ "Ansys: Physics Foundation Model",
718
+ style=(
719
+ "width:100%;"
720
+ "text-align:center;"
721
+ "font-size:34px;"
722
+ "font-weight:900;"
723
+ "letter-spacing:0.4px;"
724
+ "line-height:1.2;"
725
+ ),
726
+ )
727
+ v3.VSpacer()
728
+
729
+ # toolbar
730
+ with layout.toolbar:
731
+ # ← same margin container for the second toolbar row
732
+ with v3.VContainer(
733
+ fluid=True,
734
+ style=(
735
+ "max-width: 1800px;"
736
+ "margin: 0 auto;"
737
+ "padding: 0 8px;"
738
+ "box-sizing: border-box;"
739
+ ),
740
+ ):
741
+ v3.VSwitch(
742
+ v_model=("theme_dark",),
743
+ label="Dark Theme",
744
+ inset=True,
745
+ density="compact",
746
+ hide_details=True,
747
+ )
748
+ v3.VSpacer()
749
+
750
+ with layout.content:
751
+ html.Style("""
752
+ /* Small side padding for the whole app */
753
+ .v-application__wrap {
754
+ padding-left: 8px;
755
+ padding-right: 8px;
756
+ padding-bottom: 8px;
757
+ }
758
+
759
+ :root {
760
+ --pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
761
+ --pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace;
762
+ }
763
+
764
+ html, body, .v-application {
765
+ margin: 0;
766
+ padding: 0;
767
+ font-family: var(--pfm-font-ui) !important;
768
+ font-weight: 500;
769
+ letter-spacing: .25px;
770
+ -webkit-font-smoothing: antialiased;
771
+ -moz-osx-font-smoothing: grayscale;
772
+ text-rendering: optimizeLegibility;
773
+ line-height: 1.5;
774
+ font-size: 15.5px;
775
+ color: #ECEFF4;
776
+ }
777
+
778
+ /* ... keep all your other typography / button / slider styles here ... */
779
+
780
+ .v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; }
781
+ .v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; }
782
+
783
+ /* (rest of your .pfm-* classes unchanged) */
784
+ """)
785
+
786
+ # html.Style("""
787
+ # .v-theme--dark { background: #1F232B !important; }
788
+ # .v-theme--light { background: #f5f6f8 !important; }
789
+ # .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; }
790
+ # .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; }
791
+ # .v-theme--dark .pfm-viewer { background: #15171d !important; }
792
+ # .v-theme--light .pfm-viewer { background: #e9edf3 !important; }
793
+ # .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); }
794
+ # .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; }
795
+ # .pfm-btn-big.v-btn {
796
+ # height: 48px !important;
797
+ # font-size: 18px !important;
798
+ # font-weight: 600 !important;
799
+ # letter-spacing: 1.2px;
800
+ # text-transform: none !important;
801
+ # border-radius: 999px !important;
802
+ # }
803
+ # .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; }
804
+ # """)
805
+ html.Link(
806
+ rel="stylesheet",
807
+ href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap",
808
+ )
809
+
810
+ with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)):
811
+ with v3.VContainer(
812
+ fluid=True,
813
+ class_="pa-6",
814
+ style=(
815
+ "max-width: 2200px;" # max width of content
816
+ "margin: 8px auto 16px auto;" # top / left-right / bottom
817
+ "padding: 0 8px;" # inner left/right padding
818
+ "box-sizing: border-box;"
819
+ "background: rgba(255,255,255,0.02);"
820
+ "border-radius: 16px;"
821
+ ),
822
+ ):
823
+
824
+
825
+ # 1) Physics Application
826
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
827
+ html.Div(
828
+ "🧪 <b>Physics Application</b>",
829
+ style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;",
830
+ )
831
+ html.Div(
832
+ "Select the type of analysis",
833
+ style="font-size:24px;opacity:.82;margin-bottom:18px;",
834
+ )
835
+ toggle = v3.VBtnToggle(
836
+ v_model=("analysis_type", self.state.analysis_type),
837
+ class_="mt-1",
838
+ mandatory=True,
839
+ rounded=True,
840
+ )
841
+ # with toggle:
842
+ # for at in ANALYSIS_TYPE:
843
+ # v3.VBtn(
844
+ # at,
845
+ # value=at,
846
+ # variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
847
+ # class_="mr-2 pfm-toggle-xxl",
848
+ # )
849
+
850
+ with toggle:
851
+ for at in ANALYSIS_TYPE:
852
+ v3.VBtn(
853
+ at,
854
+ value=at,
855
+ variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
856
+ class_="mr-2 pfm-toggle-xxl",
857
+ style=(
858
+ "font-size:18px;"
859
+ "font-weight:800;"
860
+ "letter-spacing:0.4px;"
861
+ "text-transform:none;"
862
+ ),
863
+ )
864
+
865
+ # 2) Dataset + Variable
866
+ with v3.VRow(dense=True, class_="mb-3"):
867
+ with v3.VCol(cols=6):
868
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
869
+ html.Div(
870
+ "🧩 Sub Application",
871
+ style="font-weight:700;font-size:24px;margin-bottom:14px;",
872
+ )
873
+ v3.VSelect(
874
+ v_model=("dataset", self.state.dataset),
875
+ items=("dataset_choices", self.state.dataset_choices),
876
+ hide_details=True,
877
+ density="comfortable",
878
+ style=(
879
+ "font-size:24px;"
880
+ "font-weight:800;"
881
+ "height:56px;"
882
+ "display:flex;"
883
+ "align-items:center;"
884
+ ),
885
+ class_="pfm-big-select-subapp",
886
+ menu_props={"content_class": "pfm-subapp-list"}, # <— key for dropdown items
887
+ )
888
+ # v3.VSelect(
889
+ # v_model=("dataset", self.state.dataset),
890
+ # items=("dataset_choices", self.state.dataset_choices),
891
+ # hide_details=True,
892
+ # density="comfortable",
893
+ # class_="pfm-big-select-subapp pfm-subapp-list",
894
+ # style="font-size:21px;",
895
+ # )
896
+ with v3.VCol(cols=6):
897
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
898
+ html.Div(
899
+ "📊 Variable to Predict",
900
+ style="font-weight:700;font-size:20px;margin-bottom:14px;",
901
+ )
902
+
903
+ v3.VSelect(
904
+ v_model=("variable", self.state.variable),
905
+ items=("variable_choices", self.state.variable_choices),
906
+ hide_details=True,
907
+ density="comfortable",
908
+ class_="pfm-var-select",
909
+ style=(
910
+ "font-size:20px;"
911
+ "font-weight:800;"
912
+ "height:56px;"
913
+ "display:flex;"
914
+ "align-items:center;"
915
+ ),
916
+ menu_props={"content_class": "pfm-var-list"},
917
+ )
918
+ # v3.VSelect(
919
+ # v_model=("variable", self.state.variable),
920
+ # items=("variable_choices", self.state.variable_choices),
921
+ # hide_details=True,
922
+ # density="comfortable",
923
+ # style="font-size:16px;",
924
+ # )
925
+
926
+ # 3) Boundary Conditions
927
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
928
+ html.Div(
929
+ "🧱 Boundary Conditions",
930
+ style="font-weight:700;font-size:22px;margin-bottom:16px;",
931
+ )
932
+
933
+ # two columns: Left = velocity controls, Right = reference text
934
+ with v3.VRow(class_="align-start", dense=True):
935
+ # ---- LEFT: velocity slider / readout ----
936
+ with v3.VCol(cols=7, class_="pfm-vel"):
937
+ html.Div(
938
+ "🚗 Velocity (mph)",
939
+ class_="pfm-vel-title",
940
+ style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;",
941
+ )
942
+ html.Div(
943
+ "Set the inlet velocity in miles per hour",
944
+ class_="pfm-vel-sub",
945
+ style="margin-bottom:10px;font-size:20px;opacity:.95;",
946
+ )
947
+ v3.VSlider(
948
+ v_model=("velocity_mph", self.state.velocity_mph),
949
+ min=30.0, max=80.0, step=0.1,
950
+ thumb_label=True,
951
+ v_if=("show_velocity",),
952
+ style="height:54px;margin-top:12px;max-width:540px;",
953
+ class_="mt-3 mb-3 pfm-vel-slider",
954
+ )
955
+ html.Div(
956
+ "{{ velocity_mph.toFixed(0) }} / 80 "
957
+ "<span style='opacity:.95'>"
958
+ "({{ (velocity_mph * 0.44704).toFixed(2) }} m/s)</span>",
959
+ v_if=("show_velocity",),
960
+ class_="pfm-vel-readout",
961
+ style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:6px;",
962
+ )
963
+
964
+ # ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ----
965
+ with v3.VCol(cols=5, class_="pfm-bc-right"):
966
+ html.Div(
967
+ v_html=("bc_text_html", ""),
968
+ style=(
969
+ "margin-top:6px;"
970
+ "font-size:18px;"
971
+ "line-height:1.7;"
972
+ "min-width:260px;"
973
+ "max-width:360px;"
974
+ "text-align:left;"
975
+ ),
976
+ )
977
+
978
+
979
+
980
+ # 4) Two viewers
981
+ with v3.VRow(style="margin-top: 24px;"):
982
+ # LEFT = upload
983
+ with v3.VCol(cols=6):
984
+ with v3.VRow(class_="align-center justify-space-between mb-2"):
985
+ html.Div(
986
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📤 Input Geometry</span>",
987
+ )
988
+
989
+ # ✅ working gear menu
990
+ with v3.VMenu(
991
+ location="bottom end",
992
+ close_on_content_click=False,
993
+ offset="4 8",
994
+ ):
995
+ # activator slot MUST expose { props } and we MUST bind them to the button
996
+ with v3.Template(v_slot_activator="{ props }"):
997
+ with v3.VBtn(
998
+ icon=True,
999
+ variant="text",
1000
+ density="comfortable",
1001
+ style="min-width:32px;",
1002
+ v_bind="props", # 👈 this is the key
1003
+ ):
1004
+ v3.VIcon("mdi-cog", size="22")
1005
+
1006
+ # menu content
1007
+ with v3.VCard(class_="pa-4", style="min-width: 280px;"):
1008
+ html.Div("<b>Mesh decimation</b>", class_="mb-3", style="font-size:14px;")
1009
+
1010
+ v3.VSwitch(
1011
+ v_model=("decim_enable",),
1012
+ label="Enable decimation",
1013
+ inset=True,
1014
+ hide_details=True,
1015
+ class_="mb-4",
1016
+ )
1017
+
1018
+ html.Div(
1019
+ "Target reduction (fraction of faces to remove)",
1020
+ class_="mb-1",
1021
+ style="font-size:12px;color:#9ca3af;",
1022
+ )
1023
+ v3.VSlider(
1024
+ v_model=("decim_target",),
1025
+ min=0.0,
1026
+ max=0.999,
1027
+ step=0.001,
1028
+ hide_details=True,
1029
+ class_="mb-2",
1030
+ )
1031
+ html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3")
1032
+
1033
+ with v3.VRow(dense=True, class_="mb-3"):
1034
+ with v3.VCol(cols=6):
1035
+ html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1036
+ v3.VTextField(
1037
+ v_model=("decim_min_faces",),
1038
+ type="number",
1039
+ density="compact",
1040
+ hide_details=True,
1041
+ )
1042
+ with v3.VCol(cols=6):
1043
+ html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1044
+ v3.VTextField(
1045
+ v_model=("decim_max_faces",),
1046
+ type="number",
1047
+ density="compact",
1048
+ hide_details=True,
1049
+ )
1050
+
1051
+ v3.VBtn(
1052
+ "Apply to current mesh",
1053
+ block=True,
1054
+ color="primary",
1055
+ class_="mt-2",
1056
+ click=self.decimate_again,
1057
+ )
1058
+
1059
+ v3.VBtn(
1060
+ "Reset to original mesh",
1061
+ block=True,
1062
+ variant="tonal",
1063
+ class_="mt-2",
1064
+ click=self.reset_mesh, # 👈 will call the controller you added
1065
+ )
1066
+
1067
+
1068
+ v3.VFileInput(
1069
+ label="Select 3D File",
1070
+ style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;",
1071
+ multiple=False,
1072
+ show_size=True,
1073
+ accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb",
1074
+ v_model=("upload", None),
1075
+ clearable=True,
1076
+ )
1077
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1078
+ self.view_geom = VtkRemoteView(
1079
+ self.rw_geom,
1080
+ interactive=True,
1081
+ interactive_ratio=1,
1082
+ server=self.server,
1083
+ )
1084
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1085
+ rounded=True, elevation=3):
1086
+ html.Div("<b>Upload</b>", style="font-size:18px;")
1087
+
1088
+ # progress bar: only while uploading
1089
+ v3.VProgressLinear(
1090
+ v_model=("pm_upload", 0),
1091
+ height=22,
1092
+ style="margin-top:10px;margin-bottom:10px;",
1093
+ color="primary",
1094
+ rounded=True,
1095
+ v_show=("is_uploading",), # 👈 bar disappears after upload
1096
+ )
1097
+
1098
+ # text: percentage + time + message, only while uploading
1099
+ html.Div(
1100
+ "{{ pm_upload }}% — {{ pm_elapsed_upload.toFixed(2) }}s — {{ upload_msg }}",
1101
+ style="font-size:14px;",
1102
+ v_show=("is_uploading",), # 👈 hide text after completion
1103
+ )
1104
+
1105
+ v3.VBtn(
1106
+ "🗑️ CLEAR",
1107
+ block=True,
1108
+ variant="tonal",
1109
+ class_="mt-3 pfm-btn-big",
1110
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1111
+ click=self.clear,
1112
+ )
1113
+ # RIGHT = prediction
1114
+ with v3.VCol(cols=6):
1115
+ html.Div(
1116
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📈 Prediction Results</span>",
1117
+ style="margin-bottom:10px;",
1118
+ )
1119
+ html.Div(
1120
+ v_html=("stats_html",),
1121
+ class_="mb-3",
1122
+ style="font-size:20px;line-height:1.4;",
1123
+ )
1124
+ # v3.VProgressLinear(
1125
+ # v_model=("predict_progress", 0),
1126
+ # height=22,
1127
+ # style="margin-top:6px;margin-bottom:12px;",
1128
+ # color="primary",
1129
+ # rounded=True,
1130
+ # indeterminate=("predict_progress < 10",),
1131
+ # v_show=("is_predicting",),
1132
+ # )
1133
+ # html.Div(
1134
+ # "Predicting: {{ predict_progress }}%",
1135
+ # style="font-size:14px;margin-bottom:10px;",
1136
+ # v_show=("is_predicting",),
1137
+ # )
1138
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1139
+ self.view_pred = VtkRemoteView(
1140
+ self.rw_pred,
1141
+ interactive=True,
1142
+ interactive_ratio=1,
1143
+ server=self.server,
1144
+ )
1145
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1146
+ rounded=True, elevation=3):
1147
+ html.Div("<b>Inference</b>", style="font-size:18px;")
1148
+
1149
+ # 🔴 OLD: v_model=("predict_progress", 0), indeterminate=...
1150
+ # 🟢 NEW: use pm_infer and a normal (non-indeterminate) bar
1151
+ v3.VProgressLinear(
1152
+ v_model=("pm_infer", 0),
1153
+ height=22,
1154
+ style="margin-top:6px;margin-bottom:12px;",
1155
+ color="success",
1156
+ rounded=True,
1157
+ indeterminate=("predict_progress <= 0",),
1158
+ v_show=("is_predicting",), # 👈 bar only visible while predicting
1159
+ )
1160
+
1161
+ # text line: % + elapsed time + current stage message
1162
+ html.Div(
1163
+ "{{ pm_infer }}% — {{ pm_elapsed_infer.toFixed(2) }}s — {{ predict_msg }}",
1164
+ style="font-size:14px;margin-bottom:10px;",
1165
+ # ❗ if you want the *text* to also disappear at the end, keep v_show;
1166
+ # if you want the final "✅ Prediction complete — 1.23s" to stay, REMOVE v_show
1167
+ v_show=("is_predicting",),
1168
+ )
1169
+ v3.VBtn(
1170
+ "🚀 PREDICT",
1171
+ block=True,
1172
+ color="primary",
1173
+ class_="mt-3 pfm-btn-big",
1174
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1175
+ click=self.predict,
1176
+ )
1177
+
1178
+ layout.on_ready = self._first_paint
1179
+
1180
+ def _first_paint(self, **_):
1181
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1182
+ try:
1183
+ rw.Render()
1184
+ except Exception:
1185
+ pass
1186
+ view.update()
1187
+
1188
+ # ---------------------------------------------------------
1189
+ # UPLOAD (async)
1190
+ # ---------------------------------------------------------
1191
+ def _write_upload_to_disk(self, payload) -> str:
1192
+ if payload is None:
1193
+ raise ValueError("No file payload")
1194
+ if isinstance(payload, (list, tuple)):
1195
+ payload = payload[0]
1196
+ if isinstance(payload, str):
1197
+ return payload
1198
+ if not isinstance(payload, dict):
1199
+ raise ValueError(f"Unsupported payload: {type(payload)}")
1200
+ if payload.get("path"):
1201
+ return payload["path"]
1202
+ name = payload.get("name") or "upload"
1203
+ content = payload.get("content")
1204
+ if isinstance(content, str) and content.startswith("data:"):
1205
+ content = content.split(",", 1)[1]
1206
+ raw = base64.b64decode(content) if isinstance(content, str) else bytes(content)
1207
+ os.makedirs(GEOM_DIR, exist_ok=True)
1208
+ file_path = os.path.join(GEOM_DIR, name)
1209
+ with open(file_path, "wb") as f:
1210
+ f.write(raw)
1211
+ return file_path
1212
+
1213
+ def _pre_upload_spinner_loop(self):
1214
+ s = self.state
1215
+ phase = 1
1216
+ while self._pre_upload_on and not self._upload_actual_started and s.is_uploading:
1217
+ s.pm_upload = max(1, min(9, phase))
1218
+ s.upload_msg = "Initializing upload..."
1219
+ try:
1220
+ self.server.state.flush()
1221
+ except Exception:
1222
+ pass
1223
+ phase = 1 if phase >= 9 else phase + 1
1224
+ time.sleep(0.15)
1225
+
1226
+ def _start_pre_upload_spinner(self):
1227
+ if self._pre_upload_thread and self._pre_upload_thread.is_alive():
1228
+ return
1229
+ self._upload_actual_started = False
1230
+ self._pre_upload_on = True
1231
+ self._pre_upload_thread = threading.Thread(
1232
+ target=self._pre_upload_spinner_loop, daemon=True
1233
+ )
1234
+ self._pre_upload_thread.start()
1235
+
1236
+ def _stop_pre_upload_spinner(self):
1237
+ self._pre_upload_on = False
1238
+ self._pre_upload_thread = None
1239
+
1240
+ async def _fake_upload_bump(self, stop_event: asyncio.Event):
1241
+ s = self.state
1242
+ while not stop_event.is_set() and s.pm_upload < 9:
1243
+ s.pm_upload += 1
1244
+ await self._flush_async()
1245
+ await asyncio.sleep(0.05)
1246
+
1247
+ async def _upload_worker_async(self, upload):
1248
+ s = self.state
1249
+ loop = self._ensure_loop()
1250
+ t0 = time.time()
1251
+
1252
+ s.is_uploading = True
1253
+ s.upload_msg = "Starting upload..."
1254
+ s.pm_elapsed_upload = 0.0
1255
+
1256
+ s.pm_upload = 5
1257
+ self.server.state.flush()
1258
+ await asyncio.sleep(0)
1259
+
1260
+ fake_stop = asyncio.Event()
1261
+ fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop))
1262
+
1263
+ try:
1264
+ self._upload_actual_started = True
1265
+ self._stop_pre_upload_spinner()
1266
+
1267
+ if not fake_stop.is_set():
1268
+ fake_stop.set()
1269
+ try:
1270
+ await fake_task
1271
+ except asyncio.CancelledError:
1272
+ pass
1273
+
1274
+ s.upload_msg = "Writing file to disk..."
1275
+ s.pm_upload = 10
1276
+ s.pm_elapsed_upload = time.time() - t0
1277
+ await self._flush_async()
1278
+ file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload)
1279
+
1280
+ s.upload_msg = "Reading mesh..."
1281
+ s.pm_upload = 20
1282
+ s.pm_elapsed_upload = time.time() - t0
1283
+ await self._flush_async()
1284
+ mesh = await loop.run_in_executor(None, pv.read, file_path)
1285
+
1286
+ # 3) decimation (auto first)
1287
+ try:
1288
+ nf = poly_count(mesh)
1289
+ except Exception:
1290
+ nf = mesh.n_cells
1291
+
1292
+ auto_tr = float(auto_target_reduction(nf))
1293
+
1294
+ # reflect auto in UI
1295
+ s.decim_target = auto_tr
1296
+ s.decim_min_faces = 5000 # <= allow decimation even for 27k faces
1297
+ s.decim_max_faces = int(1e7)
1298
+
1299
+ target = auto_tr
1300
+ min_faces = 5000
1301
+ max_faces = int(1e7)
1302
+
1303
+ # user override
1304
+ if self.state.decim_enable:
1305
+ target = float(self.state.decim_target or 0.0)
1306
+ min_faces = int(self.state.decim_min_faces or 5000)
1307
+ max_faces = int(self.state.decim_max_faces or 1e7)
1308
+
1309
+ if target > 0.0:
1310
+ s.upload_msg = f"Decimating mesh ({target:.3f})..."
1311
+ s.pm_upload = max(s.pm_upload, 45)
1312
+ s.pm_elapsed_upload = time.time() - t0
1313
+ await self._flush_async()
1314
+
1315
+ dec_cfg = {
1316
+ "enabled": True,
1317
+ "method": "pro",
1318
+ "target_reduction": target,
1319
+ "min_faces": min_faces,
1320
+ "max_faces": max_faces,
1321
+ }
1322
+ mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg)
1323
+
1324
+ # 4) normals + save
1325
+ s.upload_msg = "Preparing geometry..."
1326
+ s.pm_upload = 75
1327
+ s.pm_elapsed_upload = time.time() - t0
1328
+ await self._flush_async()
1329
+
1330
+ def _normals_and_save(m):
1331
+ m_fixed = m.compute_normals(
1332
+ consistent_normals=True,
1333
+ auto_orient_normals=True,
1334
+ point_normals=True,
1335
+ cell_normals=False,
1336
+ inplace=False,
1337
+ )
1338
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1339
+ m_fixed.save(geom_path_)
1340
+ return geom_path_, m_fixed
1341
+
1342
+ geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh)
1343
+
1344
+ # 5) update viewer
1345
+ self.ren_geom.RemoveAllViewProps()
1346
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1347
+ self.ren_geom.ResetCamera()
1348
+ try:
1349
+ self.rw_geom.Render()
1350
+ except Exception:
1351
+ pass
1352
+ self.view_geom.update()
1353
+ # GEOMETRY_CACHE.current_mesh = mesh_fixed
1354
+ GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True)
1355
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1356
+
1357
+ s.upload_msg = "✅ Geometry ready."
1358
+ s.pm_upload = 100
1359
+ s.pm_elapsed_upload = time.time() - t0
1360
+ await self._flush_async()
1361
+
1362
+ except Exception as e:
1363
+ s.upload_msg = f"❌ Upload failed: {e}"
1364
+ s.pm_upload = 0
1365
+ s.pm_elapsed_upload = time.time() - t0
1366
+ await self._flush_async()
1367
+ finally:
1368
+ s.is_uploading = False
1369
+ s.pm_elapsed_upload = time.time() - t0
1370
+ await self._flush_async()
1371
+
1372
+ if not fake_stop.is_set():
1373
+ fake_stop.set()
1374
+ if not fake_task.done():
1375
+ fake_task.cancel()
1376
+ try:
1377
+ await fake_task
1378
+ except Exception:
1379
+ pass
1380
+
1381
+ @change("upload")
1382
+ def _on_upload_change(self, upload, **_):
1383
+ if not upload:
1384
+ return
1385
+ self._run_coro(self._upload_worker_async(upload))
1386
+
1387
+ def decimate_again(self):
1388
+ self._run_coro(self._decimate_again_async())
1389
+
1390
+ async def _decimate_again_async(self):
1391
+ s = self.state
1392
+ loop = self._ensure_loop()
1393
+
1394
+ if GEOMETRY_CACHE.current_mesh is None:
1395
+ # nothing to decimate
1396
+ s.upload_msg = "No mesh to re-decimate"
1397
+ await self._flush_async()
1398
+ return
1399
+
1400
+ # --- start "upload-like" progress for manual decimation ---
1401
+ t0 = time.time()
1402
+ s.is_uploading = True
1403
+ s.pm_upload = 5
1404
+ s.pm_elapsed_upload = 0.0
1405
+ s.upload_msg = "Starting mesh re-decimation..."
1406
+ await self._flush_async()
1407
+
1408
+ try:
1409
+ # read parameters from UI
1410
+ try:
1411
+ target = float(s.decim_target)
1412
+ except Exception:
1413
+ target = 0.0
1414
+
1415
+ try:
1416
+ min_faces = int(s.decim_min_faces)
1417
+ except Exception:
1418
+ min_faces = 5000
1419
+
1420
+ try:
1421
+ max_faces = int(s.decim_max_faces)
1422
+ except Exception:
1423
+ max_faces = int(1e7)
1424
+
1425
+ if (not s.decim_enable) or target <= 0.0:
1426
+ s.upload_msg = "Decimation disabled"
1427
+ s.pm_upload = 0
1428
+ s.pm_elapsed_upload = time.time() - t0
1429
+ await self._flush_async()
1430
+ return
1431
+
1432
+ # --- bump before heavy decimation call ---
1433
+ s.upload_msg = f"Re-decimating mesh ({target:.3f})..."
1434
+ s.pm_upload = 25
1435
+ s.pm_elapsed_upload = time.time() - t0
1436
+ await self._flush_async()
1437
+
1438
+ dec_cfg = {
1439
+ "enabled": True,
1440
+ "method": "pro",
1441
+ "target_reduction": target,
1442
+ "min_faces": min_faces,
1443
+ "max_faces": max_faces,
1444
+ }
1445
+
1446
+ # heavy work on executor
1447
+ mesh = await loop.run_in_executor(
1448
+ None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg
1449
+ )
1450
+
1451
+ # --- normals + save ---
1452
+ s.upload_msg = "Recomputing normals & saving..."
1453
+ s.pm_upload = 70
1454
+ s.pm_elapsed_upload = time.time() - t0
1455
+ await self._flush_async()
1456
+
1457
+ def _normals_and_save(m):
1458
+ m_fixed = m.compute_normals(
1459
+ consistent_normals=True,
1460
+ auto_orient_normals=True,
1461
+ point_normals=True,
1462
+ cell_normals=False,
1463
+ inplace=False,
1464
+ )
1465
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1466
+ m_fixed.save(geom_path_)
1467
+ return geom_path_, m_fixed
1468
+
1469
+ geom_path, mesh_fixed = await loop.run_in_executor(
1470
+ None, _normals_and_save, mesh
1471
+ )
1472
+
1473
+ # --- update viewer ---
1474
+ s.upload_msg = "Updating viewer..."
1475
+ s.pm_upload = 90
1476
+ s.pm_elapsed_upload = time.time() - t0
1477
+ await self._flush_async()
1478
+
1479
+ self.ren_geom.RemoveAllViewProps()
1480
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1481
+ self.ren_geom.ResetCamera()
1482
+ try:
1483
+ self.rw_geom.Render()
1484
+ except Exception:
1485
+ pass
1486
+ self.view_geom.update()
1487
+
1488
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1489
+
1490
+ # --- final bump ---
1491
+ s.upload_msg = "✅ Re-decimated"
1492
+ s.pm_upload = 100
1493
+ s.pm_elapsed_upload = time.time() - t0
1494
+ await self._flush_async()
1495
+
1496
+ except Exception as e:
1497
+ s.upload_msg = f"❌ Re-decimation failed: {e}"
1498
+ s.pm_upload = 0
1499
+ s.pm_elapsed_upload = time.time() - t0
1500
+ await self._flush_async()
1501
+ finally:
1502
+ # hide bar + text after we’re done
1503
+ s.is_uploading = False
1504
+ await self._flush_async()
1505
+
1506
+
1507
+ def reset_mesh(self):
1508
+ self._run_coro(self._reset_mesh_async())
1509
+
1510
+ async def _reset_mesh_async(self):
1511
+ s = self.state
1512
+
1513
+ if GEOMETRY_CACHE.original_mesh is None:
1514
+ s.upload_msg = "No original mesh to reset to"
1515
+ await self._flush_async()
1516
+ return
1517
+
1518
+ # use the saved original
1519
+ orig = GEOMETRY_CACHE.original_mesh
1520
+
1521
+ # save it again as current
1522
+ GEOMETRY_CACHE.current_mesh = orig
1523
+
1524
+ # write to disk (so the STL on disk matches the viewer)
1525
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
1526
+ orig.save(geom_path)
1527
+
1528
+ # update viewer
1529
+ self.ren_geom.RemoveAllViewProps()
1530
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1531
+ self.ren_geom.ResetCamera()
1532
+ try:
1533
+ self.rw_geom.Render()
1534
+ except Exception:
1535
+ pass
1536
+ self.view_geom.update()
1537
+
1538
+ s.upload_msg = "↩️ Reset to original mesh"
1539
+ await self._flush_async()
1540
+
1541
+ # ---------------------------------------------------------
1542
+ # prediction
1543
+ # ---------------------------------------------------------
1544
+ def _start_infer_heartbeat(self):
1545
+ if self._infer_thread and self._infer_thread.is_alive():
1546
+ return
1547
+
1548
+ def loop_fn():
1549
+ while self._infer_heartbeat_on:
1550
+ if self.state.is_predicting and self._predict_t0 is not None:
1551
+ self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0)
1552
+ try:
1553
+ self.server.state.flush()
1554
+ except Exception:
1555
+ pass
1556
+ time.sleep(0.12)
1557
+
1558
+ self._infer_heartbeat_on = True
1559
+ self._infer_thread = threading.Thread(target=loop_fn, daemon=True)
1560
+ self._infer_thread.start()
1561
+
1562
+ def _stop_infer_heartbeat(self):
1563
+ self._infer_heartbeat_on = False
1564
+ self._infer_thread = None
1565
+
1566
+ async def _predict_worker_async(self):
1567
+ s = self.state
1568
+ loop = self._ensure_loop()
1569
+ t0 = time.time()
1570
+
1571
+ if GEOMETRY_CACHE.current_mesh is None:
1572
+ s.predict_msg = "❌ Please upload geometry first"
1573
+ s.is_predicting = False
1574
+ await self._flush_async()
1575
+ return
1576
+
1577
+ s.is_predicting = True
1578
+ s.predict_progress = 1
1579
+ s.pm_infer = 1
1580
+ s.predict_msg = "Preparing inference..."
1581
+ self._predict_t0 = time.time()
1582
+ self._start_infer_heartbeat()
1583
+ await self._flush_async()
1584
+
1585
+ try:
1586
+ dataset = s.dataset
1587
+ variable = s.variable
1588
+ boundary = (
1589
+ {"freestream_velocity": mph_to_ms(s.velocity_mph)}
1590
+ if dataset == "Incompressible flow over car"
1591
+ else None
1592
+ )
1593
+
1594
+ s.predict_msg = "Loading model/checkpoint..."
1595
+ s.predict_progress = 5
1596
+ s.pm_infer = 5
1597
+ await self._flush_async()
1598
+
1599
+ cfg, model, device, _ = await loop.run_in_executor(
1600
+ None, MODEL_STORE.get, dataset, variable, None
1601
+ )
1602
+
1603
+ s.predict_msg = "Processing mesh for inference..."
1604
+ s.predict_progress = 35
1605
+ s.pm_infer = 35
1606
+ await self._flush_async()
1607
+
1608
+ def _run_full():
1609
+ return run_inference_fast(
1610
+ dataset,
1611
+ variable,
1612
+ boundary_conditions=boundary,
1613
+ progress_cb=None,
1614
+ )
1615
+ viz = await loop.run_in_executor(None, _run_full)
1616
+
1617
+ s.predict_msg = "Preparing visualization..."
1618
+ s.predict_progress = 85
1619
+ s.pm_infer = 85
1620
+ await self._flush_async()
1621
+
1622
+ stl_path = os.path.join(GEOM_DIR, "geometry.stl")
1623
+ vmin = float(np.min(viz["pred"]))
1624
+ vmax = float(np.max(viz["pred"]))
1625
+
1626
+ if os.path.exists(stl_path):
1627
+ _tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path)
1628
+ lut = build_jet_lut(vmin, vmax)
1629
+ colored_actor = color_actor_with_scalars_from_prediction(
1630
+ stl_path,
1631
+ viz["points"],
1632
+ viz["pred"],
1633
+ variable,
1634
+ vmin,
1635
+ vmax,
1636
+ lut=lut,
1637
+ )
1638
+ self.ren_pred.AddActor(colored_actor)
1639
+
1640
+ units = {
1641
+ "pressure": "Pa",
1642
+ "x_velocity": "m/s",
1643
+ "y_velocity": "m/s",
1644
+ "z_velocity": "m/s",
1645
+ }.get(variable, "")
1646
+ title = f"{variable} ({units})" if units else variable
1647
+ self.scalar_bar = add_or_update_scalar_bar(
1648
+ self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8
1649
+ )
1650
+
1651
+ src_cam = self.ren_geom.GetActiveCamera()
1652
+ dst_cam = self.ren_pred.GetActiveCamera()
1653
+ if src_cam is not None and dst_cam is not None:
1654
+ dst_cam.SetPosition(src_cam.GetPosition())
1655
+ dst_cam.SetFocalPoint(src_cam.GetFocalPoint())
1656
+ dst_cam.SetViewUp(src_cam.GetViewUp())
1657
+ dst_cam.SetParallelScale(src_cam.GetParallelScale())
1658
+ cr = src_cam.GetClippingRange()
1659
+ dst_cam.SetClippingRange(cr)
1660
+
1661
+ try:
1662
+ self.rw_pred.Render()
1663
+ except Exception:
1664
+ pass
1665
+ self.view_pred.update()
1666
+
1667
+ raw_vmin = float(np.min(viz["pred"]))
1668
+ raw_vmax = float(np.max(viz["pred"]))
1669
+
1670
+ s.stats_html = (
1671
+ f"<b>{variable} min:</b>{raw_vmin:.3e} "
1672
+ f"<b>max:</b> {raw_vmax:.3e} "
1673
+ f"<b>Confidence:</b> {viz['confidence_score']:.4f}"
1674
+ )
1675
+
1676
+ s.predict_msg = "✅ Prediction complete."
1677
+ s.predict_progress = 100
1678
+ s.pm_infer = 100
1679
+ s.predict_elapsed = time.time() - t0
1680
+ s.pm_elapsed_infer = s.predict_elapsed
1681
+ await self._flush_async()
1682
+
1683
+ except Exception as e:
1684
+ s.predict_msg = f"❌ Prediction failed: {e}"
1685
+ s.predict_progress = 0
1686
+ s.pm_infer = 0
1687
+ await self._flush_async()
1688
+ finally:
1689
+ s.is_predicting = False
1690
+ self._stop_infer_heartbeat()
1691
+ await self._flush_async()
1692
+
1693
+ @time_function("Inference and Visualization")
1694
+ def predict(self, *_):
1695
+ self._run_coro(self._predict_worker_async())
1696
+
1697
+ # ---------------------------------------------------------
1698
+ # dataset wiring
1699
+ # ---------------------------------------------------------
1700
+ @change("analysis_type")
1701
+ def _on_analysis_type_change(self, analysis_type=None, **_):
1702
+ ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", [])
1703
+ default_ds = ds_list[0] if ds_list else None
1704
+ self.state.dataset_choices = ds_list
1705
+ if default_ds and self.state.dataset != default_ds:
1706
+ self.state.dataset = default_ds
1707
+ elif self.state.dataset:
1708
+ self._apply_dataset(self.state.dataset)
1709
+
1710
+ @change("dataset")
1711
+ def _on_dataset_change(self, dataset=None, **_):
1712
+ if not dataset:
1713
+ return
1714
+ self._apply_dataset(dataset)
1715
+
1716
+ def _apply_dataset(self, ds: str):
1717
+ s = self.state
1718
+ opts = variables_for(ds) if ds else []
1719
+ s.variable_choices = opts
1720
+ s.variable = opts[0] if opts else None
1721
+
1722
+ s.show_velocity = (ds == "Incompressible flow over car")
1723
+ s.is_plane = (ds == "Compressible flow over plane")
1724
+
1725
+ s.bc_text = get_boundary_conditions_text(ds)
1726
+ s.bc_left = bc_text_left(ds)
1727
+ s.bc_right = bc_text_right(ds)
1728
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
1729
+
1730
+ # ---------------------------------------------------------
1731
+ # clear
1732
+ # ---------------------------------------------------------
1733
+ def clear(self, *_):
1734
+ for d in [GEOM_DIR, SOLN_DIR]:
1735
+ if os.path.exists(d):
1736
+ shutil.rmtree(d)
1737
+ os.makedirs(d, exist_ok=True)
1738
+ s = self.state
1739
+ s.stats_html = "🧹 Cleared. Upload again."
1740
+ s.is_uploading = False
1741
+ s.pm_upload = 0
1742
+ s.upload_msg = ""
1743
+ s.pm_elapsed_upload = 0.0
1744
+ s.is_predicting = False
1745
+ s.predict_progress = 0
1746
+ s.predict_msg = ""
1747
+ s.pm_infer = 0
1748
+ s.pm_elapsed_infer = 0.0
1749
+ self.ren_geom.RemoveAllViewProps()
1750
+ self.ren_pred.RemoveAllViewProps()
1751
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1752
+ try:
1753
+ rw.Render()
1754
+ except Exception:
1755
+ pass
1756
+ view.update()
1757
+
1758
+ # ---------- main ----------
1759
+ def main():
1760
+ app = PFMDemo()
1761
+ app.server.controller.add("decimate_again", app.decimate_again)
1762
+ app.server.controller.add("reset_mesh", app.reset_mesh)
1763
+ app.server.start(7872)
1764
+
1765
+ if __name__ == "__main__":
1766
+ main()
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ title: AnsysLPFM App
2
+ emoji: 💻
3
+ colorFrom: yellow
4
+ colorTo: gray
5
+ sdk: docker
6
+ pinned: false
7
+ license: mit
8
+ short_description: Ansys Research for Foundation models on physics
ULIP ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 5e7c0da470fc16717030ec4116b0f81d4d2b4823
app.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Headless / Offscreen safety (before any VTK import) =========
2
+ import os
3
+ os.environ.setdefault("VTK_DEFAULT_RENDER_WINDOW_OFFSCREEN", "1")
4
+ os.environ.setdefault("LIBGL_ALWAYS_SOFTWARE", "1")
5
+ os.environ.setdefault("MESA_LOADER_DRIVER_OVERRIDE", "llvmpipe")
6
+ os.environ.setdefault("MESA_GL_VERSION_OVERRIDE", "3.3")
7
+ os.environ.setdefault("DISPLAY", "")
8
+
9
+ # ========= Core setup =========
10
+ import shutil, time, tempfile, json, base64, threading, re, html as _html, asyncio
11
+ import numpy as np
12
+ import torch
13
+ import pyvista as pv
14
+ from scipy.spatial import cKDTree
15
+ from vtk.util import numpy_support as nps
16
+ import matplotlib.cm as cm
17
+
18
+ from omegaconf import OmegaConf
19
+ from huggingface_hub import hf_hub_download
20
+ from accelerate import Accelerator
21
+ from accelerate.utils import DistributedDataParallelKwargs
22
+ import pickle
23
+ from sklearn.metrics import pairwise_distances
24
+ from train import get_single_latent
25
+ from sklearn.neighbors import NearestNeighbors
26
+
27
+ from utils.app_utils2 import (
28
+ create_visualization_points,
29
+ create_visualization_stl,
30
+ camera_from_bounds,
31
+ bounds_from_points,
32
+ convert_vtp_to_glb,
33
+ convert_vtp_to_stl,
34
+ time_function,
35
+ print_timing,
36
+ mesh_get_variable,
37
+ mph_to_ms,
38
+ get_boundary_conditions_text,
39
+ compute_confidence_score,
40
+ compute_cosine_score,
41
+ decimate_mesh,
42
+ )
43
+
44
+ # ========= trame =========
45
+ from trame.app import TrameApp
46
+ from trame.decorators import change
47
+ from trame.ui.vuetify3 import SinglePageLayout
48
+ from trame.widgets import vuetify3 as v3, html
49
+ from trame_vtk.widgets.vtk import VtkRemoteView
50
+
51
+ # ========= VTK =========
52
+ from vtkmodules.vtkRenderingCore import (
53
+ vtkRenderer,
54
+ vtkRenderWindow,
55
+ vtkPolyDataMapper,
56
+ vtkActor,
57
+ vtkRenderWindowInteractor,
58
+ )
59
+ from vtkmodules.vtkRenderingAnnotation import vtkScalarBarActor
60
+ from vtkmodules.vtkIOGeometry import vtkSTLReader
61
+ from vtkmodules.vtkFiltersCore import vtkTriangleFilter
62
+ from vtkmodules.vtkCommonCore import vtkLookupTable
63
+ from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
64
+
65
+ # ========= Writable paths / caches =========
66
+ DATA_DIR = os.path.join(tempfile.gettempdir(), "appdata")
67
+ os.makedirs(DATA_DIR, exist_ok=True)
68
+ os.environ.setdefault("MPLCONFIGDIR", DATA_DIR)
69
+
70
+ GEOM_DIR = os.path.join(DATA_DIR, "geometry")
71
+ SOLN_DIR = os.path.join(DATA_DIR, "solution")
72
+ WEIGHTS_DIR = os.path.join(DATA_DIR, "weights")
73
+ for d in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR):
74
+ os.makedirs(d, exist_ok=True)
75
+
76
+ HF_DIR = os.path.join(DATA_DIR, "hf_home")
77
+ os.environ.setdefault("HF_HOME", HF_DIR)
78
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", HF_DIR)
79
+ os.environ.setdefault("TRANSFORMERS_CACHE", HF_DIR)
80
+ os.makedirs(HF_DIR, exist_ok=True)
81
+ for p in (GEOM_DIR, SOLN_DIR, WEIGHTS_DIR, HF_DIR):
82
+ if not os.access(p, os.W_OK):
83
+ raise RuntimeError(f"Not writable: {p}")
84
+
85
+ # ========= Auto-decimation ladder =========
86
+ def auto_target_reduction(num_cells: int) -> float:
87
+ if num_cells <= 10_000:
88
+ return 0.0
89
+ elif num_cells <= 20_000:
90
+ return 0.2
91
+ elif num_cells <= 50_000:
92
+ return 0.4
93
+ elif num_cells <= 100_000:
94
+ return 0.5
95
+ elif num_cells <= 500_000:
96
+ return 0.6
97
+ elif num_cells < 1_000_000:
98
+ return 0.8
99
+ else:
100
+ return 0.9
101
+
102
+ # ========= Registry / choices =========
103
+ REGISTRY = {
104
+ "Incompressible flow inside artery": {
105
+ "repo_id": "ansysresearch/pretrained_models",
106
+ "config": "configs/app_configs/Incompressible flow inside artery/config.yaml",
107
+ "model_attr": "ansysLPFMs",
108
+ "checkpoints": {"best_val": "ckpt_artery.pt"},
109
+ "out_variable": ["pressure", "x_velocity", "y_velocity", "z_velocity"],
110
+ },
111
+ "Vehicle crash analysis": {
112
+ "repo_id": "ansysresearch/pretrained_models",
113
+ "config": "configs/app_configs/Vehicle crash analysis/config.yaml",
114
+ "model_attr": "ansysLPFMs",
115
+ "checkpoints": {"best_val": "ckpt_vehiclecrash.pt"},
116
+ "out_variable": [
117
+ "impact_force",
118
+ "deformation",
119
+ "energy_absorption",
120
+ "x_displacement",
121
+ "y_displacement",
122
+ "z_displacement",
123
+ ],
124
+ },
125
+ "Compressible flow over plane": {
126
+ "repo_id": "ansysresearch/pretrained_models",
127
+ "config": "configs/app_configs/Compressible flow over plane/config.yaml",
128
+ "model_attr": "ansysLPFMs",
129
+ "checkpoints": {"best_val": "ckpt_plane_transonic_v3.pt"},
130
+ "out_variable": ["pressure"],
131
+ },
132
+ "Incompressible flow over car": {
133
+ "repo_id": "ansysresearch/pretrained_models",
134
+ "config": "configs/app_configs/Incompressible flow over car/config.yaml",
135
+ "model_attr": "ansysLPFMs",
136
+ "checkpoints": {"best_val": "ckpt_cadillac_v3.pt"},
137
+ "out_variable": ["pressure"],
138
+ },
139
+ }
140
+
141
+ def variables_for(dataset: str):
142
+ spec = REGISTRY.get(dataset, {})
143
+ ov = spec.get("out_variable")
144
+ if isinstance(ov, str):
145
+ return [ov]
146
+ if isinstance(ov, (list, tuple)):
147
+ return list(ov)
148
+ return list(spec.get("checkpoints", {}).keys())
149
+
150
+ ANALYSIS_TYPE_MAPPING = {
151
+ "External flow": ["Incompressible flow over car", "Compressible flow over plane"],
152
+ "Internal flow": ["Incompressible flow inside artery"],
153
+ "Crash analysis": ["Vehicle crash analysis"],
154
+ }
155
+ ANALYSIS_TYPE = list(ANALYSIS_TYPE_MAPPING.keys())
156
+ DEFAULT_ANALYSIS_TYPE = "External flow"
157
+ DEFAULT_DATASET = "Incompressible flow over car"
158
+ VAR_CHOICES0 = variables_for(DEFAULT_DATASET)
159
+ DEFAULT_VARIABLE = VAR_CHOICES0[0] if VAR_CHOICES0 else None
160
+
161
+ # ========= Simple cache =========
162
+ class GeometryCache:
163
+ def __init__(self):
164
+ self.original_mesh = None # uploaded, cleaned (normals), BEFORE user re-decimation
165
+ self.current_mesh = None # what the app is actually using right now
166
+
167
+ GEOMETRY_CACHE = GeometryCache()
168
+
169
+ # ========= Model store =========
170
+ class ModelStore:
171
+ def __init__(self):
172
+ self._cache = {}
173
+
174
+ def _build(self, dataset: str, progress_cb=None):
175
+ def tick(x):
176
+ if progress_cb:
177
+ try:
178
+ progress_cb(int(x))
179
+ except:
180
+ pass
181
+
182
+ if dataset in self._cache:
183
+ tick(12)
184
+ return self._cache[dataset]
185
+
186
+ print(f"🔧 Building model for {dataset}")
187
+ start_time = time.time()
188
+ try:
189
+ spec = REGISTRY[dataset]
190
+ repo_id = spec["repo_id"]
191
+ ckpt_name = spec["checkpoints"]["best_val"]
192
+
193
+ tick(6)
194
+ t0 = time.time()
195
+ ckpt_path_hf = hf_hub_download(
196
+ repo_id=repo_id,
197
+ filename=ckpt_name,
198
+ repo_type="model",
199
+ local_dir=HF_DIR,
200
+ local_dir_use_symlinks=False,
201
+ )
202
+ print_timing("Model checkpoint download", t0)
203
+ tick(8)
204
+
205
+ t0 = time.time()
206
+ ckpt_local_dir = os.path.join(WEIGHTS_DIR, dataset)
207
+ os.makedirs(ckpt_local_dir, exist_ok=True)
208
+ ckpt_path = os.path.join(ckpt_local_dir, ckpt_name)
209
+ if not os.path.exists(ckpt_path):
210
+ shutil.copy(ckpt_path_hf, ckpt_path)
211
+ print_timing("Local model copy setup", t0)
212
+ tick(9)
213
+
214
+ t0 = time.time()
215
+ cfg_path = spec["config"]
216
+ if not os.path.exists(cfg_path):
217
+ raise FileNotFoundError(f"Missing config: {cfg_path}")
218
+ cfg = OmegaConf.load(cfg_path)
219
+ cfg.save_latent = True
220
+ print_timing("Configuration loading", t0)
221
+ tick(11)
222
+
223
+ t0 = time.time()
224
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(getattr(cfg, "gpu_id", 0))
225
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
226
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
227
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
+ print_timing("Device init", t0)
229
+ tick(12)
230
+
231
+ t0 = time.time()
232
+ import models
233
+ model_cls_name = spec["model_attr"]
234
+ if not hasattr(models, model_cls_name):
235
+ raise ValueError(f"Model '{model_cls_name}' not found")
236
+ model = getattr(models, model_cls_name)(cfg).to(device)
237
+ print_timing("Model build", t0)
238
+ tick(14)
239
+
240
+ t0 = time.time()
241
+ state = torch.load(ckpt_path, map_location=device)
242
+ model.load_state_dict(state)
243
+ model.eval()
244
+ print_timing("Weights load", t0)
245
+ tick(15)
246
+
247
+ result = (cfg, model, device, accelerator)
248
+ self._cache[dataset] = result
249
+ print_timing(f"Total model build for {dataset}", start_time)
250
+ return result
251
+ except Exception as e:
252
+ print_timing(f"Model build failed for {dataset}", e)
253
+ raise RuntimeError(f"Failed to load model for dataset '{dataset}': {e}")
254
+
255
+ def get(self, dataset: str, variable: str, progress_cb=None):
256
+ return self._build(dataset, progress_cb=progress_cb)
257
+
258
+ MODEL_STORE = ModelStore()
259
+
260
+ # ========= Inference pipeline =========
261
+ def _variable_index(dataset: str, variable: str) -> int:
262
+ ov = REGISTRY[dataset]["out_variable"]
263
+ return 0 if isinstance(ov, str) else ov.index(variable)
264
+
265
+ @time_function("Mesh Processing")
266
+ def process_mesh_fast(mesh: pv.DataSet, cfg, variable, dataset, boundary_conditions=None):
267
+ jpath = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
268
+ json_data = json.load(open(jpath, "r"))
269
+
270
+ pts = np.asarray(mesh.points, dtype=np.float32)
271
+ N = pts.shape[0]
272
+ rng = np.random.default_rng(42)
273
+ idx = rng.permutation(N)
274
+ points = pts[idx]
275
+ tgt_np = mesh_get_variable(mesh, variable, N)[idx]
276
+ pos = torch.from_numpy(points)
277
+ target = torch.from_numpy(tgt_np).unsqueeze(-1)
278
+
279
+ if getattr(cfg, "diff_input_velocity", False) and boundary_conditions is not None:
280
+ if "freestream_velocity" in boundary_conditions:
281
+ inlet_x_velocity = torch.tensor(boundary_conditions["freestream_velocity"]).float().reshape(1, 1)
282
+ inlet_x_velocity = inlet_x_velocity.repeat(N, 1)[idx]
283
+ pos = torch.cat((pos, inlet_x_velocity), dim=1)
284
+
285
+ if getattr(cfg, "input_normalization", None) == "shift_axis":
286
+ coords = pos[:, :3].clone()
287
+ coords[:, 0] = coords[:, 0] - coords[:, 0].min()
288
+ coords[:, 2] = coords[:, 2] - coords[:, 2].min()
289
+ y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
290
+ coords[:, 1] = coords[:, 1] - y_center
291
+ pos[:, :3] = coords
292
+
293
+ if getattr(cfg, "pos_embed_sincos", False):
294
+ mins = torch.tensor(json_data["mesh_stats"]["min"], dtype=torch.float32)
295
+ maxs = torch.tensor(json_data["mesh_stats"]["max"], dtype=torch.float32)
296
+ pos = 1000.0 * (pos - mins) / (maxs - mins)
297
+ pos = torch.clamp(pos, 0, 1000)
298
+
299
+ cosine_score = compute_cosine_score(mesh, dataset)
300
+ return pos, target, points, cosine_score
301
+
302
+ @time_function("Inference")
303
+ def run_inference_fast(dataset: str, variable: str, boundary_conditions=None, progress_cb=None):
304
+ def p(v):
305
+ if progress_cb:
306
+ try:
307
+ progress_cb(int(v))
308
+ except Exception:
309
+ pass
310
+
311
+ if GEOMETRY_CACHE.current_mesh is None:
312
+ raise ValueError("No geometry loaded")
313
+
314
+ p(5)
315
+ cfg, model, device, _ = MODEL_STORE.get(dataset, variable, progress_cb=p)
316
+ p(15)
317
+
318
+ pos, target, points, cosine_score = process_mesh_fast(
319
+ GEOMETRY_CACHE.current_mesh, cfg, variable, dataset, boundary_conditions
320
+ )
321
+ p(25)
322
+
323
+ confidence_score = 0.0
324
+ try:
325
+ if dataset not in ["Incompressible flow inside artery"]:
326
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
327
+ latent_features = get_single_latent(
328
+ mesh_path=geom_path,
329
+ config_path=os.path.join("configs/app_configs/", dataset, "config.yaml"),
330
+ device=device,
331
+ custom_velocity=boundary_conditions["freestream_velocity"] if boundary_conditions else None,
332
+ use_training_velocity=False,
333
+ model=model,
334
+ )
335
+
336
+ embedding_path = os.path.join("configs/app_configs/", dataset, "pca_embedding.npy")
337
+ pca_reducer_path = os.path.join("configs/app_configs/", dataset, "pca_reducer.pkl")
338
+ scaler_path = os.path.join("configs/app_configs/", dataset, "pca_scaler.pkl")
339
+
340
+ embedding = np.load(embedding_path)
341
+ pca_reducer = pickle.load(open(pca_reducer_path, "rb"))
342
+ scaler = pickle.load(open(scaler_path, "rb"))
343
+
344
+ train_pair_dists = pairwise_distances(embedding)
345
+ sigma = float(np.median(train_pair_dists)) if train_pair_dists.size > 0 else 1.0
346
+
347
+ n_points, n_features = latent_features.shape
348
+ np.random.seed(42)
349
+ target_len = int(pca_reducer.n_features_in_ / 256)
350
+ if n_points > target_len:
351
+ latent_features = latent_features[np.random.choice(n_points, target_len, replace=False)]
352
+ elif n_points < target_len:
353
+ num_extra = target_len - n_points
354
+ extra_indices = np.random.choice(n_points, num_extra, replace=True)
355
+ latent_features = np.vstack([latent_features, latent_features[extra_indices]])
356
+
357
+ latent_features = latent_features.flatten()
358
+
359
+ confidence_score = compute_confidence_score(
360
+ latent_features, embedding, scaler, pca_reducer, sigma
361
+ )
362
+ except Exception:
363
+ confidence_score = 0.0
364
+
365
+ data = {
366
+ "input_pos": pos.unsqueeze(0).to(device),
367
+ "output_feat": target.unsqueeze(0).to(device),
368
+ }
369
+
370
+ with torch.no_grad():
371
+ inp = data["input_pos"]
372
+ _, N, _ = inp.shape
373
+ chunk = int(getattr(cfg, "num_points", 10000))
374
+
375
+ if getattr(cfg, "chunked_eval", False) and chunk < N:
376
+ input_pos = data["input_pos"]
377
+ chunk_size = cfg.num_points
378
+ out_chunks = []
379
+ total = (N + chunk_size - 1) // chunk_size
380
+
381
+ for k, i in enumerate(range(0, N, chunk_size)):
382
+ ch = input_pos[:, i : i + chunk_size, :]
383
+ n_valid = ch.shape[1]
384
+
385
+ if n_valid < chunk_size:
386
+ pad = input_pos[:, : chunk_size - n_valid, :]
387
+ ch = torch.cat([ch, pad], dim=1)
388
+
389
+ data["input_pos"] = ch
390
+ out_chunk = model(data)
391
+ if isinstance(out_chunk, (tuple, list)):
392
+ out_chunk = out_chunk[0]
393
+ out_chunks.append(out_chunk[:, :n_valid, :])
394
+
395
+ p(25 + 60 * (k + 1) / max(1, total))
396
+
397
+ outputs = torch.cat(out_chunks, dim=1)
398
+
399
+ else:
400
+ p(40)
401
+ outputs = model(data)
402
+ if isinstance(outputs, (tuple, list)):
403
+ outputs = outputs[0]
404
+ if torch.cuda.is_available():
405
+ torch.cuda.synchronize()
406
+ p(85)
407
+
408
+ vi = _variable_index(dataset, variable)
409
+ pred = outputs[0, :, vi : vi + 1]
410
+
411
+ if getattr(cfg, "normalization", "") == "std_norm":
412
+ fp = os.path.join("configs/app_configs/", dataset, "full_transform_params.json")
413
+ j = json.load(open(fp, "r"))
414
+ mu = torch.tensor(float(j["scalars"][variable]["mean"]), device=pred.device)
415
+ sd = torch.tensor(float(j["scalars"][variable]["std"]), device=pred.device)
416
+ pred = pred * sd + mu
417
+
418
+ pred_np = pred.squeeze().detach().cpu().numpy()
419
+ tgt_np = target.squeeze().numpy()
420
+
421
+ pred_t = torch.from_numpy(pred_np).unsqueeze(-1)
422
+ tgt_t = torch.from_numpy(tgt_np).unsqueeze(-1)
423
+ rel_l2 = torch.mean(
424
+ torch.norm(pred_t.squeeze(-1) - tgt_t.squeeze(-1), p=2, dim=-1)
425
+ / torch.norm(tgt_t.squeeze(-1), p=2, dim=-1)
426
+ )
427
+ tgt_mean = torch.mean(tgt_t)
428
+ ss_tot = torch.sum((tgt_t - tgt_mean) ** 2)
429
+ ss_res = torch.sum((tgt_t - pred_t) ** 2)
430
+ r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else torch.tensor(0.0)
431
+
432
+ p(100)
433
+ return {
434
+ "points": np.asarray(points),
435
+ "pred": np.asarray(pred_np),
436
+ "tgt": np.asarray(tgt_np),
437
+ "cosine_score": float(cosine_score),
438
+ "confidence_score": float(confidence_score),
439
+ "abs_err": float(np.mean(np.abs(pred_np - tgt_np))),
440
+ "mse_err": float(np.mean((pred_np - tgt_np) ** 2)),
441
+ "rel_l2": float(rel_l2.item()),
442
+ "r_squared": float(r2.item()),
443
+ }
444
+
445
+ # ========= VTK helpers =========
446
+ def make_actor_from_stl(stl_path: str, color=(0.85, 0.85, 0.85)):
447
+ r = vtkSTLReader()
448
+ r.SetFileName(stl_path)
449
+ r.Update()
450
+ tri = vtkTriangleFilter()
451
+ tri.SetInputConnection(r.GetOutputPort())
452
+ tri.Update()
453
+ m = vtkPolyDataMapper()
454
+ m.SetInputConnection(tri.GetOutputPort())
455
+ a = vtkActor()
456
+ a.SetMapper(m)
457
+ a.GetProperty().SetColor(*color)
458
+ return a
459
+
460
+ def build_jet_lut(vmin, vmax):
461
+ lut = vtkLookupTable()
462
+ lut.SetRange(float(vmin), float(vmax))
463
+ lut.SetNumberOfTableValues(256)
464
+ lut.Build()
465
+ cmap = cm.get_cmap("jet", 256)
466
+ for i in range(256):
467
+ r_, g_, b_, _ = cmap(i)
468
+ lut.SetTableValue(i, float(r_), float(g_), float(b_), 1.0)
469
+ return lut
470
+
471
+ def color_actor_with_scalars_from_prediction(stl_path, points_xyz, pred_vals, array_name, vmin, vmax, lut=None):
472
+ r = vtkSTLReader()
473
+ r.SetFileName(stl_path)
474
+ r.Update()
475
+ poly = r.GetOutput()
476
+
477
+ stl_pts = nps.vtk_to_numpy(poly.GetPoints().GetData())
478
+ tree = cKDTree(points_xyz)
479
+ _, nn_idx = tree.query(stl_pts, k=1)
480
+ scalars = np.asarray(pred_vals, dtype=np.float32)[nn_idx]
481
+
482
+ vtk_arr = nps.numpy_to_vtk(scalars, deep=True)
483
+ vtk_arr.SetName(array_name)
484
+ poly.GetPointData().AddArray(vtk_arr)
485
+ poly.GetPointData().SetActiveScalars(array_name)
486
+
487
+ mapper = vtkPolyDataMapper()
488
+ mapper.SetInputData(poly)
489
+ mapper.SetScalarModeToUsePointData()
490
+ mapper.ScalarVisibilityOn()
491
+ mapper.SetScalarRange(float(vmin), float(vmax))
492
+ if lut is None:
493
+ lut = build_jet_lut(vmin, vmax)
494
+ mapper.SetLookupTable(lut)
495
+ mapper.UseLookupTableScalarRangeOn()
496
+
497
+ actor = vtkActor()
498
+ actor.SetMapper(mapper)
499
+ return actor
500
+
501
+ def add_or_update_scalar_bar(renderer, lut, title, label_fmt="%-0.2f", n_labels=8):
502
+ existing = []
503
+ ca = renderer.GetActors2D()
504
+ ca.InitTraversal()
505
+ for _ in range(ca.GetNumberOfItems()):
506
+ a = ca.GetNextItem()
507
+ if isinstance(a, vtkScalarBarActor):
508
+ existing.append(a)
509
+ for a in existing:
510
+ renderer.RemoveActor2D(a)
511
+
512
+ sbar = vtkScalarBarActor()
513
+ sbar.SetLookupTable(lut)
514
+ sbar.SetOrientationToVertical()
515
+ sbar.SetLabelFormat(label_fmt)
516
+ sbar.SetNumberOfLabels(int(n_labels))
517
+ sbar.SetTitle(title)
518
+ sbar.SetPosition(0.92, 0.05)
519
+ sbar.SetPosition2(0.06, 0.90)
520
+
521
+ tp = sbar.GetTitleTextProperty()
522
+ tp.SetColor(1, 1, 1)
523
+ tp.SetBold(True)
524
+ tp.SetFontSize(22)
525
+ lp = sbar.GetLabelTextProperty()
526
+ lp.SetColor(1, 1, 1)
527
+ lp.SetFontSize(18)
528
+
529
+ renderer.AddActor2D(sbar)
530
+ return sbar
531
+
532
+ # ---------- Small helpers ----------
533
+ def poly_count(mesh: pv.PolyData) -> int:
534
+ if hasattr(mesh, "n_faces_strict"):
535
+ return mesh.n_faces_strict
536
+ return mesh.n_cells
537
+
538
+ def md_to_html(txt: str) -> str:
539
+ if not txt:
540
+ return ""
541
+ safe = _html.escape(txt)
542
+ safe = re.sub(r"\*\*(.+?)\*\*", r"<b>\1</b>", safe)
543
+ return "<br/>".join(safe.splitlines())
544
+
545
+ def bc_text_right(dataset: str) -> str:
546
+ if dataset == "Incompressible flow over car":
547
+ return (
548
+ "<b>Reference Density:</b> 1.225 kg/m³<br><br>"
549
+ "<b>Reference Viscosity:</b> 1.789e-5 Pa·s<br><br>"
550
+ "<b>Operating Pressure:</b> 101325 Pa"
551
+ )
552
+ if dataset == "Compressible flow over plane":
553
+ return (
554
+ "<b>Ambient Temperature:</b> 218 K<br><br>"
555
+ "<b>Cruising velocity:</b> 250.0 m/s or 560 mph"
556
+ )
557
+ return ""
558
+
559
+ def bc_text_left(dataset: str) -> str:
560
+ if dataset == "Compressible flow over plane":
561
+ return (
562
+ "<b>Reference Density:</b> 0.36 kg/m³<br><br>"
563
+ "<b>Reference viscosity:</b> 1.716e-05 kg/(m·s)<br><br>"
564
+ "<b>Operating Pressure:</b> 23842 Pa"
565
+ )
566
+ return ""
567
+
568
+ # =====================================================================
569
+ # ======================= APP =======================================
570
+ # =====================================================================
571
+ class PFMDemo(TrameApp):
572
+ def __init__(self, server=None):
573
+ super().__init__(server)
574
+
575
+ # ---------------- VTK RENDERERS ----------------
576
+ self.ren_geom = vtkRenderer()
577
+ self.ren_geom.SetBackground(0.10, 0.16, 0.22)
578
+ self.rw_geom = vtkRenderWindow()
579
+ self.rw_geom.SetOffScreenRendering(1)
580
+ self.rw_geom.AddRenderer(self.ren_geom)
581
+ self.rwi_geom = vtkRenderWindowInteractor()
582
+ self.rwi_geom.SetRenderWindow(self.rw_geom)
583
+ self.rwi_geom.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
584
+ try:
585
+ self.rwi_geom.Initialize()
586
+ self.rwi_geom.Enable()
587
+ except Exception:
588
+ pass
589
+
590
+ self.ren_pred = vtkRenderer()
591
+ self.ren_pred.SetBackground(0.10, 0.16, 0.22)
592
+ self.rw_pred = vtkRenderWindow()
593
+ self.rw_pred.SetOffScreenRendering(1)
594
+ self.rw_pred.AddRenderer(self.ren_pred)
595
+ self.rwi_pred = vtkRenderWindowInteractor()
596
+ self.rwi_pred.SetRenderWindow(self.rw_pred)
597
+ self.rwi_pred.SetInteractorStyle(vtkInteractorStyleTrackballCamera())
598
+ try:
599
+ self.rwi_pred.Initialize()
600
+ self.rwi_pred.Enable()
601
+ except Exception:
602
+ pass
603
+
604
+ self.scalar_bar = None
605
+
606
+ # timers / flags
607
+ self._predict_t0 = None
608
+ self._infer_thread = None
609
+ self._pre_upload_thread = None
610
+ self._infer_heartbeat_on = False
611
+ self._loop = None
612
+
613
+ # ---------------- TRAME STATE ----------------
614
+ s = self.state
615
+ s.theme_dark = True
616
+
617
+ s.analysis_types = ANALYSIS_TYPE
618
+ s.analysis_type = DEFAULT_ANALYSIS_TYPE
619
+ s.dataset_choices = ANALYSIS_TYPE_MAPPING[DEFAULT_ANALYSIS_TYPE]
620
+ s.dataset = DEFAULT_DATASET
621
+ s.variable_choices = variables_for(DEFAULT_DATASET)
622
+ s.variable = s.variable_choices[0] if s.variable_choices else None
623
+
624
+ # dialog (still kept)
625
+ s.show_decimation_dialog = False
626
+ s.decim_override_enabled = False
627
+ s.decim_override_mode = "medium"
628
+ s.decim_override_custom = 0.5
629
+
630
+ # menu decimation defaults
631
+ s.decim_enable = False # user MUST toggle to override auto
632
+ s.decim_target = 0.5
633
+ s.decim_min_faces = 5000 # <= important: 0 so small meshes can be reduced
634
+ s.decim_max_faces = int(1e7)
635
+
636
+ # register controller properly
637
+ # self.server.controller.decimate_again = self.decimate_again
638
+ # self.server.controller.add("decimate_again", self.decimate_again)
639
+ ctrl = self.server.controller
640
+
641
+ # ✅ this actually registers the trigger
642
+ ctrl.add("decimate_again", self.decimate_again)
643
+ ctrl.add("reset_mesh", self.reset_mesh)
644
+
645
+
646
+
647
+ s.show_velocity = (DEFAULT_DATASET == "Incompressible flow over car")
648
+ s.is_plane = (DEFAULT_DATASET == "Compressible flow over plane")
649
+ s.velocity_mph = 45.0
650
+
651
+ s.bc_text = get_boundary_conditions_text(DEFAULT_DATASET)
652
+ s.bc_left = bc_text_left(DEFAULT_DATASET)
653
+ s.bc_right = bc_text_right(DEFAULT_DATASET)
654
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
655
+
656
+ s.stats_html = "👋 Upload a geometry, then click Predict."
657
+ s.upload = None
658
+
659
+ # upload
660
+ s.is_uploading = False
661
+ s.pm_upload = 0
662
+ s.pm_elapsed_upload = 0.0
663
+ s.upload_msg = ""
664
+
665
+ # predict
666
+ s.is_predicting = False
667
+ s.predict_progress = 0
668
+ s.predict_msg = ""
669
+ s.predict_elapsed = 0.0
670
+ s.predict_est_total = 0.0
671
+ s.pm_infer = 0
672
+ s.pm_elapsed_infer = 0.0
673
+
674
+ self._build_ui()
675
+
676
+ def _ensure_loop(self):
677
+ if self._loop is not None:
678
+ return self._loop
679
+ try:
680
+ loop = asyncio.get_event_loop()
681
+ except RuntimeError:
682
+ loop = asyncio.new_event_loop()
683
+ asyncio.set_event_loop(loop)
684
+ self._loop = loop
685
+ return loop
686
+
687
+ def _run_coro(self, coro):
688
+ loop = self._ensure_loop()
689
+ if loop.is_running():
690
+ return asyncio.ensure_future(coro, loop=loop)
691
+ return loop.run_until_complete(coro)
692
+
693
+ async def _flush_async(self):
694
+ try:
695
+ self.server.state.flush()
696
+ except Exception:
697
+ pass
698
+ await asyncio.sleep(0)
699
+
700
+ def _build_ui(self):
701
+ ctrl = self.server.controller
702
+ with SinglePageLayout(self.server, full_height=True) as layout:
703
+ layout.title.set_text("") # clear
704
+ layout.title.hide = True # hide default
705
+ with layout.toolbar:
706
+ with v3.VContainer(
707
+ fluid=True,
708
+ style=(
709
+ "max-width: 1800px;" # overall width
710
+ "margin: 0 auto;" # center it
711
+ "padding: 0 8px;" # ← left/right margin
712
+ "box-sizing: border-box;"
713
+ ),
714
+ ):
715
+ v3.VSpacer()
716
+ html.Div(
717
+ "Ansys: Physics Foundation Model",
718
+ style=(
719
+ "width:100%;"
720
+ "text-align:center;"
721
+ "font-size:34px;"
722
+ "font-weight:900;"
723
+ "letter-spacing:0.4px;"
724
+ "line-height:1.2;"
725
+ ),
726
+ )
727
+ v3.VSpacer()
728
+
729
+ # toolbar
730
+ with layout.toolbar:
731
+ # ← same margin container for the second toolbar row
732
+ with v3.VContainer(
733
+ fluid=True,
734
+ style=(
735
+ "max-width: 1800px;"
736
+ "margin: 0 auto;"
737
+ "padding: 0 8px;"
738
+ "box-sizing: border-box;"
739
+ ),
740
+ ):
741
+ v3.VSwitch(
742
+ v_model=("theme_dark",),
743
+ label="Dark Theme",
744
+ inset=True,
745
+ density="compact",
746
+ hide_details=True,
747
+ )
748
+ v3.VSpacer()
749
+
750
+ with layout.content:
751
+ html.Style("""
752
+ /* Small side padding for the whole app */
753
+ .v-application__wrap {
754
+ padding-left: 8px;
755
+ padding-right: 8px;
756
+ padding-bottom: 8px;
757
+ }
758
+
759
+ :root {
760
+ --pfm-font-ui: 'Inter', 'IBM Plex Sans', 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
761
+ --pfm-font-mono: 'JetBrains Mono', 'IBM Plex Mono', monospace;
762
+ }
763
+
764
+ html, body, .v-application {
765
+ margin: 0;
766
+ padding: 0;
767
+ font-family: var(--pfm-font-ui) !important;
768
+ font-weight: 500;
769
+ letter-spacing: .25px;
770
+ -webkit-font-smoothing: antialiased;
771
+ -moz-osx-font-smoothing: grayscale;
772
+ text-rendering: optimizeLegibility;
773
+ line-height: 1.5;
774
+ font-size: 15.5px;
775
+ color: #ECEFF4;
776
+ }
777
+
778
+ /* ... keep all your other typography / button / slider styles here ... */
779
+
780
+ .v-theme--dark { background-color: #14171C !important; color: #ECEFF4 !important; }
781
+ .v-theme--light { background-color: #F6F7FA !important; color: #1F1F1F !important; }
782
+
783
+ /* (rest of your .pfm-* classes unchanged) */
784
+ """)
785
+
786
+ # html.Style("""
787
+ # .v-theme--dark { background: #1F232B !important; }
788
+ # .v-theme--light { background: #f5f6f8 !important; }
789
+ # .v-theme--dark .pfm-card { background: #23272F !important; color: #fff !important; }
790
+ # .v-theme--light .pfm-card { background: #ffffff !important; color: #1f232b !important; }
791
+ # .v-theme--dark .pfm-viewer { background: #15171d !important; }
792
+ # .v-theme--light .pfm-viewer { background: #e9edf3 !important; }
793
+ # .pfm-card { border-radius: 16px !important; box-shadow: 0 6px 24px rgba(0,0,0,0.12); }
794
+ # .pfm-progress .v-progress-linear { height: 22px !important; border-radius: 999px !important; }
795
+ # .pfm-btn-big.v-btn {
796
+ # height: 48px !important;
797
+ # font-size: 18px !important;
798
+ # font-weight: 600 !important;
799
+ # letter-spacing: 1.2px;
800
+ # text-transform: none !important;
801
+ # border-radius: 999px !important;
802
+ # }
803
+ # .pfm-viewer { min-height: 420px; height: 650px !important; border-radius: 16px; }
804
+ # """)
805
+ html.Link(
806
+ rel="stylesheet",
807
+ href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap",
808
+ )
809
+
810
+ with v3.VThemeProvider(theme=("theme_dark ? 'dark' : 'light'",)):
811
+ with v3.VContainer(
812
+ fluid=True,
813
+ class_="pa-6",
814
+ style=(
815
+ "max-width: 2200px;" # max width of content
816
+ "margin: 8px auto 16px auto;" # top / left-right / bottom
817
+ "padding: 0 8px;" # inner left/right padding
818
+ "box-sizing: border-box;"
819
+ "background: rgba(255,255,255,0.02);"
820
+ "border-radius: 16px;"
821
+ ),
822
+ ):
823
+
824
+
825
+ # 1) Physics Application
826
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
827
+ html.Div(
828
+ "🧪 <b>Physics Application</b>",
829
+ style="font-size:28px;font-weight:700;letter-spacing:1.1px;margin-bottom:10px;",
830
+ )
831
+ html.Div(
832
+ "Select the type of analysis",
833
+ style="font-size:24px;opacity:.82;margin-bottom:18px;",
834
+ )
835
+ toggle = v3.VBtnToggle(
836
+ v_model=("analysis_type", self.state.analysis_type),
837
+ class_="mt-1",
838
+ mandatory=True,
839
+ rounded=True,
840
+ )
841
+ # with toggle:
842
+ # for at in ANALYSIS_TYPE:
843
+ # v3.VBtn(
844
+ # at,
845
+ # value=at,
846
+ # variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
847
+ # class_="mr-2 pfm-toggle-xxl",
848
+ # )
849
+
850
+ with toggle:
851
+ for at in ANALYSIS_TYPE:
852
+ v3.VBtn(
853
+ at,
854
+ value=at,
855
+ variant=(f"analysis_type===`{at}` ? 'elevated' : 'tonal'"),
856
+ class_="mr-2 pfm-toggle-xxl",
857
+ style=(
858
+ "font-size:18px;"
859
+ "font-weight:800;"
860
+ "letter-spacing:0.4px;"
861
+ "text-transform:none;"
862
+ ),
863
+ )
864
+
865
+ # 2) Dataset + Variable
866
+ with v3.VRow(dense=True, class_="mb-3"):
867
+ with v3.VCol(cols=6):
868
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
869
+ html.Div(
870
+ "🧩 Sub Application",
871
+ style="font-weight:700;font-size:24px;margin-bottom:14px;",
872
+ )
873
+ v3.VSelect(
874
+ v_model=("dataset", self.state.dataset),
875
+ items=("dataset_choices", self.state.dataset_choices),
876
+ hide_details=True,
877
+ density="comfortable",
878
+ style=(
879
+ "font-size:24px;"
880
+ "font-weight:800;"
881
+ "height:56px;"
882
+ "display:flex;"
883
+ "align-items:center;"
884
+ ),
885
+ class_="pfm-big-select-subapp",
886
+ menu_props={"content_class": "pfm-subapp-list"}, # <— key for dropdown items
887
+ )
888
+ # v3.VSelect(
889
+ # v_model=("dataset", self.state.dataset),
890
+ # items=("dataset_choices", self.state.dataset_choices),
891
+ # hide_details=True,
892
+ # density="comfortable",
893
+ # class_="pfm-big-select-subapp pfm-subapp-list",
894
+ # style="font-size:21px;",
895
+ # )
896
+ with v3.VCol(cols=6):
897
+ with v3.VSheet(class_="pa-6 pfm-card", rounded=True, elevation=3):
898
+ html.Div(
899
+ "📊 Variable to Predict",
900
+ style="font-weight:700;font-size:20px;margin-bottom:14px;",
901
+ )
902
+
903
+ v3.VSelect(
904
+ v_model=("variable", self.state.variable),
905
+ items=("variable_choices", self.state.variable_choices),
906
+ hide_details=True,
907
+ density="comfortable",
908
+ class_="pfm-var-select",
909
+ style=(
910
+ "font-size:20px;"
911
+ "font-weight:800;"
912
+ "height:56px;"
913
+ "display:flex;"
914
+ "align-items:center;"
915
+ ),
916
+ menu_props={"content_class": "pfm-var-list"},
917
+ )
918
+ # v3.VSelect(
919
+ # v_model=("variable", self.state.variable),
920
+ # items=("variable_choices", self.state.variable_choices),
921
+ # hide_details=True,
922
+ # density="comfortable",
923
+ # style="font-size:16px;",
924
+ # )
925
+
926
+ # 3) Boundary Conditions
927
+ with v3.VSheet(class_="pa-6 mb-4 pfm-card", rounded=True, elevation=3):
928
+ html.Div(
929
+ "🧱 Boundary Conditions",
930
+ style="font-weight:700;font-size:22px;margin-bottom:16px;",
931
+ )
932
+
933
+ # two columns: Left = velocity controls, Right = reference text
934
+ with v3.VRow(class_="align-start", dense=True):
935
+ # ---- LEFT: velocity slider / readout ----
936
+ with v3.VCol(cols=7, class_="pfm-vel"):
937
+ html.Div(
938
+ "🚗 Velocity (mph)",
939
+ class_="pfm-vel-title",
940
+ style="margin-bottom:8px;font-weight:800;font-size:21px;letter-spacing:.3px;",
941
+ )
942
+ html.Div(
943
+ "Set the inlet velocity in miles per hour",
944
+ class_="pfm-vel-sub",
945
+ style="margin-bottom:10px;font-size:20px;opacity:.95;",
946
+ )
947
+ v3.VSlider(
948
+ v_model=("velocity_mph", self.state.velocity_mph),
949
+ min=30.0, max=80.0, step=0.1,
950
+ thumb_label=True,
951
+ v_if=("show_velocity",),
952
+ style="height:54px;margin-top:12px;max-width:540px;",
953
+ class_="mt-3 mb-3 pfm-vel-slider",
954
+ )
955
+ html.Div(
956
+ "{{ velocity_mph.toFixed(0) }} / 80 "
957
+ "<span style='opacity:.95'>"
958
+ "({{ (velocity_mph * 0.44704).toFixed(2) }} m/s)</span>",
959
+ v_if=("show_velocity",),
960
+ class_="pfm-vel-readout",
961
+ style="font-size:18px;font-weight:900;letter-spacing:.3px;margin-top:6px;",
962
+ )
963
+
964
+ # ---- RIGHT: fixed reference values (HTML from bc_text_right / bc_text_left) ----
965
+ with v3.VCol(cols=5, class_="pfm-bc-right"):
966
+ html.Div(
967
+ v_html=("bc_text_html", ""),
968
+ style=(
969
+ "margin-top:6px;"
970
+ "font-size:18px;"
971
+ "line-height:1.7;"
972
+ "min-width:260px;"
973
+ "max-width:360px;"
974
+ "text-align:left;"
975
+ ),
976
+ )
977
+
978
+
979
+
980
+ # 4) Two viewers
981
+ with v3.VRow(style="margin-top: 24px;"):
982
+ # LEFT = upload
983
+ with v3.VCol(cols=6):
984
+ with v3.VRow(class_="align-center justify-space-between mb-2"):
985
+ html.Div(
986
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📤 Input Geometry</span>",
987
+ )
988
+
989
+ # ✅ working gear menu
990
+ with v3.VMenu(
991
+ location="bottom end",
992
+ close_on_content_click=False,
993
+ offset="4 8",
994
+ ):
995
+ # activator slot MUST expose { props } and we MUST bind them to the button
996
+ with v3.Template(v_slot_activator="{ props }"):
997
+ with v3.VBtn(
998
+ icon=True,
999
+ variant="text",
1000
+ density="comfortable",
1001
+ style="min-width:32px;",
1002
+ v_bind="props", # 👈 this is the key
1003
+ ):
1004
+ v3.VIcon("mdi-cog", size="22")
1005
+
1006
+ # menu content
1007
+ with v3.VCard(class_="pa-4", style="min-width: 280px;"):
1008
+ html.Div("<b>Mesh decimation</b>", class_="mb-3", style="font-size:14px;")
1009
+
1010
+ v3.VSwitch(
1011
+ v_model=("decim_enable",),
1012
+ label="Enable decimation",
1013
+ inset=True,
1014
+ hide_details=True,
1015
+ class_="mb-4",
1016
+ )
1017
+
1018
+ html.Div(
1019
+ "Target reduction (fraction of faces to remove)",
1020
+ class_="mb-1",
1021
+ style="font-size:12px;color:#9ca3af;",
1022
+ )
1023
+ v3.VSlider(
1024
+ v_model=("decim_target",),
1025
+ min=0.0,
1026
+ max=0.999,
1027
+ step=0.001,
1028
+ hide_details=True,
1029
+ class_="mb-2",
1030
+ )
1031
+ html.Div("{{ decim_target.toFixed(3) }}", style="font-size:11px;", class_="mb-3")
1032
+
1033
+ with v3.VRow(dense=True, class_="mb-3"):
1034
+ with v3.VCol(cols=6):
1035
+ html.Div("Min faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1036
+ v3.VTextField(
1037
+ v_model=("decim_min_faces",),
1038
+ type="number",
1039
+ density="compact",
1040
+ hide_details=True,
1041
+ )
1042
+ with v3.VCol(cols=6):
1043
+ html.Div("Max faces", style="font-size:11px;color:#9ca3af;", class_="mb-1")
1044
+ v3.VTextField(
1045
+ v_model=("decim_max_faces",),
1046
+ type="number",
1047
+ density="compact",
1048
+ hide_details=True,
1049
+ )
1050
+
1051
+ v3.VBtn(
1052
+ "Apply to current mesh",
1053
+ block=True,
1054
+ color="primary",
1055
+ class_="mt-2",
1056
+ click=self.decimate_again,
1057
+ )
1058
+
1059
+ v3.VBtn(
1060
+ "Reset to original mesh",
1061
+ block=True,
1062
+ variant="tonal",
1063
+ class_="mt-2",
1064
+ click=self.reset_mesh, # 👈 will call the controller you added
1065
+ )
1066
+
1067
+
1068
+ v3.VFileInput(
1069
+ label="Select 3D File",
1070
+ style="font-size:17px;padding:12px;height:50px;margin-bottom:20px;",
1071
+ multiple=False,
1072
+ show_size=True,
1073
+ accept=".stl,.vtk,.vtp,.ply,.obj,.vtu,.glb",
1074
+ v_model=("upload", None),
1075
+ clearable=True,
1076
+ )
1077
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1078
+ self.view_geom = VtkRemoteView(
1079
+ self.rw_geom,
1080
+ interactive=True,
1081
+ interactive_ratio=1,
1082
+ server=self.server,
1083
+ )
1084
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1085
+ rounded=True, elevation=3):
1086
+ html.Div("<b>Upload</b>", style="font-size:18px;")
1087
+
1088
+ # progress bar: only while uploading
1089
+ v3.VProgressLinear(
1090
+ v_model=("pm_upload", 0),
1091
+ height=22,
1092
+ style="margin-top:10px;margin-bottom:10px;",
1093
+ color="primary",
1094
+ rounded=True,
1095
+ v_show=("is_uploading",), # 👈 bar disappears after upload
1096
+ )
1097
+
1098
+ # text: percentage + time + message, only while uploading
1099
+ html.Div(
1100
+ "{{ pm_upload }}% — {{ pm_elapsed_upload.toFixed(2) }}s — {{ upload_msg }}",
1101
+ style="font-size:14px;",
1102
+ v_show=("is_uploading",), # 👈 hide text after completion
1103
+ )
1104
+
1105
+ v3.VBtn(
1106
+ "🗑️ CLEAR",
1107
+ block=True,
1108
+ variant="tonal",
1109
+ class_="mt-3 pfm-btn-big",
1110
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1111
+ click=self.clear,
1112
+ )
1113
+ # RIGHT = prediction
1114
+ with v3.VCol(cols=6):
1115
+ html.Div(
1116
+ "<span style='font-size:26px;font-weight:700;letter-spacing:1.1px;'>📈 Prediction Results</span>",
1117
+ style="margin-bottom:10px;",
1118
+ )
1119
+ html.Div(
1120
+ v_html=("stats_html",),
1121
+ class_="mb-3",
1122
+ style="font-size:20px;line-height:1.4;",
1123
+ )
1124
+ # v3.VProgressLinear(
1125
+ # v_model=("predict_progress", 0),
1126
+ # height=22,
1127
+ # style="margin-top:6px;margin-bottom:12px;",
1128
+ # color="primary",
1129
+ # rounded=True,
1130
+ # indeterminate=("predict_progress < 10",),
1131
+ # v_show=("is_predicting",),
1132
+ # )
1133
+ # html.Div(
1134
+ # "Predicting: {{ predict_progress }}%",
1135
+ # style="font-size:14px;margin-bottom:10px;",
1136
+ # v_show=("is_predicting",),
1137
+ # )
1138
+ with v3.VSheet(height=620, rounded=True, class_="pa-0 pfm-viewer"):
1139
+ self.view_pred = VtkRemoteView(
1140
+ self.rw_pred,
1141
+ interactive=True,
1142
+ interactive_ratio=1,
1143
+ server=self.server,
1144
+ )
1145
+ with v3.VSheet(class_="mt-3 pa-4 pfm-card pfm-progress",
1146
+ rounded=True, elevation=3):
1147
+ html.Div("<b>Inference</b>", style="font-size:18px;")
1148
+
1149
+ # 🔴 OLD: v_model=("predict_progress", 0), indeterminate=...
1150
+ # 🟢 NEW: use pm_infer and a normal (non-indeterminate) bar
1151
+ v3.VProgressLinear(
1152
+ v_model=("pm_infer", 0),
1153
+ height=22,
1154
+ style="margin-top:6px;margin-bottom:12px;",
1155
+ color="success",
1156
+ rounded=True,
1157
+ indeterminate=("predict_progress <= 0",),
1158
+ v_show=("is_predicting",), # 👈 bar only visible while predicting
1159
+ )
1160
+
1161
+ # text line: % + elapsed time + current stage message
1162
+ html.Div(
1163
+ "{{ pm_infer }}% — {{ pm_elapsed_infer.toFixed(2) }}s — {{ predict_msg }}",
1164
+ style="font-size:14px;margin-bottom:10px;",
1165
+ # ❗ if you want the *text* to also disappear at the end, keep v_show;
1166
+ # if you want the final "✅ Prediction complete — 1.23s" to stay, REMOVE v_show
1167
+ v_show=("is_predicting",),
1168
+ )
1169
+ v3.VBtn(
1170
+ "🚀 PREDICT",
1171
+ block=True,
1172
+ color="primary",
1173
+ class_="mt-3 pfm-btn-big",
1174
+ style="--v-btn-height:38px;--v-btn-size:1.35rem;padding:0 32px;",
1175
+ click=self.predict,
1176
+ )
1177
+
1178
+ layout.on_ready = self._first_paint
1179
+
1180
+ def _first_paint(self, **_):
1181
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1182
+ try:
1183
+ rw.Render()
1184
+ except Exception:
1185
+ pass
1186
+ view.update()
1187
+
1188
+ # ---------------------------------------------------------
1189
+ # UPLOAD (async)
1190
+ # ---------------------------------------------------------
1191
+ def _write_upload_to_disk(self, payload) -> str:
1192
+ if payload is None:
1193
+ raise ValueError("No file payload")
1194
+ if isinstance(payload, (list, tuple)):
1195
+ payload = payload[0]
1196
+ if isinstance(payload, str):
1197
+ return payload
1198
+ if not isinstance(payload, dict):
1199
+ raise ValueError(f"Unsupported payload: {type(payload)}")
1200
+ if payload.get("path"):
1201
+ return payload["path"]
1202
+ name = payload.get("name") or "upload"
1203
+ content = payload.get("content")
1204
+ if isinstance(content, str) and content.startswith("data:"):
1205
+ content = content.split(",", 1)[1]
1206
+ raw = base64.b64decode(content) if isinstance(content, str) else bytes(content)
1207
+ os.makedirs(GEOM_DIR, exist_ok=True)
1208
+ file_path = os.path.join(GEOM_DIR, name)
1209
+ with open(file_path, "wb") as f:
1210
+ f.write(raw)
1211
+ return file_path
1212
+
1213
+ def _pre_upload_spinner_loop(self):
1214
+ s = self.state
1215
+ phase = 1
1216
+ while self._pre_upload_on and not self._upload_actual_started and s.is_uploading:
1217
+ s.pm_upload = max(1, min(9, phase))
1218
+ s.upload_msg = "Initializing upload..."
1219
+ try:
1220
+ self.server.state.flush()
1221
+ except Exception:
1222
+ pass
1223
+ phase = 1 if phase >= 9 else phase + 1
1224
+ time.sleep(0.15)
1225
+
1226
+ def _start_pre_upload_spinner(self):
1227
+ if self._pre_upload_thread and self._pre_upload_thread.is_alive():
1228
+ return
1229
+ self._upload_actual_started = False
1230
+ self._pre_upload_on = True
1231
+ self._pre_upload_thread = threading.Thread(
1232
+ target=self._pre_upload_spinner_loop, daemon=True
1233
+ )
1234
+ self._pre_upload_thread.start()
1235
+
1236
+ def _stop_pre_upload_spinner(self):
1237
+ self._pre_upload_on = False
1238
+ self._pre_upload_thread = None
1239
+
1240
+ async def _fake_upload_bump(self, stop_event: asyncio.Event):
1241
+ s = self.state
1242
+ while not stop_event.is_set() and s.pm_upload < 9:
1243
+ s.pm_upload += 1
1244
+ await self._flush_async()
1245
+ await asyncio.sleep(0.05)
1246
+
1247
+ async def _upload_worker_async(self, upload):
1248
+ s = self.state
1249
+ loop = self._ensure_loop()
1250
+ t0 = time.time()
1251
+
1252
+ s.is_uploading = True
1253
+ s.upload_msg = "Starting upload..."
1254
+ s.pm_elapsed_upload = 0.0
1255
+
1256
+ s.pm_upload = 5
1257
+ self.server.state.flush()
1258
+ await asyncio.sleep(0)
1259
+
1260
+ fake_stop = asyncio.Event()
1261
+ fake_task = asyncio.create_task(self._fake_upload_bump(fake_stop))
1262
+
1263
+ try:
1264
+ self._upload_actual_started = True
1265
+ self._stop_pre_upload_spinner()
1266
+
1267
+ if not fake_stop.is_set():
1268
+ fake_stop.set()
1269
+ try:
1270
+ await fake_task
1271
+ except asyncio.CancelledError:
1272
+ pass
1273
+
1274
+ s.upload_msg = "Writing file to disk..."
1275
+ s.pm_upload = 10
1276
+ s.pm_elapsed_upload = time.time() - t0
1277
+ await self._flush_async()
1278
+ file_path = await loop.run_in_executor(None, self._write_upload_to_disk, upload)
1279
+
1280
+ s.upload_msg = "Reading mesh..."
1281
+ s.pm_upload = 20
1282
+ s.pm_elapsed_upload = time.time() - t0
1283
+ await self._flush_async()
1284
+ mesh = await loop.run_in_executor(None, pv.read, file_path)
1285
+
1286
+ # 3) decimation (auto first)
1287
+ try:
1288
+ nf = poly_count(mesh)
1289
+ except Exception:
1290
+ nf = mesh.n_cells
1291
+
1292
+ auto_tr = float(auto_target_reduction(nf))
1293
+
1294
+ # reflect auto in UI
1295
+ s.decim_target = auto_tr
1296
+ s.decim_min_faces = 5000 # <= allow decimation even for 27k faces
1297
+ s.decim_max_faces = int(1e7)
1298
+
1299
+ target = auto_tr
1300
+ min_faces = 5000
1301
+ max_faces = int(1e7)
1302
+
1303
+ # user override
1304
+ if self.state.decim_enable:
1305
+ target = float(self.state.decim_target or 0.0)
1306
+ min_faces = int(self.state.decim_min_faces or 5000)
1307
+ max_faces = int(self.state.decim_max_faces or 1e7)
1308
+
1309
+ if target > 0.0:
1310
+ s.upload_msg = f"Decimating mesh ({target:.3f})..."
1311
+ s.pm_upload = max(s.pm_upload, 45)
1312
+ s.pm_elapsed_upload = time.time() - t0
1313
+ await self._flush_async()
1314
+
1315
+ dec_cfg = {
1316
+ "enabled": True,
1317
+ "method": "pro",
1318
+ "target_reduction": target,
1319
+ "min_faces": min_faces,
1320
+ "max_faces": max_faces,
1321
+ }
1322
+ mesh = await loop.run_in_executor(None, decimate_mesh, mesh, dec_cfg)
1323
+
1324
+ # 4) normals + save
1325
+ s.upload_msg = "Preparing geometry..."
1326
+ s.pm_upload = 75
1327
+ s.pm_elapsed_upload = time.time() - t0
1328
+ await self._flush_async()
1329
+
1330
+ def _normals_and_save(m):
1331
+ m_fixed = m.compute_normals(
1332
+ consistent_normals=True,
1333
+ auto_orient_normals=True,
1334
+ point_normals=True,
1335
+ cell_normals=False,
1336
+ inplace=False,
1337
+ )
1338
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1339
+ m_fixed.save(geom_path_)
1340
+ return geom_path_, m_fixed
1341
+
1342
+ geom_path, mesh_fixed = await loop.run_in_executor(None, _normals_and_save, mesh)
1343
+
1344
+ # 5) update viewer
1345
+ self.ren_geom.RemoveAllViewProps()
1346
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1347
+ self.ren_geom.ResetCamera()
1348
+ try:
1349
+ self.rw_geom.Render()
1350
+ except Exception:
1351
+ pass
1352
+ self.view_geom.update()
1353
+ # GEOMETRY_CACHE.current_mesh = mesh_fixed
1354
+ GEOMETRY_CACHE.original_mesh = mesh_fixed.copy(deep=True)
1355
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1356
+
1357
+ s.upload_msg = "✅ Geometry ready."
1358
+ s.pm_upload = 100
1359
+ s.pm_elapsed_upload = time.time() - t0
1360
+ await self._flush_async()
1361
+
1362
+ except Exception as e:
1363
+ s.upload_msg = f"❌ Upload failed: {e}"
1364
+ s.pm_upload = 0
1365
+ s.pm_elapsed_upload = time.time() - t0
1366
+ await self._flush_async()
1367
+ finally:
1368
+ s.is_uploading = False
1369
+ s.pm_elapsed_upload = time.time() - t0
1370
+ await self._flush_async()
1371
+
1372
+ if not fake_stop.is_set():
1373
+ fake_stop.set()
1374
+ if not fake_task.done():
1375
+ fake_task.cancel()
1376
+ try:
1377
+ await fake_task
1378
+ except Exception:
1379
+ pass
1380
+
1381
+ @change("upload")
1382
+ def _on_upload_change(self, upload, **_):
1383
+ if not upload:
1384
+ return
1385
+ self._run_coro(self._upload_worker_async(upload))
1386
+
1387
+ def decimate_again(self):
1388
+ self._run_coro(self._decimate_again_async())
1389
+
1390
+ async def _decimate_again_async(self):
1391
+ s = self.state
1392
+ loop = self._ensure_loop()
1393
+
1394
+ if GEOMETRY_CACHE.current_mesh is None:
1395
+ # nothing to decimate
1396
+ s.upload_msg = "No mesh to re-decimate"
1397
+ await self._flush_async()
1398
+ return
1399
+
1400
+ # --- start "upload-like" progress for manual decimation ---
1401
+ t0 = time.time()
1402
+ s.is_uploading = True
1403
+ s.pm_upload = 5
1404
+ s.pm_elapsed_upload = 0.0
1405
+ s.upload_msg = "Starting mesh re-decimation..."
1406
+ await self._flush_async()
1407
+
1408
+ try:
1409
+ # read parameters from UI
1410
+ try:
1411
+ target = float(s.decim_target)
1412
+ except Exception:
1413
+ target = 0.0
1414
+
1415
+ try:
1416
+ min_faces = int(s.decim_min_faces)
1417
+ except Exception:
1418
+ min_faces = 5000
1419
+
1420
+ try:
1421
+ max_faces = int(s.decim_max_faces)
1422
+ except Exception:
1423
+ max_faces = int(1e7)
1424
+
1425
+ if (not s.decim_enable) or target <= 0.0:
1426
+ s.upload_msg = "Decimation disabled"
1427
+ s.pm_upload = 0
1428
+ s.pm_elapsed_upload = time.time() - t0
1429
+ await self._flush_async()
1430
+ return
1431
+
1432
+ # --- bump before heavy decimation call ---
1433
+ s.upload_msg = f"Re-decimating mesh ({target:.3f})..."
1434
+ s.pm_upload = 25
1435
+ s.pm_elapsed_upload = time.time() - t0
1436
+ await self._flush_async()
1437
+
1438
+ dec_cfg = {
1439
+ "enabled": True,
1440
+ "method": "pro",
1441
+ "target_reduction": target,
1442
+ "min_faces": min_faces,
1443
+ "max_faces": max_faces,
1444
+ }
1445
+
1446
+ # heavy work on executor
1447
+ mesh = await loop.run_in_executor(
1448
+ None, decimate_mesh, GEOMETRY_CACHE.current_mesh, dec_cfg
1449
+ )
1450
+
1451
+ # --- normals + save ---
1452
+ s.upload_msg = "Recomputing normals & saving..."
1453
+ s.pm_upload = 70
1454
+ s.pm_elapsed_upload = time.time() - t0
1455
+ await self._flush_async()
1456
+
1457
+ def _normals_and_save(m):
1458
+ m_fixed = m.compute_normals(
1459
+ consistent_normals=True,
1460
+ auto_orient_normals=True,
1461
+ point_normals=True,
1462
+ cell_normals=False,
1463
+ inplace=False,
1464
+ )
1465
+ geom_path_ = os.path.join(GEOM_DIR, "geometry.stl")
1466
+ m_fixed.save(geom_path_)
1467
+ return geom_path_, m_fixed
1468
+
1469
+ geom_path, mesh_fixed = await loop.run_in_executor(
1470
+ None, _normals_and_save, mesh
1471
+ )
1472
+
1473
+ # --- update viewer ---
1474
+ s.upload_msg = "Updating viewer..."
1475
+ s.pm_upload = 90
1476
+ s.pm_elapsed_upload = time.time() - t0
1477
+ await self._flush_async()
1478
+
1479
+ self.ren_geom.RemoveAllViewProps()
1480
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1481
+ self.ren_geom.ResetCamera()
1482
+ try:
1483
+ self.rw_geom.Render()
1484
+ except Exception:
1485
+ pass
1486
+ self.view_geom.update()
1487
+
1488
+ GEOMETRY_CACHE.current_mesh = mesh_fixed
1489
+
1490
+ # --- final bump ---
1491
+ s.upload_msg = "✅ Re-decimated"
1492
+ s.pm_upload = 100
1493
+ s.pm_elapsed_upload = time.time() - t0
1494
+ await self._flush_async()
1495
+
1496
+ except Exception as e:
1497
+ s.upload_msg = f"❌ Re-decimation failed: {e}"
1498
+ s.pm_upload = 0
1499
+ s.pm_elapsed_upload = time.time() - t0
1500
+ await self._flush_async()
1501
+ finally:
1502
+ # hide bar + text after we’re done
1503
+ s.is_uploading = False
1504
+ await self._flush_async()
1505
+
1506
+
1507
+ def reset_mesh(self):
1508
+ self._run_coro(self._reset_mesh_async())
1509
+
1510
+ async def _reset_mesh_async(self):
1511
+ s = self.state
1512
+
1513
+ if GEOMETRY_CACHE.original_mesh is None:
1514
+ s.upload_msg = "No original mesh to reset to"
1515
+ await self._flush_async()
1516
+ return
1517
+
1518
+ # use the saved original
1519
+ orig = GEOMETRY_CACHE.original_mesh
1520
+
1521
+ # save it again as current
1522
+ GEOMETRY_CACHE.current_mesh = orig
1523
+
1524
+ # write to disk (so the STL on disk matches the viewer)
1525
+ geom_path = os.path.join(GEOM_DIR, "geometry.stl")
1526
+ orig.save(geom_path)
1527
+
1528
+ # update viewer
1529
+ self.ren_geom.RemoveAllViewProps()
1530
+ self.ren_geom.AddActor(make_actor_from_stl(geom_path))
1531
+ self.ren_geom.ResetCamera()
1532
+ try:
1533
+ self.rw_geom.Render()
1534
+ except Exception:
1535
+ pass
1536
+ self.view_geom.update()
1537
+
1538
+ s.upload_msg = "↩️ Reset to original mesh"
1539
+ await self._flush_async()
1540
+
1541
+ # ---------------------------------------------------------
1542
+ # prediction
1543
+ # ---------------------------------------------------------
1544
+ def _start_infer_heartbeat(self):
1545
+ if self._infer_thread and self._infer_thread.is_alive():
1546
+ return
1547
+
1548
+ def loop_fn():
1549
+ while self._infer_heartbeat_on:
1550
+ if self.state.is_predicting and self._predict_t0 is not None:
1551
+ self.state.pm_elapsed_infer = max(0.0, time.time() - self._predict_t0)
1552
+ try:
1553
+ self.server.state.flush()
1554
+ except Exception:
1555
+ pass
1556
+ time.sleep(0.12)
1557
+
1558
+ self._infer_heartbeat_on = True
1559
+ self._infer_thread = threading.Thread(target=loop_fn, daemon=True)
1560
+ self._infer_thread.start()
1561
+
1562
+ def _stop_infer_heartbeat(self):
1563
+ self._infer_heartbeat_on = False
1564
+ self._infer_thread = None
1565
+
1566
+ async def _predict_worker_async(self):
1567
+ s = self.state
1568
+ loop = self._ensure_loop()
1569
+ t0 = time.time()
1570
+
1571
+ if GEOMETRY_CACHE.current_mesh is None:
1572
+ s.predict_msg = "❌ Please upload geometry first"
1573
+ s.is_predicting = False
1574
+ await self._flush_async()
1575
+ return
1576
+
1577
+ s.is_predicting = True
1578
+ s.predict_progress = 1
1579
+ s.pm_infer = 1
1580
+ s.predict_msg = "Preparing inference..."
1581
+ self._predict_t0 = time.time()
1582
+ self._start_infer_heartbeat()
1583
+ await self._flush_async()
1584
+
1585
+ try:
1586
+ dataset = s.dataset
1587
+ variable = s.variable
1588
+ boundary = (
1589
+ {"freestream_velocity": mph_to_ms(s.velocity_mph)}
1590
+ if dataset == "Incompressible flow over car"
1591
+ else None
1592
+ )
1593
+
1594
+ s.predict_msg = "Loading model/checkpoint..."
1595
+ s.predict_progress = 5
1596
+ s.pm_infer = 5
1597
+ await self._flush_async()
1598
+
1599
+ cfg, model, device, _ = await loop.run_in_executor(
1600
+ None, MODEL_STORE.get, dataset, variable, None
1601
+ )
1602
+
1603
+ s.predict_msg = "Processing mesh for inference..."
1604
+ s.predict_progress = 35
1605
+ s.pm_infer = 35
1606
+ await self._flush_async()
1607
+
1608
+ def _run_full():
1609
+ return run_inference_fast(
1610
+ dataset,
1611
+ variable,
1612
+ boundary_conditions=boundary,
1613
+ progress_cb=None,
1614
+ )
1615
+ viz = await loop.run_in_executor(None, _run_full)
1616
+
1617
+ s.predict_msg = "Preparing visualization..."
1618
+ s.predict_progress = 85
1619
+ s.pm_infer = 85
1620
+ await self._flush_async()
1621
+
1622
+ stl_path = os.path.join(GEOM_DIR, "geometry.stl")
1623
+ vmin = float(np.min(viz["pred"]))
1624
+ vmax = float(np.max(viz["pred"]))
1625
+
1626
+ if os.path.exists(stl_path):
1627
+ _tmp_trimesh, vmin, vmax = create_visualization_stl(viz, stl_path)
1628
+ lut = build_jet_lut(vmin, vmax)
1629
+ colored_actor = color_actor_with_scalars_from_prediction(
1630
+ stl_path,
1631
+ viz["points"],
1632
+ viz["pred"],
1633
+ variable,
1634
+ vmin,
1635
+ vmax,
1636
+ lut=lut,
1637
+ )
1638
+ self.ren_pred.AddActor(colored_actor)
1639
+
1640
+ units = {
1641
+ "pressure": "Pa",
1642
+ "x_velocity": "m/s",
1643
+ "y_velocity": "m/s",
1644
+ "z_velocity": "m/s",
1645
+ }.get(variable, "")
1646
+ title = f"{variable} ({units})" if units else variable
1647
+ self.scalar_bar = add_or_update_scalar_bar(
1648
+ self.ren_pred, lut, title, label_fmt="%-0.2f", n_labels=8
1649
+ )
1650
+
1651
+ src_cam = self.ren_geom.GetActiveCamera()
1652
+ dst_cam = self.ren_pred.GetActiveCamera()
1653
+ if src_cam is not None and dst_cam is not None:
1654
+ dst_cam.SetPosition(src_cam.GetPosition())
1655
+ dst_cam.SetFocalPoint(src_cam.GetFocalPoint())
1656
+ dst_cam.SetViewUp(src_cam.GetViewUp())
1657
+ dst_cam.SetParallelScale(src_cam.GetParallelScale())
1658
+ cr = src_cam.GetClippingRange()
1659
+ dst_cam.SetClippingRange(cr)
1660
+
1661
+ try:
1662
+ self.rw_pred.Render()
1663
+ except Exception:
1664
+ pass
1665
+ self.view_pred.update()
1666
+
1667
+ raw_vmin = float(np.min(viz["pred"]))
1668
+ raw_vmax = float(np.max(viz["pred"]))
1669
+
1670
+ s.stats_html = (
1671
+ f"<b>{variable} min:</b>{raw_vmin:.3e} "
1672
+ f"<b>max:</b> {raw_vmax:.3e} "
1673
+ f"<b>Confidence:</b> {viz['confidence_score']:.4f}"
1674
+ )
1675
+
1676
+ s.predict_msg = "✅ Prediction complete."
1677
+ s.predict_progress = 100
1678
+ s.pm_infer = 100
1679
+ s.predict_elapsed = time.time() - t0
1680
+ s.pm_elapsed_infer = s.predict_elapsed
1681
+ await self._flush_async()
1682
+
1683
+ except Exception as e:
1684
+ s.predict_msg = f"❌ Prediction failed: {e}"
1685
+ s.predict_progress = 0
1686
+ s.pm_infer = 0
1687
+ await self._flush_async()
1688
+ finally:
1689
+ s.is_predicting = False
1690
+ self._stop_infer_heartbeat()
1691
+ await self._flush_async()
1692
+
1693
+ @time_function("Inference and Visualization")
1694
+ def predict(self, *_):
1695
+ self._run_coro(self._predict_worker_async())
1696
+
1697
+ # ---------------------------------------------------------
1698
+ # dataset wiring
1699
+ # ---------------------------------------------------------
1700
+ @change("analysis_type")
1701
+ def _on_analysis_type_change(self, analysis_type=None, **_):
1702
+ ds_list = ANALYSIS_TYPE_MAPPING.get(analysis_type or "", [])
1703
+ default_ds = ds_list[0] if ds_list else None
1704
+ self.state.dataset_choices = ds_list
1705
+ if default_ds and self.state.dataset != default_ds:
1706
+ self.state.dataset = default_ds
1707
+ elif self.state.dataset:
1708
+ self._apply_dataset(self.state.dataset)
1709
+
1710
+ @change("dataset")
1711
+ def _on_dataset_change(self, dataset=None, **_):
1712
+ if not dataset:
1713
+ return
1714
+ self._apply_dataset(dataset)
1715
+
1716
+ def _apply_dataset(self, ds: str):
1717
+ s = self.state
1718
+ opts = variables_for(ds) if ds else []
1719
+ s.variable_choices = opts
1720
+ s.variable = opts[0] if opts else None
1721
+
1722
+ s.show_velocity = (ds == "Incompressible flow over car")
1723
+ s.is_plane = (ds == "Compressible flow over plane")
1724
+
1725
+ s.bc_text = get_boundary_conditions_text(ds)
1726
+ s.bc_left = bc_text_left(ds)
1727
+ s.bc_right = bc_text_right(ds)
1728
+ s.bc_text_html = s.bc_right or md_to_html(s.bc_text)
1729
+
1730
+ # ---------------------------------------------------------
1731
+ # clear
1732
+ # ---------------------------------------------------------
1733
+ def clear(self, *_):
1734
+ for d in [GEOM_DIR, SOLN_DIR]:
1735
+ if os.path.exists(d):
1736
+ shutil.rmtree(d)
1737
+ os.makedirs(d, exist_ok=True)
1738
+ s = self.state
1739
+ s.stats_html = "🧹 Cleared. Upload again."
1740
+ s.is_uploading = False
1741
+ s.pm_upload = 0
1742
+ s.upload_msg = ""
1743
+ s.pm_elapsed_upload = 0.0
1744
+ s.is_predicting = False
1745
+ s.predict_progress = 0
1746
+ s.predict_msg = ""
1747
+ s.pm_infer = 0
1748
+ s.pm_elapsed_infer = 0.0
1749
+ self.ren_geom.RemoveAllViewProps()
1750
+ self.ren_pred.RemoveAllViewProps()
1751
+ for rw, view in ((self.rw_geom, self.view_geom), (self.rw_pred, self.view_pred)):
1752
+ try:
1753
+ rw.Render()
1754
+ except Exception:
1755
+ pass
1756
+ view.update()
1757
+
1758
+ # ---------- main ----------
1759
+ def main():
1760
+ app = PFMDemo()
1761
+ app.server.controller.add("decimate_again", app.decimate_again)
1762
+ app.server.controller.add("reset_mesh", app.reset_mesh)
1763
+ app.server.start(7872)
1764
+
1765
+ if __name__ == "__main__":
1766
+ main()
configs/DrivAerML/config.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: DrivAerML
3
+ data_dir: /raid/ansysai/ajoglekar/Data/drivaerml_processed_data_surface
4
+ splits_file: datasets/DrivAerML/train_val_test_splits.json
5
+ # num_points: 40000
6
+ num_points: 16384
7
+ num_workers: 1
8
+ presampled_exists: True
9
+ # presampled_data_path: ../Data/drivaerml_processed_data_surface/presampled_val_test_data_40k # Directory will be created if presampled_exists is False
10
+ presampled_data_path: /raid/ansysai/ajoglekar/Data/drivaerml_processed_data_surface/presampled_val_test_data_16k # Directory will be created if presampled_exists is False
11
+
12
+ #model
13
+ indim: 3
14
+ outdim: 1
15
+ model: ansysLPFMs
16
+ hidden_dim: 256
17
+ n_heads: 8
18
+ n_decoder: 8
19
+ mlp_ratio: 2
20
+
21
+
22
+ #training
23
+ val_iter: 1
24
+ lr: 0.001
25
+ batch_size: 1
26
+ epochs: 500
27
+ optimizer:
28
+ type: AdamW
29
+ scheduler: OneCycleLR
30
+ # scheduler: LinearWarmupCosineAnnealingLR
31
+ num_processes: 1
32
+ max_grad_norm: 1.0
33
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
34
+ eval: True
35
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
36
+
37
+ #logging
38
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
39
+ test_name: "ptautocastwithloss_test_priyesh"
40
+ project_name: "DrivAerML"
configs/DriveAerNet/config.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: DriveAerNet
3
+ test_name: "baseline_exel_test"
4
+ gpu_id: 0
5
+
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/1_Incompressible_external_flow/driveaer/1_VTK_surface
7
+ splits_file: datasets/${dataset_name}/train_val_test_splits_400
8
+ data_folder: driveaer
9
+
10
+ normalization: "std_norm"
11
+ physical_scale_for_test: True
12
+ press_mean: -93.194176
13
+ press_std: 111.9078
14
+ input_pos_mins: [-1.14437997, -1.02134001, 0.0]
15
+ input_pos_maxs: [4.08646011, 1.02122998, 1.75308001]
16
+
17
+ # num_points: 40000
18
+ num_points: 10005
19
+ num_workers: 1
20
+
21
+ #model
22
+ indim: 3
23
+ outdim: 1
24
+ model: ansysLPFMs
25
+ hidden_dim: 256
26
+ n_heads: 8
27
+ n_decoder: 8
28
+ mlp_ratio: 2
29
+
30
+
31
+ #training
32
+ val_iter: 1
33
+ lr: 0.001
34
+ batch_size: 1
35
+ epochs: 1500
36
+ optimizer:
37
+ type: AdamW
38
+ scheduler: OneCycleLR #OneCycleLR
39
+ loss_type: huber # options: mse, mae, huber
40
+
41
+ # scheduler: LinearWarmupCosineAnnealingLR
42
+ num_processes: 1
43
+ max_grad_norm: 1.0
44
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
45
+ eval: True
46
+
47
+
48
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
49
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
50
+
51
+ #logging
52
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
53
+
54
+ pos_embed_sincos: True
55
+
56
+
57
+ project_name: ${dataset_name}
configs/app_configs/Compressible flow over plane/config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: plane_transonic
3
+ test_name: "baseline"
4
+ gpu_id: 0
5
+ epochs: 300
6
+ data_dir: Data/${dataset_name}/1_VTK_surface/
7
+
8
+ json_file: ${data_dir}/params.json
9
+
10
+ splits_file: ${data_dir}/
11
+ data_folder: ${dataset_name}
12
+ input_normalization: none
13
+ normalization: "std_norm"
14
+ norm_vars: "pressure"
15
+ physical_scale_for_test: True
16
+ diff_input_velocity: false
17
+
18
+ # num_points: 40000
19
+ num_points: 30000 # 30000
20
+ num_workers: 1
21
+
22
+ #model
23
+ indim: 3
24
+ outdim: 1
25
+ model: ansysLPFMs
26
+ hidden_dim: 256
27
+ n_heads: 8
28
+ n_decoder: 8
29
+ mlp_ratio: 2
30
+
31
+
32
+ #training
33
+ val_iter: 1
34
+ lr: 0.001
35
+ batch_size: 1
36
+
37
+ optimizer:
38
+ type: AdamW
39
+ scheduler: OneCycleLR #OneCycleLR
40
+ loss_type: huber # options: mse, mae, huber
41
+
42
+ # scheduler: LinearWarmupCosineAnnealingLR
43
+ num_processes: 1
44
+ max_grad_norm: 1.0
45
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
46
+ eval: False
47
+
48
+
49
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
50
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
51
+
52
+ #logging
53
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
54
+
55
+ pos_embed_sincos: False
56
+
57
+
58
+ project_name: ${dataset_name}
configs/app_configs/Compressible flow over plane/full_transform_params.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ 7.045689443e-06,
5
+ -41.00179814,
6
+ -3.982239961
7
+ ],
8
+ "max": [
9
+ 77.2491069,
10
+ 41.00179814,
11
+ 17.56380467
12
+ ],
13
+ "mean": [
14
+ 25.50086541120836,
15
+ 2.7514762491576034e-14,
16
+ 1.3779085606454433
17
+ ],
18
+ "std": [
19
+ 14.20950999685366,
20
+ 8.283627425964518,
21
+ 2.0614220403323915
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "pressure": {
26
+ "min": -18114.4549,
27
+ "max": 16682.95124,
28
+ "mean": -1659.761628946487,
29
+ "std": 3701.8950182476274
30
+ }
31
+ },
32
+ "boundary_conditions_max": {
33
+ "fluid_density": 1.184,
34
+ "inlet_velocity": 30.0,
35
+ "viscosity_dynamic": 1.847e-05
36
+ }
37
+ }
configs/app_configs/Compressible flow over plane/pca_embedding.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e59a7f72be4b38f18c86e643a8b859558eb6c6afca0801cadc265eb4a0fbd992
3
+ size 824
configs/app_configs/Compressible flow over plane/pca_reducer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b511c4d5368bc70e0e79a090183eb12a17c5c806cb30c2b4bac33f20ad038c6
3
+ size 30720772
configs/app_configs/Compressible flow over plane/pca_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa67ca7a567c3ba2799c46168878e91a1e70e2ba8a17e71eed70edb550c09ff6
3
+ size 61440498
configs/app_configs/Compressible flow over plane/train_dist.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f34f5e7107b3cc05c1e5691c7cb78021f7d78b17ab9f8d10474743f9713c128d
3
+ size 126618
configs/app_configs/Incompressible flow inside artery/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_name: artery
2
+ test_name: baseline_w_sincos_v1
3
+ gpu_id: 0
4
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/0_Incompressible_internal_flow/artery/1_VTK_surface
5
+ json_file: ${data_dir}/../full_transform_params.json
6
+ splits_file: ${data_dir}/../
7
+ data_folder: ${dataset_name}
8
+
9
+ normalization: "std_norm"
10
+ norm_vars: "pressure"
11
+ physical_scale_for_test: True
12
+ diff_input_velocity: false
13
+
14
+ # num_points: 40000
15
+ num_points: 10000 #15600
16
+ num_workers: 1
17
+ indim: 3
18
+ outdim: 4
19
+ model: ansysLPFMs
20
+ hidden_dim: 256
21
+ n_heads: 8
22
+ n_decoder: 8
23
+ mlp_ratio: 2
24
+ val_iter: 1
25
+ lr: 0.001
26
+ batch_size: 1
27
+ epochs: 300
28
+ optimizer:
29
+ type: AdamW
30
+ scheduler: OneCycleLR
31
+ loss_type: huber
32
+ num_processes: 1
33
+ max_grad_norm: 1.0
34
+ mixed_precision: true
35
+ eval: false
36
+ chunked_eval: true
37
+ train_ckpt_load: false
38
+ pos_embed_sincos: true
39
+ project_name: ${dataset_name}
configs/app_configs/Incompressible flow inside artery/full_transform_params.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ 0.025,
5
+ -1.204112099,
6
+ -1.222944777
7
+ ],
8
+ "max": [
9
+ 9.975,
10
+ 1.199037878,
11
+ 1.197492383
12
+ ],
13
+ "mean": [
14
+ 5.000000000023512,
15
+ -0.0050070103447356455,
16
+ 0.0028514114550009695
17
+ ],
18
+ "std": [
19
+ 2.8867160058379673,
20
+ 0.2890868787382987,
21
+ 0.29065190404236046
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "pressure": {
26
+ "min": -4.88198,
27
+ "max": 23.20724818437452,
28
+ "mean": 3.2158044839488764,
29
+ "std": 3.9818469542284554
30
+ },
31
+ "x_velocity": {
32
+ "min": -1.070354126,
33
+ "max": 5.2829312809790085,
34
+ "mean": 1.0341798972778575,
35
+ "std": 0.8492951431118362
36
+ },
37
+ "y_velocity": {
38
+ "min": -1.0771796258219577,
39
+ "max": 1.0766134297782608,
40
+ "mean": -0.00017382796364451712,
41
+ "std": 0.2074833123027587
42
+ },
43
+ "z_velocity": {
44
+ "min": -1.109028642942948,
45
+ "max": 1.1098764213186418,
46
+ "mean": 0.0005453850642336799,
47
+ "std": 0.2126440513763686
48
+ }
49
+ }
50
+ }
configs/app_configs/Incompressible flow inside artery/train_dist.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5bfb9b4197f81455710966c8e5bd8352368542b2193a5b50708ede030914725
3
+ size 126618
configs/app_configs/Incompressible flow over car/config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: cadillac
3
+ test_name: "baseline-variable-velocity-long"
4
+ gpu_id: 0
5
+ epochs: 1000
6
+ data_dir: Data/${dataset_name}/1_VTK_surface/
7
+
8
+ json_file: ${data_dir}/params.json
9
+
10
+ splits_file: ${data_dir}/
11
+ data_folder: ${dataset_name}
12
+ input_normalization: shift_axis
13
+ normalization: "std_norm"
14
+ norm_vars: "pressure"
15
+ physical_scale_for_test: True
16
+
17
+ diff_input_velocity: True # If true, inlet_x_velocity is added as an input feature
18
+
19
+ # num_points: 40000
20
+ num_points: 30000 #30000 #30000
21
+ num_workers: 1
22
+
23
+ #model
24
+ indim: 4
25
+ outdim: 1
26
+ model: ansysLPFMs
27
+ hidden_dim: 256
28
+ n_heads: 8
29
+ n_decoder: 8
30
+ mlp_ratio: 2
31
+
32
+
33
+ #training
34
+ val_iter: 1
35
+ lr: 0.001
36
+ batch_size: 1
37
+
38
+ optimizer:
39
+ type: AdamW
40
+ scheduler: OneCycleLR #OneCycleLR
41
+ loss_type: huber # options: mse, mae, huber
42
+
43
+ # scheduler: LinearWarmupCosineAnnealingLR
44
+ num_processes: 1
45
+ max_grad_norm: 1.0
46
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
47
+ eval: False
48
+
49
+
50
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
51
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
52
+
53
+ #logging
54
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
55
+
56
+ pos_embed_sincos: False
57
+
58
+
59
+ project_name: ${dataset_name}
configs/app_configs/Incompressible flow over car/full_transform_params.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ -3.4729325771331787,
5
+ -1.36235511302948,
6
+ -0.09825284779071808
7
+ ],
8
+ "max": [
9
+ 3.9517974853515625,
10
+ 1.36235511302948,
11
+ 2.7738349437713623
12
+ ],
13
+ "mean": [
14
+ 0.04156165739940759,
15
+ -1.009201088640187e-18,
16
+ 0.7471213421336442
17
+ ],
18
+ "std": [
19
+ 1.4202602139692266,
20
+ 0.637785422763593,
21
+ 0.4729766176454715
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "pressure": {
26
+ "min": -2440.914306640625,
27
+ "max": 778.2598266601562,
28
+ "mean": -82.46476859198359,
29
+ "std": 145.29254059628536
30
+ }
31
+ },
32
+ "boundary_conditions_max": {
33
+ "fluid_density": 1.184,
34
+ "inlet_velocity": 35.0,
35
+ "viscosity_dynamic": 1.847e-05
36
+ },
37
+ "max_num_points": 925282,
38
+ "mean_num_points": 662953.6781609196,
39
+ "min_num_points": 469670
40
+ }
configs/app_configs/Incompressible flow over car/pca_embedding.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a453f26faab0fb3be15e9fb484ab8dcdf35bde71257c264d83a5d6faf1642d2
3
+ size 952
configs/app_configs/Incompressible flow over car/pca_reducer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19f552e22f7c2f73d1ee88142d3862c28778dea68929f03a28254b9c2fe5f4ac
3
+ size 30720772
configs/app_configs/Incompressible flow over car/pca_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16de9f26dc9b6dcb9ce1175d3890b829c18504a88aafb1e63837c963d0375a89
3
+ size 61440498
configs/app_configs/Incompressible flow over car/train_dist.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d16da556fd26a61f40756d78507b268681d137d667bf134da93985c010c95a98
3
+ size 126618
configs/app_configs/Structural analysis of bracket/config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: deepjeb
3
+ test_name: "baseline"
4
+ gpu_id: 0
5
+ epochs: 300
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/2_Structural_linear_static/deepjeb/1_VTK_surface/
7
+
8
+ json_file: ${data_dir}/../full_transform_params.json
9
+
10
+ splits_file: ${data_dir}/../
11
+ data_folder: ${dataset_name}
12
+
13
+ normalization: "std_norm"
14
+ norm_vars: "von_mises_stress"
15
+ physical_scale_for_test: True
16
+
17
+
18
+ # num_points: 40000
19
+ num_points: 15000
20
+ num_workers: 1
21
+
22
+ #model
23
+ indim: 3
24
+ outdim: 4
25
+ model: ansysLPFMs
26
+ hidden_dim: 256
27
+ n_heads: 8
28
+ n_decoder: 8
29
+ mlp_ratio: 2
30
+
31
+
32
+ #training
33
+ val_iter: 1
34
+ lr: 0.001
35
+ batch_size: 1
36
+
37
+ optimizer:
38
+ type: AdamW
39
+ scheduler: OneCycleLR #OneCycleLR
40
+ loss_type: huber # options: mse, mae, huber
41
+
42
+ # scheduler: LinearWarmupCosineAnnealingLR
43
+ num_processes: 1
44
+ max_grad_norm: 1.0
45
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
46
+ eval: False
47
+
48
+
49
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
50
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
51
+
52
+ #logging
53
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
54
+
55
+ pos_embed_sincos: True
56
+
57
+
58
+ project_name: ${dataset_name}
configs/app_configs/Structural analysis of bracket/full_transform_params.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ -39.4561,
5
+ -164.829,
6
+ -0.660644
7
+ ],
8
+ "max": [
9
+ 73.00727,
10
+ 21.59413,
11
+ 65.88816
12
+ ],
13
+ "mean": [
14
+ 14.378833339963359,
15
+ -70.12945824215386,
16
+ 19.648179770356553
17
+ ],
18
+ "std": [
19
+ 25.41684127064301,
20
+ 49.219533224947526,
21
+ 14.945247349739933
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "von_mises_stress": {
26
+ "min": -379.1768394875346,
27
+ "max": 425.25385344927463,
28
+ "mean": 22.799127814174668,
29
+ "std": 77.70355231295902
30
+ },
31
+ "x_displacement": {
32
+ "min": -0.27920543884164056,
33
+ "max": 0.33591902881152996,
34
+ "mean": 0.028007646216518724,
35
+ "std": 0.059356218220308093
36
+ },
37
+ "y_displacement": {
38
+ "min": -0.10354705206573399,
39
+ "max": 0.10148109386521943,
40
+ "mean": -0.001005061685581641,
41
+ "std": 0.019455714218423757
42
+ },
43
+ "z_displacement": {
44
+ "min": -0.200961,
45
+ "max": 0.6233766340719438,
46
+ "mean": 0.08011871363216146,
47
+ "std": 0.10814779545698262
48
+ }
49
+ },
50
+ "boundary_conditions_max": {
51
+ "youngs_modulus": 113.8,
52
+ "poissons_ratio": 0.342,
53
+ "material_density": 0.00447,
54
+ "force_x": 0.0,
55
+ "force_y": 0.0,
56
+ "force_z": 35.6,
57
+ "torque_x": 0.0,
58
+ "torque_y": 0.0,
59
+ "torque_z": 0.0
60
+ }
61
+ }
configs/app_configs/Structural analysis of bracket/train_dist.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f62cc7f80d090474c39fd0021e6b430d0e94117f31b37b3b7de4edfa1a8dc10b
3
+ size 126618
configs/app_configs/Vehicle crash analysis/config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: deepjeb
3
+ test_name: "baseline"
4
+ gpu_id: 0
5
+ epochs: 300
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/2_Structural_linear_static/deepjeb/1_VTK_surface/
7
+
8
+ json_file: ${data_dir}/../full_transform_params.json
9
+
10
+ splits_file: ${data_dir}/../
11
+ data_folder: ${dataset_name}
12
+
13
+ normalization: "std_norm"
14
+ norm_vars: "von_mises_stress"
15
+ physical_scale_for_test: True
16
+
17
+
18
+ # num_points: 40000
19
+ num_points: 2000 #15000
20
+ num_workers: 1
21
+
22
+ #model
23
+ indim: 3
24
+ outdim: 4
25
+ model: ansysLPFMs
26
+ hidden_dim: 256
27
+ n_heads: 8
28
+ n_decoder: 8
29
+ mlp_ratio: 2
30
+
31
+
32
+ #training
33
+ val_iter: 1
34
+ lr: 0.001
35
+ batch_size: 1
36
+
37
+ optimizer:
38
+ type: AdamW
39
+ scheduler: OneCycleLR #OneCycleLR
40
+ loss_type: huber # options: mse, mae, huber
41
+
42
+ # scheduler: LinearWarmupCosineAnnealingLR
43
+ num_processes: 1
44
+ max_grad_norm: 1.0
45
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
46
+ eval: False
47
+
48
+
49
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
50
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
51
+
52
+ #logging
53
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
54
+
55
+ pos_embed_sincos: True
56
+
57
+
58
+ project_name: ${dataset_name}
configs/app_configs/Vehicle crash analysis/full_transform_params.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ -39.4561,
5
+ -164.829,
6
+ -0.660644
7
+ ],
8
+ "max": [
9
+ 73.00727,
10
+ 21.59413,
11
+ 65.88816
12
+ ],
13
+ "mean": [
14
+ 14.378833339963359,
15
+ -70.12945824215386,
16
+ 19.648179770356553
17
+ ],
18
+ "std": [
19
+ 25.41684127064301,
20
+ 49.219533224947526,
21
+ 14.945247349739933
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "von_mises_stress": {
26
+ "min": -379.1768394875346,
27
+ "max": 425.25385344927463,
28
+ "mean": 22.799127814174668,
29
+ "std": 77.70355231295902
30
+ },
31
+ "x_displacement": {
32
+ "min": -0.27920543884164056,
33
+ "max": 0.33591902881152996,
34
+ "mean": 0.028007646216518724,
35
+ "std": 0.059356218220308093
36
+ },
37
+ "y_displacement": {
38
+ "min": -0.10354705206573399,
39
+ "max": 0.10148109386521943,
40
+ "mean": -0.001005061685581641,
41
+ "std": 0.019455714218423757
42
+ },
43
+ "z_displacement": {
44
+ "min": -0.200961,
45
+ "max": 0.6233766340719438,
46
+ "mean": 0.08011871363216146,
47
+ "std": 0.10814779545698262
48
+ }
49
+ },
50
+ "boundary_conditions_max": {
51
+ "youngs_modulus": 113.8,
52
+ "poissons_ratio": 0.342,
53
+ "material_density": 0.00447,
54
+ "force_x": 0.0,
55
+ "force_y": 0.0,
56
+ "force_z": 35.6,
57
+ "torque_x": 0.0,
58
+ "torque_y": 0.0,
59
+ "torque_z": 0.0
60
+ }
61
+ }
configs/app_configs/Vehicle crash analysis/train_dist.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f62cc7f80d090474c39fd0021e6b430d0e94117f31b37b3b7de4edfa1a8dc10b
3
+ size 126618
configs/artery/config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_name: artery
2
+ test_name: baseline_w_sincos_v1
3
+ gpu_id: 0
4
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/0_Incompressible_internal_flow/artery/1_VTK_surface
5
+ json_file: ${data_dir}/../full_transform_params.json
6
+ splits_file: ${data_dir}/../
7
+ data_folder: ${dataset_name}
8
+
9
+ normalization: "std_norm"
10
+ norm_vars: "pressure"
11
+ physical_scale_for_test: True
12
+
13
+ # num_points: 40000
14
+ num_points: 2000 #15600
15
+ num_workers: 1
16
+ indim: 3
17
+ outdim: 4
18
+ model: ansysLPFMs
19
+ hidden_dim: 256
20
+ n_heads: 8
21
+ n_decoder: 8
22
+ mlp_ratio: 2
23
+ val_iter: 1
24
+ lr: 0.001
25
+ batch_size: 1
26
+ epochs: 300
27
+ optimizer:
28
+ type: AdamW
29
+ scheduler: OneCycleLR
30
+ loss_type: huber
31
+ num_processes: 1
32
+ max_grad_norm: 1.0
33
+ mixed_precision: true
34
+ eval: false
35
+ chunked_eval: true
36
+ train_ckpt_load: false
37
+ pos_embed_sincos: true
38
+ project_name: ${dataset_name}
configs/artery/full_transform_params.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ 0.025,
5
+ -1.204112099,
6
+ -1.222944777
7
+ ],
8
+ "max": [
9
+ 9.975,
10
+ 1.199037878,
11
+ 1.197492383
12
+ ],
13
+ "mean": [
14
+ 5.000000000023512,
15
+ -0.0050070103447356455,
16
+ 0.0028514114550009695
17
+ ],
18
+ "std": [
19
+ 2.8867160058379673,
20
+ 0.2890868787382987,
21
+ 0.29065190404236046
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "pressure": {
26
+ "min": -4.88198,
27
+ "max": 23.20724818437452,
28
+ "mean": 3.2158044839488764,
29
+ "std": 3.9818469542284554
30
+ },
31
+ "x_velocity": {
32
+ "min": -1.070354126,
33
+ "max": 5.2829312809790085,
34
+ "mean": 1.0341798972778575,
35
+ "std": 0.8492951431118362
36
+ },
37
+ "y_velocity": {
38
+ "min": -1.0771796258219577,
39
+ "max": 1.0766134297782608,
40
+ "mean": -0.00017382796364451712,
41
+ "std": 0.2074833123027587
42
+ },
43
+ "z_velocity": {
44
+ "min": -1.109028642942948,
45
+ "max": 1.1098764213186418,
46
+ "mean": 0.0005453850642336799,
47
+ "std": 0.2126440513763686
48
+ }
49
+ }
50
+ }
configs/cadillac/config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: cadillac
3
+ test_name: "baseline-variable-velocity-axis_v2_long_test"
4
+ gpu_id: 1
5
+ epochs: 1000
6
+ # data_dir: /raid/ansysai/udbhav/alpha_Xdata/data_prep_transformer/${dataset_name}/1_VTK_surface/
7
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/data_prep_transformer/cadillac_v2/1_VTK_surface/
8
+
9
+
10
+ json_file: ${data_dir}/params.json
11
+
12
+ splits_file: ${data_dir}/
13
+ data_folder: ${dataset_name}
14
+
15
+ input_normalization: "shift_axis" # options: "min_max", "std_norm", "none"
16
+ normalization: "std_norm"
17
+ norm_vars: "pressure"
18
+ physical_scale_for_test: True
19
+
20
+ diff_input_velocity: True # If true, inlet_x_velocity is added as an input feature
21
+
22
+ # num_points: 40000
23
+ num_points: 30000
24
+ num_workers: 1
25
+
26
+ #model
27
+ indim: 4
28
+ outdim: 1
29
+ model: ansysLPFMs
30
+ hidden_dim: 256
31
+ n_heads: 8
32
+ n_decoder: 8
33
+ mlp_ratio: 2
34
+
35
+
36
+ #training
37
+ val_iter: 1
38
+ lr: 0.001
39
+ batch_size: 1
40
+
41
+ optimizer:
42
+ type: AdamW
43
+ scheduler: OneCycleLR #OneCycleLR
44
+ loss_type: huber # options: mse, mae, huber
45
+
46
+ # scheduler: LinearWarmupCosineAnnealingLR
47
+ num_processes: 1
48
+ max_grad_norm: 1.0
49
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
50
+ eval: True
51
+ save_latent: True
52
+
53
+
54
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
55
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
56
+
57
+ pos_embed_sincos: False
58
+
59
+
60
+ project_name: ${dataset_name}
configs/deepjeb/config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: deepjeb
3
+ test_name: "baseline"
4
+ gpu_id: 0
5
+ epochs: 300
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/2_Structural_linear_static/deepjeb/1_VTK_surface/
7
+
8
+ json_file: ${data_dir}/../full_transform_params.json
9
+
10
+ splits_file: ${data_dir}/../
11
+ data_folder: ${dataset_name}
12
+
13
+ normalization: "std_norm"
14
+ norm_vars: "von_mises_stress"
15
+ physical_scale_for_test: True
16
+
17
+
18
+ # num_points: 40000
19
+ num_points: 2000 #15000
20
+ num_workers: 1
21
+
22
+ #model
23
+ indim: 3
24
+ outdim: 4
25
+ model: ansysLPFMs
26
+ hidden_dim: 256
27
+ n_heads: 8
28
+ n_decoder: 8
29
+ mlp_ratio: 2
30
+
31
+
32
+ #training
33
+ val_iter: 1
34
+ lr: 0.001
35
+ batch_size: 1
36
+
37
+ optimizer:
38
+ type: AdamW
39
+ scheduler: OneCycleLR #OneCycleLR
40
+ loss_type: huber # options: mse, mae, huber
41
+
42
+ # scheduler: LinearWarmupCosineAnnealingLR
43
+ num_processes: 1
44
+ max_grad_norm: 1.0
45
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
46
+ eval: False
47
+
48
+
49
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
50
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
51
+
52
+ #logging
53
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
54
+
55
+ pos_embed_sincos: True
56
+
57
+
58
+ project_name: ${dataset_name}
configs/deepjeb/full_transform_params.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ -39.4561,
5
+ -164.829,
6
+ -0.660644
7
+ ],
8
+ "max": [
9
+ 73.00727,
10
+ 21.59413,
11
+ 65.88816
12
+ ],
13
+ "mean": [
14
+ 14.378833339963359,
15
+ -70.12945824215386,
16
+ 19.648179770356553
17
+ ],
18
+ "std": [
19
+ 25.41684127064301,
20
+ 49.219533224947526,
21
+ 14.945247349739933
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "von_mises_stress": {
26
+ "min": -379.1768394875346,
27
+ "max": 425.25385344927463,
28
+ "mean": 22.799127814174668,
29
+ "std": 77.70355231295902
30
+ },
31
+ "x_displacement": {
32
+ "min": -0.27920543884164056,
33
+ "max": 0.33591902881152996,
34
+ "mean": 0.028007646216518724,
35
+ "std": 0.059356218220308093
36
+ },
37
+ "y_displacement": {
38
+ "min": -0.10354705206573399,
39
+ "max": 0.10148109386521943,
40
+ "mean": -0.001005061685581641,
41
+ "std": 0.019455714218423757
42
+ },
43
+ "z_displacement": {
44
+ "min": -0.200961,
45
+ "max": 0.6233766340719438,
46
+ "mean": 0.08011871363216146,
47
+ "std": 0.10814779545698262
48
+ }
49
+ },
50
+ "boundary_conditions_max": {
51
+ "youngs_modulus": 113.8,
52
+ "poissons_ratio": 0.342,
53
+ "material_density": 0.00447,
54
+ "force_x": 0.0,
55
+ "force_y": 0.0,
56
+ "force_z": 35.6,
57
+ "torque_x": 0.0,
58
+ "torque_y": 0.0,
59
+ "torque_z": 0.0
60
+ }
61
+ }
configs/driveaerpp/config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: driveaerpp
3
+ test_name: "baseline_long_1500ep"
4
+ gpu_id: 0
5
+ epochs: 1500
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/1_Incompressible_external_flow/driveaernetpp/1_VTK_surface
7
+
8
+ json_file: ${data_dir}/../full_transform_params.json
9
+
10
+ splits_file: ${data_dir}/../
11
+ data_folder: ${dataset_name}
12
+
13
+ normalization: "std_norm"
14
+ norm_vars: "pressure"
15
+ physical_scale_for_test: False
16
+
17
+ num_points: 2000 #10000
18
+ num_workers: 1
19
+
20
+ #model
21
+ indim: 3
22
+ outdim: 1
23
+ model: ansysLPFMs
24
+ hidden_dim: 256
25
+ n_heads: 8
26
+ n_decoder: 8
27
+ mlp_ratio: 2
28
+
29
+
30
+ #training
31
+ val_iter: 1
32
+ lr: 0.001
33
+ batch_size: 1
34
+
35
+ optimizer:
36
+ type: AdamW
37
+ scheduler: OneCycleLR #OneCycleLR
38
+ loss_type: huber # options: mse, mae, huber
39
+
40
+ # scheduler: LinearWarmupCosineAnnealingLR
41
+ num_processes: 1
42
+ max_grad_norm: 1.0
43
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
44
+ eval: False
45
+
46
+
47
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
48
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
49
+
50
+ #logging
51
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
52
+
53
+ pos_embed_sincos: True
54
+
55
+
56
+ project_name: ${dataset_name}
configs/driveaerpp/full_transform_params.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mesh_stats": {
3
+ "min": [
4
+ -1.151419997215271,
5
+ -1.1943000555038452,
6
+ 0.0
7
+ ],
8
+ "max": [
9
+ 4.204110145568848,
10
+ 1.1943000555038452,
11
+ 1.76187002658844
12
+ ],
13
+ "mean": [
14
+ 0.03162362053990364,
15
+ -0.004585757851600647,
16
+ 0.009171515703201294
17
+ ],
18
+ "std": [
19
+ 0.30713993310928345,
20
+ 0.0957680270075798,
21
+ 0.13543644547462463
22
+ ]
23
+ },
24
+ "scalars": {
25
+ "pressure": {
26
+ "min": -27552.19921875,
27
+ "max": 6858.10986328125,
28
+ "mean": -93.42745971679688,
29
+ "std": 120.5966567993164
30
+ }
31
+ },
32
+ "boundary_conditions_max": {
33
+ "fluid_density": 1.184,
34
+ "inlet_velocity": 30.0,
35
+ "viscosity_dynamic": 1.847e-05
36
+ }
37
+ }
configs/elasticity/config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ data_dir: ../Data/elasticity
3
+ dataset_name: elasticity
4
+ ntrain: 1000
5
+ ntest: 200
6
+
7
+ #model
8
+ indim: 2
9
+ outdim: 1
10
+ model: ansysLPFMs
11
+ hidden_dim: 128
12
+ n_heads: 8
13
+ n_decoder: 8
14
+ mlp_ratio: 1
15
+
16
+
17
+ #training
18
+ val_iter: 1
19
+ lr: 0.001
20
+ batch_size: 1
21
+ epochs: 500
22
+ optimizer:
23
+ type: AdamW
24
+ scheduler: OneCycleLR
25
+ num_processes: 1
26
+ max_grad_norm: 1.0
27
+ eval: False
28
+
29
+
30
+ #logging
31
+ test_name: "OCLR_float32_A100_test"
32
+ project_name: "elasticity"
configs/plane_engine_test/config.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: plane_engine_test
3
+ test_name: "sin_cos_4mn_long_wo_sincos_v3"
4
+ gpu_id: 1
5
+ epochs: 2500
6
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/xgm_data/data_prep_transformer/1_Incompressible_external_flow/plane/1_VTK_surface
7
+
8
+ json_file: ${data_dir}/../full_transform_params.json
9
+
10
+ splits_file: ${data_dir}/../
11
+ data_folder: ${dataset_name}
12
+
13
+ normalization: "std_norm"
14
+ norm_vars: "pressure"
15
+ physical_scale_for_test: True
16
+
17
+ # num_points: 40000
18
+ num_points: 16035
19
+ num_workers: 1
20
+
21
+ #model
22
+ indim: 3
23
+ outdim: 1
24
+ model: ansysLPFMs
25
+ hidden_dim: 256
26
+ n_heads: 8
27
+ n_decoder: 8
28
+ mlp_ratio: 2
29
+
30
+
31
+ #training
32
+ val_iter: 1
33
+ lr: 0.001
34
+ batch_size: 1
35
+
36
+ optimizer:
37
+ type: AdamW
38
+ scheduler: OneCycleLR #OneCycleLR
39
+ loss_type: huber # options: mse, mae, huber
40
+
41
+ # scheduler: LinearWarmupCosineAnnealingLR
42
+ num_processes: 1
43
+ max_grad_norm: 1.0
44
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
45
+ eval: False
46
+
47
+
48
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
49
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
50
+
51
+ #logging
52
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
53
+
54
+ pos_embed_sincos: False
55
+
56
+
57
+ project_name: ${dataset_name}
configs/plane_transonic/config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: plane_transonic
3
+ test_name: "baseline_ep1000_v2_100_long_test"
4
+ gpu_id: 0
5
+ epochs: 300
6
+ # data_dir: /raid/ansysai/udbhav/alpha_Xdata/data_prep_transformer/${dataset_name}/1_VTK_surface/
7
+ data_dir: /raid/ansysai/udbhav/alpha_Xdata/data_prep_transformer/plane_transonic_v2/1_VTK_surface/
8
+
9
+
10
+ json_file: ${data_dir}/params.json
11
+
12
+ splits_file: ${data_dir}/
13
+ data_folder: ${dataset_name}
14
+
15
+ input_normalization: "none" # options: "min_max", "std_norm", "none"
16
+ normalization: "std_norm"
17
+ norm_vars: "pressure"
18
+ physical_scale_for_test: True
19
+ diff_input_velocity: false
20
+
21
+ # num_points: 40000
22
+ num_points: 30000
23
+ num_workers: 1
24
+
25
+ #model
26
+ indim: 3
27
+ outdim: 1
28
+ model: ansysLPFMs
29
+ hidden_dim: 256
30
+ n_heads: 8
31
+ n_decoder: 8
32
+ mlp_ratio: 2
33
+
34
+
35
+ #training
36
+ val_iter: 1
37
+ lr: 0.001
38
+ batch_size: 1
39
+
40
+ optimizer:
41
+ type: AdamW
42
+ scheduler: OneCycleLR #OneCycleLR
43
+ loss_type: huber # options: mse, mae, huber
44
+
45
+ # scheduler: LinearWarmupCosineAnnealingLR
46
+ num_processes: 1
47
+ max_grad_norm: 1.0
48
+ mixed_precision: True #currently default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models.
49
+ eval: true
50
+ save_latent: true
51
+
52
+
53
+ chunked_eval: True # Default with True is evaluation of max chunks of size num_points that can fit in a data sample, to avoid small last chunks
54
+ train_ckpt_load: False ## Will load best model if ckpt_load is false
55
+
56
+ #logging
57
+ # test_name: "Final_surface_only_OCLR_3p9M_float32_A100"
58
+
59
+ pos_embed_sincos: False
60
+
61
+
62
+ project_name: ${dataset_name}
configs/shapenet_car_pv/config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ data_dir: /raid/ansysai/ajoglekar/Data/shapenet_car/mlcfd_data/training_data
3
+ save_dir: /raid/ansysai/pkakka/6-Transformers/comparePhysicsLM/Data/shapenet_car/mlcfd_data/preprocessed_pv_data
4
+ preprocessed: True
5
+ dataset_name: shapenet_car_pv
6
+ fold_id: 0
7
+ num_workers: 1
8
+ train_num_points_volume: 4096 #Training volume as well with shapenet?
9
+ train_num_points_surface: 3586
10
+
11
+ #model
12
+ indim: 3
13
+ outdim: 1
14
+ vdim: 3
15
+ sdim: 1
16
+ model: ansysLPFMs
17
+ hidden_dim: 256 ## Time series and decoder GPT?
18
+ n_heads: 8
19
+ n_decoder: 8 ## Any intution on how many decoders to use? heads etc? s
20
+ mlp_ratio: 2
21
+
22
+ #training
23
+ val_iter: 1
24
+ lr: 0.001
25
+ batch_size: 1
26
+ epochs: 1
27
+ optimizer:
28
+ type: AdamW
29
+ scheduler: OneCycleLR
30
+ num_processes: 1
31
+ max_grad_norm: 1.0
32
+ mixed_precision: True #if True, default fp16 is selected by torch.autocast(). Fp16 gave the best results for Transformer based models. Can change to bf16 with dtpye arguments if required.
33
+ eval: False
34
+
35
+ #logging
36
+ test_name: "surface_only_float32_A100_test"
37
+ project_name: "ShapeNetCar_PV"
dataset_loader.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import pyvista as pv
7
+ import json
8
+ import glob
9
+
10
+ class Data_loader(Dataset):
11
+ def __init__(self, cfg, split, epoch_seed=None, mode='train'):
12
+ """
13
+ data_dir: parent directory
14
+ split: list of int, e.g. [0,1,2,3,4] for train, [5] for val, [6] for test
15
+ num_points: number of points to sample per geometry
16
+ epoch_seed: seed for random sampling (for training)
17
+ mode: 'train', 'val', or 'test'
18
+ """
19
+ self.data_dir = cfg.data_dir
20
+ self.split = split
21
+ self.num_points = cfg.num_points
22
+ self.epoch_seed = epoch_seed
23
+ self.mode = mode
24
+ self.cfg = cfg
25
+ self.meshes = []
26
+ self.mesh_names = []
27
+ for idx in split:
28
+ # Find folder matching *_{idx}
29
+ folder = os.path.join(self.data_dir, f"{idx}")
30
+ if not os.path.exists(folder):
31
+ raise FileNotFoundError(f"No folder matching '{idx}' found in {self.data_dir}")
32
+
33
+ # Find file matching *_{idx}.vtp inside the folder
34
+ vtp_files = glob.glob(os.path.join(folder, f"{idx}.vtp"))
35
+ if not vtp_files:
36
+ raise FileNotFoundError(f"No file matching '{idx}.vtp' found in {folder}")
37
+ vtp_file = vtp_files[0]
38
+ mesh = pv.read(vtp_file)
39
+ self.meshes.append(mesh)
40
+ self.mesh_names.append(os.path.splitext(os.path.basename(vtp_file))[0])
41
+ # For validation chunking
42
+ self.val_indices = None
43
+ self.val_chunk_ptr = 0
44
+ with open(cfg.json_file, "r") as f:
45
+ self.json_data = json.load(f)
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch_seed = epoch
49
+ self.val_indices = None
50
+ self.val_chunk_ptr = 0
51
+
52
+ def __len__(self):
53
+ if self.mode == 'train':
54
+ return len(self.meshes)
55
+ elif self.mode == 'val':
56
+ return len(self.meshes)
57
+ elif self.mode == 'test':
58
+ # Number of chunks = total points in all val meshes // num_points + remainder chunk
59
+ total = 0
60
+ for mesh in self.meshes:
61
+ return len(self.meshes)
62
+ else:
63
+ raise ValueError(f"Unknown mode: {self.mode}")
64
+
65
+ def __getitem__(self, idx):
66
+
67
+ if self.mode == 'train' or self.mode == 'val':
68
+ # Each item is a geometry, sample num_points randomly
69
+ mesh = self.meshes[idx]
70
+ n_pts = mesh.points.shape[0]
71
+ rng = np.random.default_rng(self.epoch_seed+idx)
72
+ indices = rng.choice(n_pts, self.num_points, replace=False)
73
+ pos = mesh.points
74
+ pos = torch.tensor(pos, dtype=torch.float32)
75
+ pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1)
76
+
77
+ if self.cfg.normalization == "std_norm":
78
+ target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"]
79
+
80
+ if self.cfg.diff_input_velocity:
81
+ inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1)
82
+ pos = torch.cat((pos,inlet_x_vel),dim = 1)
83
+
84
+ if self.cfg.input_normalization == "shift_axis":
85
+ coords = pos[:,:3].clone()
86
+
87
+ # Shift x: set minimum x (front bumper) to 0
88
+ coords[:, 0] = coords[:, 0] - coords[:, 0].min()
89
+
90
+ # Shift z: set minimum z (ground) to 0
91
+ coords[:, 2] = coords[:, 2] - coords[:, 2].min()
92
+
93
+ # Shift y: center about 0 (left/right symmetry)
94
+ y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
95
+ coords[:, 1] = coords[:, 1] - y_center
96
+
97
+ pos[:,:3] = coords
98
+
99
+
100
+ if self.cfg.pos_embed_sincos:
101
+
102
+ if self.cfg.diff_input_velocity:
103
+ raise Exception("pos_embed_sincos not supported with diff_input_velocity=True")
104
+
105
+ input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"])
106
+ input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"])
107
+ pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins)
108
+ assert torch.all(pos >= 0)
109
+ assert torch.all(pos <= 1000)
110
+
111
+ pos = pos[indices]
112
+
113
+ return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx]}
114
+
115
+ elif self.mode == 'test':
116
+ # For each mesh in test, scramble all points and return the full mesh
117
+ mesh = self.meshes[idx]
118
+ n_pts = mesh.points.shape[0]
119
+ rng = np.random.default_rng(self.epoch_seed+idx)
120
+ indices = rng.permutation(n_pts)
121
+ pos = mesh.points
122
+
123
+ pos = torch.tensor(pos, dtype=torch.float32)
124
+ pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1)
125
+
126
+ if self.cfg.normalization == "std_norm":
127
+ target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"]
128
+
129
+ if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity:
130
+
131
+ inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1)
132
+ pos = torch.cat((pos,inlet_x_vel),dim = 1)
133
+
134
+ if self.cfg.input_normalization == "shift_axis":
135
+
136
+ coords = pos[:,:3].clone()
137
+
138
+ # Shift x: set minimum x (front bumper) to 0
139
+ coords[:, 0] = coords[:, 0] - coords[:, 0].min()
140
+
141
+ # Shift z: set minimum z (ground) to 0
142
+ coords[:, 2] = coords[:, 2] - coords[:, 2].min()
143
+
144
+ # Shift y: center about 0 (left/right symmetry)
145
+ y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
146
+ coords[:, 1] = coords[:, 1] - y_center
147
+
148
+ pos[:,:3] = coords
149
+
150
+ if self.cfg.pos_embed_sincos:
151
+
152
+ if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity:
153
+ raise Exception("pos_embed_sincos not supported with diff_input_velocity=True")
154
+
155
+ input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"])
156
+ input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"])
157
+ pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins)
158
+ assert torch.all(pos >= 0)
159
+ assert torch.all(pos <= 1000)
160
+
161
+ pos = pos[indices]
162
+
163
+ return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx],"physical_coordinates":mesh.points[indices]}
164
+
165
+ else:
166
+ raise ValueError(f"Unknown mode: {self.mode}")
167
+
168
+ def get_dataloaders(cfg):
169
+
170
+
171
+ with open(os.path.join(cfg.splits_file, "train.txt")) as f:
172
+ train_split = [line.strip() for line in f if line.strip()]
173
+ with open(os.path.join(cfg.splits_file, "test.txt")) as f:
174
+ val_split = [line.strip() for line in f if line.strip()]
175
+ with open(os.path.join(cfg.splits_file, "test.txt")) as f:
176
+ test_split = [line.strip() for line in f if line.strip()]
177
+ print("Indices in test_split:", test_split[:5]) # Print first 5 indices for verification
178
+
179
+
180
+ train_dataset = Data_loader(cfg, train_split, mode='train')
181
+ val_dataset = Data_loader(cfg, val_split, mode='val')
182
+ test_dataset = Data_loader(cfg, test_split, mode='test')
183
+
184
+ train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
185
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
186
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
187
+
188
+ return train_loader, val_loader, test_loader
datasets/DrivAerML/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Downloading dataset
2
+ Run download_dataset.sh after setting LOCAL_DIR (path to where files should be downloaded) <br>
3
+ Currently only stl files (drivaer_i.stl) and surface points and fields files (boundary_i.vtp) are being downloaded. <br>
4
+
5
+ # Preprocess dataset
6
+ Run preprocess_data.py after setting the paths of downloaded data (input_dir) and where processed data should be saved (output_dir) in data_preprocessing.yaml. This will create .npy files in output_dir. <br>
7
+ Currently only surface quantities are being stored in the npy files. Change model_type in data_preprocessing.yaml to 'volume' or 'combined' for other quantities to be stored. <br>
8
+ The dataset does not provide train/val/test splits. splits_creation.py can be used for creating and saving the splits in train_val_test_splits.json. <br>
9
+
10
+ # Dataset and Dataloader
11
+ dataset_drivaerml.py has the Dataset class. Use get_dataloaders function for the Dataloaders. <br>
12
+ For first run use config argument 'presampled_exists': False. <br>
13
+ Set 'num_points' argument in config file and those many points will be sampled from each sample and stored in path set in 'presampled_data_path' argument in config. <br>
14
+
15
+ Refer to \url{https://huggingface.co/datasets/neashton/drivaerml} and \url{https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino} for further details.
datasets/DrivAerML/data_preprocessing.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data
2
+ dataset_name: DrivAerML
3
+ data_processor: # Data processor configurable parameters
4
+ output_dir: ../../../Data/drivaerml_processed_data_surface # path to where processed data should be saved
5
+ input_dir: ../../../Data/drivaerml_data # path to where raw downloaded data is stored
6
+ num_processors: 12
7
+
8
+ variables:
9
+ surface:
10
+ solution:
11
+ # The following is for AWS DrivAer dataset.
12
+ pMeanTrim: scalar
13
+ wallShearStressMeanTrim: vector
14
+ volume:
15
+ solution:
16
+ # The following is for AWS DrivAer dataset.
17
+ UMeanTrim: vector
18
+ pMeanTrim: scalar
19
+ nutMeanTrim: scalar
20
+ model_type: surface
datasets/DrivAerML/dataset_drivaerml.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DrivAerML Dataset with Memory-Efficient Presampling Support
3
+
4
+ This dataset implements presampling functionality for training, validation and test data.
5
+ The presampling feature ensures consistent results across different runs, with memory-efficient
6
+ on-demand loading.
7
+
8
+ Presampling Workflow:
9
+ 1. Set presampled=False in config to create presampled training, validation and test data
10
+ 2. The system creates fixed samples for all splits
11
+ 3. Each run's presampled data is saved as individual files in a directory structure
12
+ 4. Set presampled=True in config to use the saved presampled data for future runs
13
+
14
+ Directory Structure:
15
+ presampled_data_path/
16
+ ├── train/
17
+ │ ├── run_1.npy
18
+ │ ├── run_2.npy
19
+ │ └── ...
20
+ ├── validation/
21
+ │ ├── run_1.npy
22
+ │ ├── run_2.npy
23
+ │ └── ...
24
+ └── test/
25
+ ├── run_1.npy
26
+ ├── run_2.npy
27
+ └── ...
28
+
29
+ Configuration Parameters:
30
+ - presampled: Boolean flag to control whether to use presampled data
31
+ - presampled_data_path: Base path where presampled data directory is created
32
+
33
+ Usage:
34
+ - First run: Set presampled=False to create presampled data
35
+ - Subsequent runs: Set presampled=True to use existing presampled data
36
+ """
37
+
38
+ import os
39
+ import numpy as np
40
+ import torch
41
+ from torch.utils.data import Dataset, DataLoader
42
+ from torch.utils.data import default_collate
43
+ import json
44
+ import re
45
+ from sklearn.cluster import KMeans
46
+ from sklearn.neighbors import NearestNeighbors
47
+
48
+ def create_presampled_data(cfg, splits, save_path):
49
+ """
50
+ Create presampled training, validation and test data with fixed random sampling.
51
+ Saves individual files for each run to enable on-demand loading.
52
+
53
+ Args:
54
+ cfg: Configuration object
55
+ splits: Dictionary containing train/validation/test splits
56
+ save_path: Base path for saving presampled data (directory will be created)
57
+ """
58
+ print("Creating presampled training, validation and test data...")
59
+
60
+ # Create directory structure for presampled data
61
+ base_dir = os.path.splitext(save_path)[0] # Remove .npy extension if present
62
+ os.makedirs(base_dir, exist_ok=True)
63
+
64
+ # Set seed for reproducible sampling
65
+ np.random.seed(0)
66
+
67
+ for split_type in ['train', 'validation', 'test']:
68
+ print(f"Processing {split_type} split...")
69
+ split_runs = splits[split_type]
70
+
71
+ # Create subdirectory for this split
72
+ split_dir = os.path.join(base_dir, split_type)
73
+ os.makedirs(split_dir, exist_ok=True)
74
+
75
+ for run_number in split_runs:
76
+ # Find the corresponding .npy file
77
+ for f in os.listdir(cfg.data_dir):
78
+ if f.endswith('.npy'):
79
+ match = re.search(r'run_(\d+)', f) ## Very inefficient?
80
+ if match and int(match.group(1)) == run_number:
81
+ npy_file_path = os.path.join(cfg.data_dir, f)
82
+
83
+ # Load the original data
84
+ data = np.load(npy_file_path, allow_pickle=True).item()
85
+ coordinates = data['surface_mesh_centers']
86
+ field = data['surface_fields']
87
+
88
+ # Sample points with fixed seed for reproducibility
89
+ sample_indices = np.random.choice(coordinates.shape[0], cfg.num_points, replace=False)
90
+ sampled_coordinates = coordinates[sample_indices, :]
91
+ sampled_field = field[sample_indices, :]
92
+
93
+ # Save individual presampled file for this run
94
+ presampled_run_data = {
95
+ 'surface_mesh_centers': sampled_coordinates,
96
+ 'surface_fields': sampled_field
97
+ }
98
+
99
+ run_file_path = os.path.join(split_dir, f'run_{run_number}.npy')
100
+ np.save(run_file_path, presampled_run_data)
101
+ break
102
+
103
+ print(f"Presampled data saved to directory: {base_dir}")
104
+ print(f"Structure: {base_dir}/{{train,validation,test}}/run_{{number}}.npy")
105
+
106
+ return base_dir
107
+
108
+ class DrivAerMLDataset(Dataset):
109
+ def __init__(self, cfg, splits = None, split_type = 'train', presampled = False, save_presampled_data_path = None):
110
+ """
111
+ Initializes the DrivAerMLDataset instance.
112
+
113
+ Args:
114
+ cfg: Configuration object containing data directory and number of points
115
+ splits: List of run numbers to include, if None includes all files
116
+ split_type: Type of split ('train', 'validation', 'test')
117
+ presampled: Whether to use presampled data
118
+ save_presampled_data_path: Base path to the presampled data directory
119
+ """
120
+ self.data_dir = cfg.data_dir
121
+ self.chunked_eval = cfg.chunked_eval
122
+ self.splits = splits
123
+
124
+ # Store only run numbers and create filename mapping for efficiency
125
+ self.run_numbers = []
126
+ self.original_filenames = {} # run_number -> original filename
127
+
128
+ for f in os.listdir(cfg.data_dir):
129
+ if f.endswith('.npy'):
130
+ match = re.search(r'run_(\d+)', f)
131
+ if match:
132
+ run_number = int(match.group(1))
133
+ if run_number in splits:
134
+ self.run_numbers.append(run_number)
135
+ self.original_filenames[run_number] = f
136
+
137
+ if len(self.run_numbers) == 0:
138
+ raise ValueError(f"No .npy files found in directory: {cfg.data_dir}")
139
+
140
+ self.num_points = cfg.num_points
141
+ self.split_type = split_type
142
+ self.presampled = presampled # Is there a script for non presampled dataloader?
143
+
144
+ # Set up presampled data directory path (but don't load data yet)
145
+ if self.presampled and save_presampled_data_path:
146
+ self.presampled_base_dir = os.path.splitext(save_presampled_data_path)[0]
147
+ self.presampled_split_dir = os.path.join(self.presampled_base_dir, self.split_type)
148
+ if not os.path.exists(self.presampled_split_dir):
149
+ raise FileNotFoundError(f"Presampled data directory not found: {self.presampled_split_dir}")
150
+
151
+ def __len__(self):
152
+ return len(self.run_numbers)
153
+
154
+ def __getitem__(self, idx):
155
+ run_number = self.run_numbers[idx]
156
+
157
+ if self.presampled:
158
+ # Load presampled data on-demand
159
+ presampled_file_path = os.path.join(self.presampled_split_dir, f'run_{run_number}.npy')
160
+ if os.path.exists(presampled_file_path):
161
+ data_dict = np.load(presampled_file_path, allow_pickle=True).item()
162
+ coordinates = data_dict['surface_mesh_centers']
163
+ field = data_dict['surface_fields'][:,0:1]
164
+ else:
165
+ raise FileNotFoundError(f"Presampled file not found: {presampled_file_path}")
166
+ else:
167
+ # Load original data
168
+ original_filename = self.original_filenames[run_number]
169
+ original_file_path = os.path.join(self.data_dir, original_filename)
170
+ data = np.load(original_file_path, allow_pickle=True).item()
171
+ coordinates = data['surface_mesh_centers']
172
+ field = data['surface_fields'][:,0:1]
173
+
174
+ # Random sampling
175
+ sample_indices = np.random.choice(coordinates.shape[0], self.num_points, replace=False)
176
+ coordinates = coordinates[sample_indices,:]
177
+ field = field[sample_indices,0:1]
178
+
179
+ if self.split_type == 'test' and self.chunked_eval:
180
+ # Load original data
181
+ original_filename = self.original_filenames[run_number]
182
+ original_file_path = os.path.join(self.data_dir, original_filename)
183
+ data = np.load(original_file_path, allow_pickle=True).item()
184
+ coordinates = data['surface_mesh_centers']
185
+ field = data['surface_fields'][:,0:1]
186
+ num_chunks = coordinates.shape[0]//self.num_points
187
+ indices = torch.randperm(coordinates.shape[0])[:self.num_points*num_chunks]
188
+ # indices = torch.randperm(coordinates.shape[0])
189
+ coordinates = coordinates[indices,:]
190
+ field = field[indices,0:1]
191
+
192
+
193
+ coordinates_tensor = torch.tensor(coordinates, dtype=torch.float32)
194
+ field_tensor = torch.tensor(field, dtype=torch.float32)
195
+
196
+ # Use mean-std normalization for coordinates
197
+ coordinates_tensor = (coordinates_tensor - INPUT_POS_MEAN) / INPUT_POS_STD
198
+ field_tensor = (field_tensor - PRESSURE_MEAN) / PRESSURE_STD
199
+
200
+ data = {'input_pos': coordinates_tensor, 'output_feat': field_tensor, 'output_pos': coordinates_tensor}
201
+
202
+ return data
203
+
204
+
205
+ def calculate_normalization_constants(dataloader):
206
+ """
207
+ Calculate normalization constants for both pressure values and coordinate ranges
208
+ across the entire training dataset.
209
+
210
+ Args:
211
+ dataloader: Training DataLoader
212
+
213
+ Returns:
214
+ tuple: (pressure_mean, pressure_std, coord_ranges, coord_mean, coord_std)
215
+ where coord_ranges = {'min_x', 'max_x', 'min_y', 'max_y', 'min_z', 'max_z'}
216
+ coord_mean = [mean_x, mean_y, mean_z]
217
+ coord_std = [std_x, std_y, std_z]
218
+ """
219
+ all_pressures = []
220
+ all_coordinates = [] # Store all coordinate points for mean/std calculation
221
+
222
+ # Initialize coordinate extremes
223
+ max_x = float('-inf')
224
+ max_y = float('-inf')
225
+ max_z = float('-inf')
226
+ min_x = float('inf')
227
+ min_y = float('inf')
228
+ min_z = float('inf')
229
+
230
+ print("Calculating normalization constants...")
231
+ for batch_idx, batch in enumerate(dataloader):
232
+ # Process pressure values
233
+ output_feat = batch['output_feat']
234
+ pressures = output_feat.flatten().numpy()
235
+ all_pressures.extend(pressures)
236
+
237
+ # Process coordinate ranges and collect all coordinates
238
+ input_pos = batch['input_pos']
239
+ # Convert tensor to numpy for coordinate calculations
240
+ input_pos_np = input_pos.numpy()
241
+
242
+ # Collect all coordinate points for mean/std calculation
243
+ # Reshape from (batch_size, num_points, 3) to (batch_size * num_points, 3)
244
+ coords_reshaped = input_pos_np.reshape(-1, 3)
245
+ all_coordinates.extend(coords_reshaped)
246
+
247
+ # Calculate coordinate ranges
248
+ max_x = max(max_x, np.max(input_pos_np[:,:,0]))
249
+ max_y = max(max_y, np.max(input_pos_np[:,:,1]))
250
+ max_z = max(max_z, np.max(input_pos_np[:,:,2]))
251
+ min_x = min(min_x, np.min(input_pos_np[:,:,0]))
252
+ min_y = min(min_y, np.min(input_pos_np[:,:,1]))
253
+ min_z = min(min_z, np.min(input_pos_np[:,:,2]))
254
+
255
+ if batch_idx % 10 == 0: # Print progress every 10 batches
256
+ print(f"Processed {batch_idx + 1} batches...")
257
+
258
+ # Convert to numpy arrays for efficient computation
259
+ all_pressures = np.array(all_pressures)
260
+ all_coordinates = np.array(all_coordinates) # Shape: (total_points, 3)
261
+
262
+ # Calculate pressure statistics
263
+ pressure_mean = np.mean(all_pressures)
264
+ pressure_std = np.std(all_pressures)
265
+
266
+ # Calculate coordinate statistics (mean and std for each dimension)
267
+ coord_mean = np.mean(all_coordinates, axis=0) # [mean_x, mean_y, mean_z]
268
+ coord_std = np.std(all_coordinates, axis=0) # [std_x, std_y, std_z]
269
+
270
+ # Store coordinate ranges
271
+ coord_ranges = {
272
+ 'min_x': min_x, 'max_x': max_x,
273
+ 'min_y': min_y, 'max_y': max_y,
274
+ 'min_z': min_z, 'max_z': max_z
275
+ }
276
+
277
+ # Print comprehensive statistics
278
+ print(f"\nPressure statistics from {len(all_pressures)} data points:")
279
+ print(f"Mean: {pressure_mean:.6f}")
280
+ print(f"Std: {pressure_std:.6f}")
281
+ print(f"Min: {np.min(all_pressures):.6f}")
282
+ print(f"Max: {np.max(all_pressures):.6f}")
283
+
284
+ print(f"\nCoordinate ranges:")
285
+ print(f"X: [{min_x:.6f}, {max_x:.6f}]")
286
+ print(f"Y: [{min_y:.6f}, {max_y:.6f}]")
287
+ print(f"Z: [{min_z:.6f}, {max_z:.6f}]")
288
+
289
+ print(f"\nCoordinate statistics for mean-std normalization from {len(all_coordinates)} data points:")
290
+ print(f"Mean: [{coord_mean[0]:.6f}, {coord_mean[1]:.6f}, {coord_mean[2]:.6f}]")
291
+ print(f"Std: [{coord_std[0]:.6f}, {coord_std[1]:.6f}, {coord_std[2]:.6f}]")
292
+
293
+ print(f"\nFor use in dataset file:")
294
+ print(f"INPUT_POS_MEAN = torch.tensor([{coord_mean[0]:.6f}, {coord_mean[1]:.6f}, {coord_mean[2]:.6f}])")
295
+ print(f"INPUT_POS_STD = torch.tensor([{coord_std[0]:.6f}, {coord_std[1]:.6f}, {coord_std[2]:.6f}])")
296
+
297
+ return pressure_mean, pressure_std, coord_ranges, coord_mean, coord_std
298
+
299
+
300
+ def seed_worker(worker_id):
301
+ worker_seed = torch.initial_seed() % 2**32
302
+ np.random.seed(worker_seed)
303
+
304
+ g = torch.Generator()
305
+ g.manual_seed(0)
306
+
307
+ def get_dataloaders(cfg):
308
+ splits = json.load(open(cfg.splits_file)) # How is train validation used in DrivAerML?
309
+
310
+ # Handle presampling logic
311
+ presampled_data_path = getattr(cfg, 'presampled_data_path', os.path.join(cfg.data_dir, 'presampled_val_test_data.npy'))
312
+ presampled_base_dir = os.path.splitext(presampled_data_path)[0]
313
+
314
+ if not cfg.presampled_exists:
315
+ # Create presampled data if it doesn't exist or if presampled=False
316
+ if not os.path.exists(presampled_base_dir):
317
+ print("=" * 60)
318
+ print("PRESAMPLING MODE: Creating presampled validation and test data...")
319
+ print(f"Presampled data will be saved to: {presampled_base_dir}")
320
+ create_presampled_data(cfg, splits, presampled_data_path)
321
+ print("Presampled data created successfully!")
322
+ print("You can now set presampled=True in config for future runs to use this presampled data.")
323
+ print("=" * 60)
324
+ else:
325
+ print(f"Presampled data directory already exists at: {presampled_base_dir}")
326
+ print("Using existing presampled data. Set presampled=True to use it in future runs.")
327
+
328
+ if not cfg.presampled_exists :
329
+ print(f"Warning: presampled=True but presampled data directory not found at {presampled_base_dir}")
330
+ print("Creating presampled data...")
331
+ create_presampled_data(cfg, splits, presampled_data_path)
332
+
333
+ use_presampled = True
334
+ print(f"Using presampled training, validation and test data from: {presampled_base_dir}")
335
+
336
+
337
+ train_dataset = DrivAerMLDataset(cfg, splits = splits['train'], split_type = 'train',
338
+ presampled = use_presampled, save_presampled_data_path = presampled_data_path)
339
+ val_dataset = DrivAerMLDataset(cfg, splits = splits['validation'], split_type = 'validation',
340
+ presampled = use_presampled, save_presampled_data_path = presampled_data_path)
341
+ test_dataset = DrivAerMLDataset(cfg, splits = splits['test'], split_type = 'test',
342
+ presampled = use_presampled, save_presampled_data_path = presampled_data_path)
343
+
344
+ collate_fn = None
345
+
346
+ train_dataloader = DataLoader(
347
+ train_dataset, batch_size=cfg.batch_size, shuffle=True,
348
+ drop_last=True, num_workers=cfg.num_workers, collate_fn=collate_fn,
349
+ worker_init_fn=seed_worker, generator=g
350
+ )
351
+ val_dataloader = DataLoader(
352
+ val_dataset, batch_size=cfg.batch_size, shuffle=True,
353
+ drop_last=True, num_workers=cfg.num_workers, collate_fn=collate_fn,
354
+ worker_init_fn=seed_worker, generator=g
355
+ )
356
+ test_dataloader = DataLoader(
357
+ test_dataset, batch_size=1, shuffle=False,
358
+ drop_last=False, num_workers=cfg.num_workers, collate_fn=collate_fn,
359
+ worker_init_fn=seed_worker, generator=g
360
+ )
361
+
362
+ # # Calculate normalization constants
363
+ # print('Calculating normalization constants...')
364
+ # pressure_mean, pressure_std, coord_ranges, coord_mean, coord_std = calculate_normalization_constants(train_dataloader)
365
+ # exit()
366
+
367
+ return train_dataloader, val_dataloader, test_dataloader
368
+
369
+
370
+
371
+ # Pressure statistics from openfoam surface train dataset (10k points sampled):
372
+ # Mean: -229.845718
373
+ # Std: 269.598572
374
+ # Min: -3651.057861
375
+ # Max: 859.160034
376
+
377
+ # Coordinate ranges:
378
+ # X: [-0.941836, 4.131968]
379
+ # Y: [-1.129535, 1.125530]
380
+ # Z: [-0.317549, 1.244577]
381
+
382
+ # Pressure statistics from full openfoam surface train dataset (3323811346 data points):
383
+ # Mean: -229.266983
384
+ # Std: 269.226807
385
+ # Min: -111492.804688
386
+ # Max: 6382.190918
387
+
388
+ # Coordinate ranges:
389
+ # X: [-0.942579, 4.132785]
390
+ # Y: [-1.131676, 1.131676]
391
+ # Z: [-0.317577, 1.244584]
392
+
393
+ # Coordinate statistics for mean-std normalization (computed from full dataset):
394
+ # Mean: [1.595103, 0.000000, 0.463503]
395
+ # Std: [1.434788, 0.801948, 0.440890]
396
+
397
+
398
+
399
+ # Pressure statistics from 6553600 data points:
400
+ # Mean: -0.003021
401
+ # Std: 1.002092
402
+ # Min: -14.342350
403
+ # Max: 4.157114
404
+
405
+ # Coordinate ranges:
406
+ # X: [-1.768229, 1.766621]
407
+ # Y: [-1.408318, 1.410171]
408
+ # Z: [-1.771534, 1.781146]
409
+
410
+ # Coordinate statistics for mean-std normalization from 6553600 data points:
411
+ # Mean: [-0.076668, -0.001889, -0.831090]
412
+ # Std: [0.968414, 0.882944, 0.858088]
413
+
414
+ # For use in dataset file:
415
+ # INPUT_POS_MEAN = torch.tensor([-0.076668, -0.001889, -0.831090])
416
+ # INPUT_POS_STD = torch.tensor([0.968414, 0.882944, 0.858088])
417
+
418
+ # # With full dataset - pressure normalization
419
+ # PRESSURE_MEAN = -229.266983
420
+ # PRESSURE_STD = 269.226807
421
+
422
+ # # Coordinate normalization using mean-std
423
+ # INPUT_POS_MEAN = torch.tensor([1.595103, 0.000000, 0.463503])
424
+ # INPUT_POS_STD = torch.tensor([1.434788, 0.801948, 0.440890])
425
+
426
+ # # Legacy min-max normalization (keep for reference but not used)
427
+ # input_pos_mins = torch.tensor([-0.942579, -1.131676, -0.317577])
428
+ # input_pos_maxs = torch.tensor([4.132785, 1.131676, 1.244584])
429
+
430
+
431
+ # With full dataset - pressure normalization
432
+ PRESSURE_MEAN = -229.266983
433
+ PRESSURE_STD = 269.226807
434
+
435
+ # Coordinate normalization using mean-std
436
+ INPUT_POS_MEAN = torch.tensor([1.490858, -0.001515, 0.099364])
437
+ INPUT_POS_STD = torch.tensor([1.388309, 0.706769, 0.380478])
438
+
439
+ # # Legacy min-max normalization (keep for reference but not used)
440
+ # input_pos_mins = torch.tensor([-0.942579, -1.131676, -0.317577])
441
+ # input_pos_maxs = torch.tensor([4.132785, 1.131676, 1.244584])
442
+
443
+
444
+ # Pressure normalization
445
+ # PRESSURE_MEAN = 0
446
+ # PRESSURE_STD = 1
447
+
448
+ # # Coordinate normalization using mean-std
449
+ # INPUT_POS_MEAN = torch.tensor([0, 0, 0])
450
+ # INPUT_POS_STD = torch.tensor([1, 1, 1])
451
+