fenghora commited on
Commit
84bb8a3
·
1 Parent(s): 93e001c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -35
  2. README.md +66 -12
  3. app.py +418 -0
  4. assets/teaser.png +3 -0
  5. color_report.json +30 -0
  6. data_toolkit/bpy_render.py +623 -0
  7. data_toolkit/color_glb.py +84 -0
  8. data_toolkit/color_img.py +129 -0
  9. data_toolkit/example_full_seg.py +45 -0
  10. data_toolkit/example_full_seg_w_2d_map.py +46 -0
  11. data_toolkit/example_interactive_seg.py +46 -0
  12. data_toolkit/glb_to_parts.py +17 -0
  13. data_toolkit/glb_to_vxz.py +87 -0
  14. data_toolkit/img_to_cond.py +66 -0
  15. data_toolkit/texturing_pipeline.json +64 -0
  16. data_toolkit/transforms.json +31 -0
  17. data_toolkit/vxz_to_slat.py +123 -0
  18. examples/00aee5c2fef743d69421bb642d446a5b.glb +3 -0
  19. examples/00aee5c2fef743d69421bb642d446a5b.png +3 -0
  20. examples/01b8043112e74366a21256d5e64398fb.glb +3 -0
  21. examples/01b8043112e74366a21256d5e64398fb.png +3 -0
  22. examples/0c070001a3904cd6809a31345475e930.glb +3 -0
  23. examples/0c070001a3904cd6809a31345475e930.png +3 -0
  24. examples/0c3ca2b32545416f8f1e6f0e87def1a6.glb +3 -0
  25. examples/0c3ca2b32545416f8f1e6f0e87def1a6.png +3 -0
  26. examples/1b3e8b99913442308aa989e3f87680b3.glb +3 -0
  27. examples/1b3e8b99913442308aa989e3f87680b3.png +3 -0
  28. examples/1c33b2e86c023a72905a5bea4ae713d0.glb +3 -0
  29. examples/1c33b2e86c023a72905a5bea4ae713d0.png +3 -0
  30. examples/1ca8ea337fbc4bcfbeb3c633bc4c43f0.glb +3 -0
  31. examples/1ca8ea337fbc4bcfbeb3c633bc4c43f0.png +3 -0
  32. examples/2260799ee4e342398b64ab4ce8af1559.glb +3 -0
  33. examples/2260799ee4e342398b64ab4ce8af1559.png +3 -0
  34. examples/2ae5cf2990c34e7db704f677de8de74c.glb +3 -0
  35. examples/2ae5cf2990c34e7db704f677de8de74c.png +3 -0
  36. examples/2ceb6778ac114101833e4c531544ada8.glb +3 -0
  37. examples/2ceb6778ac114101833e4c531544ada8.png +3 -0
  38. examples/4b57e73e82ab400aa307adac36ea0e5e.glb +3 -0
  39. examples/4b57e73e82ab400aa307adac36ea0e5e.png +3 -0
  40. inference_full.py +553 -0
  41. inference_full_ori.py +383 -0
  42. inference_interactive.py +435 -0
  43. inference_unified.py +473 -0
  44. requirements.txt +23 -0
  45. split.py +833 -0
  46. split_ori.py +686 -0
  47. train_full.py +227 -0
  48. train_interactive.py +287 -0
  49. train_unified.py +303 -0
  50. trellis2/__init__.py +6 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.glb filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,66 @@
1
- ---
2
- title: SegviGen
3
- emoji: ⚡
4
- colorFrom: gray
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SegviGen: Repurposing 3D Generative Model for Part Segmentation
2
+
3
+ ![teaser](assets/teaser.png)
4
+
5
+ ***SegviGen*** is a framework for 3D part segmentation that leverages the rich 3D structural and textural knowledge encoded in large-scale 3D generative models.
6
+ It learns to predict part-indicative colors while reconstructing geometry, and unifies three settings in one architecture: **interactive part segmentation**, **full segmentation**, and **2D segmentation map–guided full segmentation** with arbitrary granularity.
7
+
8
+
9
+ ## 🌟 Features
10
+ - **Repurposed 3D Generative Priors for Data Efficiency**: By reusing the rich structural and textural knowledge encoded in large-scale native 3D generative models, ***SegviGen*** learns 3D part segmentation with minimal task-specific supervision, requiring only **0.32%** training data.
11
+ - **Unified and Flexible Segmentation Settings**: Supports **interactive part segmentation**, **full segmentation**, and **2D segmentation map–guided full segmentation** with arbitrary part granularity under a single architecture.
12
+ - **State-of-the-Art Accuracy**: Consistently surpasses P3-SAM, delivering a **40%** gain in IoU@1 for single-click interaction on PartObjaverse-Tiny and PartNeXT, and a **15%** improvement in overall IoU for unguided full segmentation averaged across datasets.
13
+
14
+
15
+ ## 🔨Installation
16
+
17
+ ### Prerequisites
18
+ - **System**: Linux
19
+ - **GPU**: A NVIDIA GPU with at least 24GB of memory is necessary
20
+ - **Python**: 3.10
21
+
22
+ ### Installation Steps
23
+ 1. Create the environment of [TRELLIS.2](https://github.com/microsoft/TRELLIS.2)
24
+ ```sh
25
+ git clone -b main https://github.com/microsoft/TRELLIS.2.git --recursive
26
+ cd TRELLIS.2
27
+ ./setup.sh --new-env --basic --flash-attn --nvdiffrast --nvdiffrec --cumesh --o-voxel --flexgemm
28
+ ```
29
+
30
+ 2. Install the rest of requirements
31
+ ```sh
32
+ pip install mathutils
33
+ pip install transformers==4.57.6 # https://github.com/microsoft/TRELLIS.2/issues/101
34
+ pip install bpy==4.0.0 --extra-index-url https://download.blender.org/pypi/
35
+ sudo apt-get install -y libsm6 libxrender1 libxext6
36
+ pip install --upgrade Pillow
37
+ ```
38
+
39
+ 3. If want to train
40
+ ```sh
41
+ pip install pytorch_lightning
42
+ ```
43
+
44
+ ### Pretrained Weights
45
+
46
+ The checkpoints of **Interactive part-segmentation**, **Full segmentation** and **Full segmentation with 2D guidance** are available on [Hugging Face](https://huggingface.co/Nelipot/tmp).
47
+
48
+ ## 📒Usage
49
+
50
+ - `inference_interactive.py`: **Interactive part-segmentation**
51
+ - `inference_full.py`: **Full segmentation** or **Full segmentation with 2D guidance**
52
+ - `inference_unified.py`: Unified model
53
+
54
+ ## Training
55
+
56
+ ### Data preparation
57
+
58
+ - `data_toolkit/example_interactive_seg.py`: **Interactive part-segmentation**
59
+ - `data_toolkit/example_full_seg.py`: **Full segmentation**
60
+ - `data_toolkit/example_full_seg_w_2d_map.py`: **Full segmentation with 2D guidance**
61
+
62
+ ### Running training
63
+
64
+ - `train_interactive.py`: **Interactive part-segmentation**
65
+ - `train_full.py`: **Full segmentation** or **Full segmentation with 2D guidance**
66
+ - `train_unified.py`: Unified model
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+
4
+ os.makedirs("pretrained_model", exist_ok=True)
5
+
6
+ CKPT_FULL_SEG = "pretrained_model/full_seg.ckpt"
7
+ CKPT_W_2D_MAP = "pretrained_model/full_seg_w_2d_map.ckpt"
8
+
9
+ if not os.path.exists(CKPT_FULL_SEG):
10
+ urllib.request.urlretrieve(
11
+ "https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg.ckpt",
12
+ CKPT_FULL_SEG,
13
+ )
14
+
15
+ if not os.path.exists(CKPT_W_2D_MAP):
16
+ urllib.request.urlretrieve(
17
+ "https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg_w_2d_map.ckpt",
18
+ CKPT_W_2D_MAP,
19
+ )
20
+
21
+ import shutil
22
+ import traceback
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ from typing import List
26
+ import inference_full as inf
27
+ import split as splitter
28
+
29
+
30
+ TRANSFORMS_JSON = "./data_toolkit/transforms.json"
31
+
32
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
33
+ TMP_DIR = os.path.join(ROOT_DIR, "_tmp_gradio_seg")
34
+ EXAMPLES_CACHE_DIR = os.path.join(TMP_DIR, "examples_cache")
35
+ os.makedirs(TMP_DIR, exist_ok=True)
36
+ os.makedirs(EXAMPLES_CACHE_DIR, exist_ok=True)
37
+
38
+ os.environ["GRADIO_TEMP_DIR"] = TMP_DIR
39
+ os.environ["GRADIO_EXAMPLES_CACHE"] = EXAMPLES_CACHE_DIR
40
+
41
+ import gradio as gr
42
+
43
+ EXAMPLES_DIR = "examples"
44
+
45
+
46
+ def _ensure_dir(p: str):
47
+ os.makedirs(p, exist_ok=True)
48
+
49
+
50
+ def _normalize_path(x):
51
+ """
52
+ Compatible with different Gradio versions: File/Model3D might be str / dict / object
53
+ """
54
+ if x is None:
55
+ return None
56
+ if isinstance(x, str):
57
+ return x
58
+ if isinstance(x, dict):
59
+ return x.get("name") or x.get("path") or x.get("data")
60
+ return getattr(x, "name", None) or getattr(x, "path", None) or None
61
+
62
+
63
+ def _raise_user_error(msg: str):
64
+ if hasattr(gr, "Error"):
65
+ raise gr.Error(msg)
66
+ raise RuntimeError(msg)
67
+
68
+
69
+ def _collect_examples(example_dir: str) -> List[List[str]]:
70
+ """
71
+ Scan example_dir for pairs: <name>.glb + <name>.png
72
+ Return a list of examples: [[glb_path, png_path], ...]
73
+ """
74
+ d = Path(example_dir)
75
+ if not d.is_dir():
76
+ return []
77
+
78
+ examples: List[List[str]] = []
79
+
80
+ # Search recursively in case you add subfolders later
81
+ glb_files = sorted(d.rglob("*.glb"))
82
+ for glb_path in glb_files:
83
+ png_path = glb_path.with_suffix(".png")
84
+ if png_path.is_file():
85
+ examples.append([str(glb_path), str(png_path)])
86
+ # If png is missing, skip to keep examples consistent (2 inputs required)
87
+
88
+ return examples
89
+
90
+
91
+ # Build examples once at startup
92
+ FULL_SEG_EXAMPLES = _collect_examples(EXAMPLES_DIR)
93
+
94
+
95
+ def run_seg(glb_in, img_in):
96
+ """
97
+ Segment button: generates whole segmented GLB and displays in the second box.
98
+ Returns: segmented_glb_path, segmented_glb_path(state)
99
+ """
100
+ try:
101
+ glb_path = _normalize_path(glb_in)
102
+ img_path = _normalize_path(img_in)
103
+
104
+ if glb_path is None or (not os.path.isfile(glb_path)):
105
+ _raise_user_error("Please upload a valid .glb file.")
106
+
107
+ _ensure_dir(TMP_DIR)
108
+ run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
109
+ workdir = os.path.join(TMP_DIR, run_id)
110
+ _ensure_dir(workdir)
111
+
112
+ in_glb = os.path.join(workdir, "input.glb")
113
+ shutil.copy(glb_path, in_glb)
114
+
115
+ out_glb = os.path.join(workdir, "segmented.glb")
116
+ in_vxz = os.path.join(workdir, "input.vxz")
117
+
118
+ # If image is provided -> 2d_map=True; otherwise full segmentation (render_from_transforms)
119
+ if img_path is not None and os.path.isfile(img_path):
120
+ ckpt = CKPT_W_2D_MAP
121
+ in_img = os.path.join(workdir, "2d_map.png")
122
+ shutil.copy(img_path, in_img)
123
+ item = {
124
+ "2d_map": True,
125
+ "glb": in_glb,
126
+ "input_vxz": in_vxz,
127
+ "img": in_img,
128
+ "export_glb": out_glb,
129
+ }
130
+ else:
131
+ ckpt = CKPT_FULL_SEG
132
+ render_img = os.path.join(workdir, "render.png")
133
+ item = {
134
+ "2d_map": False,
135
+ "glb": in_glb,
136
+ "input_vxz": in_vxz,
137
+ "transforms": TRANSFORMS_JSON,
138
+ "img": render_img,
139
+ "export_glb": out_glb,
140
+ }
141
+
142
+ inf.inference_with_loaded_models(ckpt, item)
143
+
144
+ if not os.path.isfile(out_glb):
145
+ _raise_user_error("Export failed: output glb not found.")
146
+
147
+ # Apply X90 rotation for whole segmented output
148
+ # _apply_root_x90_rotation_glb(out_glb)
149
+
150
+ return out_glb, out_glb
151
+
152
+ except Exception as e:
153
+ err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
154
+ print(err)
155
+ raise
156
+
157
+
158
+ def run_refine_segmentation(
159
+ seg_glb_path_state,
160
+ color_quant_step,
161
+ palette_sample_pixels,
162
+ palette_min_pixels,
163
+ palette_max_colors,
164
+ palette_merge_dist,
165
+ samples_per_face,
166
+ flip_v,
167
+ uv_wrap_repeat,
168
+ transition_conf_thresh,
169
+ transition_prop_iters,
170
+ transition_neighbor_min,
171
+ small_component_action,
172
+ small_component_min_faces,
173
+ postprocess_iters,
174
+ min_faces_per_part,
175
+ bake_transforms,
176
+ ):
177
+ """
178
+ Refine Segmentation button: splits the segmented GLB into smaller parts GLB and displays in the fourth box.
179
+ """
180
+ try:
181
+ seg_glb_path = seg_glb_path_state if isinstance(seg_glb_path_state, str) else None
182
+ if (seg_glb_path is None) or (not os.path.isfile(seg_glb_path)):
183
+ _raise_user_error("Please run Segmentation first (the segmented GLB is missing).")
184
+
185
+ out_dir = os.path.dirname(seg_glb_path)
186
+ out_parts_glb = os.path.join(out_dir, "segmented_parts.glb")
187
+
188
+ splitter.split_glb_by_texture_palette_rgb(
189
+ in_glb_path=seg_glb_path,
190
+ out_glb_path=out_parts_glb,
191
+ min_faces_per_part=min_faces_per_part,
192
+ bake_transforms=bool(bake_transforms),
193
+ color_quant_step=color_quant_step,
194
+ palette_sample_pixels=palette_sample_pixels,
195
+ palette_min_pixels=palette_min_pixels,
196
+ palette_max_colors=palette_max_colors,
197
+ palette_merge_dist=palette_merge_dist,
198
+ samples_per_face=samples_per_face,
199
+ flip_v=flip_v,
200
+ uv_wrap_repeat=uv_wrap_repeat,
201
+ transition_conf_thresh=transition_conf_thresh,
202
+ transition_prop_iters=transition_prop_iters,
203
+ transition_neighbor_min=transition_neighbor_min,
204
+ small_component_action=small_component_action,
205
+ small_component_min_faces=small_component_min_faces,
206
+ postprocess_iters=postprocess_iters,
207
+ debug_print=True,
208
+ )
209
+
210
+ if not os.path.isfile(out_parts_glb):
211
+ _raise_user_error("Split failed: output parts glb not found.")
212
+
213
+ # If bake_transforms=False, split output will not have the wrapper transform baked, so we need to apply X90 rotation fix
214
+ # if (not bool(bake_transforms)) and APPLY_OUTPUT_X90_FIX:
215
+ # _apply_root_x90_rotation_glb(out_parts_glb)
216
+
217
+ return out_parts_glb
218
+
219
+ except Exception as e:
220
+ err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
221
+ print(err)
222
+ raise
223
+
224
+
225
+ CSS_TEXT = """
226
+ <style>
227
+ #in_glb { height: 520px !important; }
228
+ #seg_glb { height: 520px !important; }
229
+ #part_glb{ height: 520px !important; }
230
+ #img { height: 520px !important; }
231
+ </style>
232
+ """
233
+
234
+ with gr.Blocks() as demo:
235
+ gr.HTML(CSS_TEXT)
236
+ gr.Markdown(
237
+ """
238
+ # SegviGen: Repurposing 3D Generative Model for Part Segmentation
239
+ """
240
+ )
241
+
242
+ # ---------------- 2x2 Layout ----------------
243
+ with gr.Row():
244
+ with gr.Column(scale=1, min_width=260):
245
+ in_glb = gr.Model3D(label="Input GLB", elem_id="in_glb")
246
+ with gr.Column(scale=1, min_width=260):
247
+ seg_glb = gr.Model3D(label="Processed GLB", elem_id="seg_glb")
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=1, min_width=260):
251
+ with gr.Accordion("2D Segmentation Map (Optional)", open=False):
252
+ in_img = gr.Image(label="2D Segmentation Map", type="filepath", elem_id="img")
253
+
254
+ seg_btn = gr.Button("Process", variant="primary")
255
+
256
+ # ✅ Examples directly under the Process button
257
+ if FULL_SEG_EXAMPLES:
258
+ gr.Examples(
259
+ examples=FULL_SEG_EXAMPLES,
260
+ inputs=[in_glb, in_img],
261
+ label="Examples",
262
+ examples_per_page=3,
263
+ cache_examples=False,
264
+ )
265
+ else:
266
+ gr.Markdown(f"**No examples found** in: `{EXAMPLES_DIR}` (expected: `*.glb` + same-name `*.png`).")
267
+
268
+ with gr.Accordion("Advanced segmentation options", open=False):
269
+ def _g(name, default):
270
+ return getattr(splitter, name, default)
271
+
272
+ color_quant_step = gr.Slider(
273
+ 1, 64, value=_g("COLOR_QUANT_STEP", 16), step=1, label="COLOR_QUANT_STEP"
274
+ )
275
+ gr.Markdown(
276
+ "*COLOR_QUANT_STEP controls the RGB quantization step, where a larger value merges similar colors more aggressively and a smaller value preserves finer color differences.*"
277
+ )
278
+
279
+ palette_sample_pixels = gr.Number(
280
+ value=_g("PALETTE_SAMPLE_PIXELS", 2_000_000), precision=0, label="PALETTE_SAMPLE_PIXELS"
281
+ )
282
+ gr.Markdown(
283
+ "*PALETTE_SAMPLE_PIXELS sets the maximum number of sampled pixels used to estimate the palette, where more samples improve stability but increase runtime.*"
284
+ )
285
+
286
+ palette_min_pixels = gr.Number(
287
+ value=_g("PALETTE_MIN_PIXELS", 500), precision=0, label="PALETTE_MIN_PIXELS"
288
+ )
289
+ gr.Markdown(
290
+ "*PALETTE_MIN_PIXELS specifies the minimum pixel count required to keep a color in the palette, where a higher threshold suppresses noise but may discard small parts.*"
291
+ )
292
+
293
+ palette_max_colors = gr.Number(
294
+ value=_g("PALETTE_MAX_COLORS", 256), precision=0, label="PALETTE_MAX_COLORS"
295
+ )
296
+ gr.Markdown(
297
+ "*PALETTE_MAX_COLORS limits the maximum number of colors retained in the palette, where a larger limit yields finer partitions and a smaller limit enforces stronger merging.*"
298
+ )
299
+
300
+ palette_merge_dist = gr.Number(
301
+ value=_g("PALETTE_MERGE_DIST", 32), precision=0, label="PALETTE_MERGE_DIST"
302
+ )
303
+ gr.Markdown(
304
+ "*PALETTE_MERGE_DIST defines the distance threshold for merging nearby palette colors in RGB space, where a larger threshold merges near duplicates more often and a smaller threshold keeps colors distinct.*"
305
+ )
306
+
307
+ samples_per_face = gr.Dropdown(
308
+ choices=[1, 4], value=_g("SAMPLES_PER_FACE", 4), label="SAMPLES_PER_FACE"
309
+ )
310
+ gr.Markdown(
311
+ "*SAMPLES_PER_FACE sets the number of UV samples per triangle used for label voting, where more samples improve robustness near boundaries but increase computation.*"
312
+ )
313
+
314
+ flip_v = gr.Checkbox(value=_g("FLIP_V", True), label="FLIP_V")
315
+ gr.Markdown(
316
+ "*FLIP_V toggles whether the V coordinate is flipped to match common glTF texture conventions, and you should disable it only if the texture appears vertically inverted.*"
317
+ )
318
+
319
+ uv_wrap_repeat = gr.Checkbox(value=_g("UV_WRAP_REPEAT", True), label="UV_WRAP_REPEAT")
320
+ gr.Markdown(
321
+ "*UV_WRAP_REPEAT selects how out of range UVs are handled by either repeating via modulo or clamping to the unit interval, and repeating is typically preferred for tiled textures.*"
322
+ )
323
+
324
+ transition_conf_thresh = gr.Slider(
325
+ 0.25, 1.0, value=float(_g("TRANSITION_CONF_THRESH", 1.0)), step=0.25, label="TRANSITION_CONF_THRESH"
326
+ )
327
+ gr.Markdown(
328
+ "*TRANSITION_CONF_THRESH sets the confidence threshold for transition handling, where a higher value makes refinement more conservative and a lower value enables more aggressive smoothing.*"
329
+ )
330
+
331
+ transition_prop_iters = gr.Number(
332
+ value=_g("TRANSITION_PROP_ITERS", 6), precision=0, label="TRANSITION_PROP_ITERS"
333
+ )
334
+ gr.Markdown(
335
+ "*TRANSITION_PROP_ITERS specifies the number of propagation iterations used in transition refinement, where more iterations strengthen diffusion effects but increase runtime.*"
336
+ )
337
+
338
+ transition_neighbor_min = gr.Number(
339
+ value=_g("TRANSITION_NEIGHBOR_MIN", 1), precision=0, label="TRANSITION_NEIGHBOR_MIN"
340
+ )
341
+ gr.Markdown(
342
+ "*TRANSITION_NEIGHBOR_MIN requires a minimum number of supporting neighbors to propagate a label, where a higher requirement is more conservative and a lower requirement is more permissive.*"
343
+ )
344
+
345
+ small_component_action = gr.Dropdown(
346
+ choices=["reassign", "drop"], value=_g("SMALL_COMPONENT_ACTION", "reassign"), label="SMALL_COMPONENT_ACTION"
347
+ )
348
+ gr.Markdown(
349
+ "*SMALL_COMPONENT_ACTION determines how small connected components are handled by either reassigning them to neighboring labels or dropping them entirely.*"
350
+ )
351
+
352
+ small_component_min_faces = gr.Number(
353
+ value=_g("SMALL_COMPONENT_MIN_FACES", 50), precision=0, label="SMALL_COMPONENT_MIN_FACES"
354
+ )
355
+ gr.Markdown(
356
+ "*SMALL_COMPONENT_MIN_FACES defines the face count threshold used to classify a component as small, where a higher threshold merges or removes more fragments and a lower threshold preserves more small parts.*"
357
+ )
358
+
359
+ postprocess_iters = gr.Number(
360
+ value=_g("POSTPROCESS_ITERS", 3), precision=0, label="POSTPROCESS_ITERS"
361
+ )
362
+ gr.Markdown(
363
+ "*POSTPROCESS_ITERS sets the number of post processing iterations, where more iterations produce stronger cleanup at the cost of additional computation.*"
364
+ )
365
+
366
+ min_faces_per_part = gr.Number(
367
+ value=_g("MIN_FACES_PER_PART", 1), precision=0, label="MIN_FACES_PER_PART"
368
+ )
369
+ gr.Markdown(
370
+ "*MIN_FACES_PER_PART enforces a minimum number of faces per exported part, where a larger value filters tiny outputs and a smaller value retains fine components.*"
371
+ )
372
+
373
+ bake_transforms = gr.Checkbox(value=_g("BAKE_TRANSFORMS", True), label="BAKE_TRANSFORMS")
374
+ gr.Markdown(
375
+ "*BAKE_TRANSFORMS controls whether scene graph transforms are baked into geometry before splitting, where enabling it improves consistency in world space and disabling it preserves node transforms.*"
376
+ )
377
+
378
+ with gr.Column(scale=1, min_width=260):
379
+ refine_btn = gr.Button("Segment", variant="secondary")
380
+ part_glb = gr.Model3D(label="Segmented GLB", elem_id="part_glb")
381
+
382
+ # Hidden states
383
+ seg_glb_state = gr.State(None)
384
+
385
+ # ---------------- wiring ----------------
386
+ seg_btn.click(
387
+ fn=run_seg,
388
+ inputs=[in_glb, in_img],
389
+ outputs=[seg_glb, seg_glb_state],
390
+ )
391
+
392
+ refine_btn.click(
393
+ fn=run_refine_segmentation,
394
+ inputs=[
395
+ seg_glb_state,
396
+ color_quant_step,
397
+ palette_sample_pixels,
398
+ palette_min_pixels,
399
+ palette_max_colors,
400
+ palette_merge_dist,
401
+ samples_per_face,
402
+ flip_v,
403
+ uv_wrap_repeat,
404
+ transition_conf_thresh,
405
+ transition_prop_iters,
406
+ transition_neighbor_min,
407
+ small_component_action,
408
+ small_component_min_faces,
409
+ postprocess_iters,
410
+ min_faces_per_part,
411
+ bake_transforms
412
+ ],
413
+ outputs=[part_glb],
414
+ )
415
+
416
+ if __name__ == "__main__":
417
+ inf.PIPE.load_all_models()
418
+ demo.launch(server_name="0.0.0.0", server_port=8012, share=False)
assets/teaser.png ADDED

Git LFS Details

  • SHA256: af7a4739bea10cea49a2c70bf7ae86371876cb46676eaf8eb3b2adab04c71a8b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.39 MB
color_report.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_glb": "/media/nfs/tmp_data/fenghr/SegviGen/data_toolkit/assets/output.glb",
3
+ "color_quant_step": 8,
4
+ "max_image_samples": 2000000,
5
+ "nodes": [
6
+ {
7
+ "node": "geometry_0",
8
+ "geom": "geometry_0",
9
+ "n_faces": 92502,
10
+ "n_verts": 65312,
11
+ "visual_type": "TextureVisuals",
12
+ "face_colors": null,
13
+ "vertex_colors": null,
14
+ "textures": [],
15
+ "material": {
16
+ "base_color_factor_or_main_color": [
17
+ 255.0,
18
+ 255.0,
19
+ 255.0,
20
+ 255.0
21
+ ]
22
+ }
23
+ }
24
+ ],
25
+ "summary": {
26
+ "total_face_color_entries": 0,
27
+ "total_vertex_color_entries": 0,
28
+ "total_texture_pixels_sampled": 0
29
+ }
30
+ }
data_toolkit/bpy_render.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import json
3
+ import math
4
+ import mathutils
5
+ import numpy as np
6
+
7
+
8
+ class BpyRenderer:
9
+ def __init__(self, resolution=512, engine="BLENDER_EEVEE", geo_mode=False, split_normal=False):
10
+ """
11
+ engine:
12
+ - "CYCLES"
13
+ - "BLENDER_EEVEE" (Blender 3.x common)
14
+ - "BLENDER_EEVEE_NEXT" (Blender 4.x common)
15
+ - "EEVEE" / "EEVEE_NEXT" (aliases, optional)
16
+ """
17
+ self.resolution = resolution
18
+ self.engine = engine
19
+ self.geo_mode = geo_mode
20
+ self.split_normal = split_normal
21
+ self.import_functions = self._setup_import_functions()
22
+
23
+ def _setup_import_functions(self):
24
+ import_functions = {
25
+ "obj": bpy.ops.wm.obj_import,
26
+ "glb": bpy.ops.import_scene.gltf,
27
+ "gltf": bpy.ops.import_scene.gltf,
28
+ "usd": bpy.ops.import_scene.usd,
29
+ "fbx": bpy.ops.import_scene.fbx,
30
+ "stl": bpy.ops.import_mesh.stl,
31
+ "usda": bpy.ops.import_scene.usda,
32
+ "dae": bpy.ops.wm.collada_import,
33
+ "ply": bpy.ops.wm.ply_import,
34
+ "abc": bpy.ops.wm.alembic_import,
35
+ "blend": bpy.ops.wm.append,
36
+ }
37
+ return import_functions
38
+
39
+ # -------------------------
40
+ # Engine helpers
41
+ # -------------------------
42
+ def _resolve_render_engine(self, requested: str) -> str:
43
+ """
44
+ Robustly set render engine across Blender versions.
45
+ Blender 4.x may not accept "BLENDER_EEVEE" and instead uses "BLENDER_EEVEE_NEXT".
46
+ """
47
+ req = (requested or "").upper()
48
+
49
+ if req in {"EEVEE", "BLENDER_EEVEE"}:
50
+ candidates = ["BLENDER_EEVEE", "BLENDER_EEVEE_NEXT"]
51
+ elif req in {"EEVEE_NEXT", "BLENDER_EEVEE_NEXT"}:
52
+ candidates = ["BLENDER_EEVEE_NEXT", "BLENDER_EEVEE"]
53
+ elif req in {"CYCLES"}:
54
+ candidates = ["CYCLES"]
55
+ elif req in {"WORKBENCH", "BLENDER_WORKBENCH"}:
56
+ candidates = ["BLENDER_WORKBENCH"]
57
+ else:
58
+ candidates = [requested]
59
+
60
+ last_err = None
61
+ for eng in candidates:
62
+ try:
63
+ bpy.context.scene.render.engine = eng
64
+ return eng
65
+ except Exception as e:
66
+ last_err = e
67
+ continue
68
+
69
+ raise ValueError(f"Failed to set render engine from {candidates}. Last error: {last_err}")
70
+
71
+ def _init_eevee_settings(self, render_samples: int = 64):
72
+ """
73
+ EEVEE / EEVEE Next settings (close to huanngzh/bpy-renderer defaults).
74
+ """
75
+ scene = bpy.context.scene
76
+
77
+ # Render basics
78
+ scene.render.image_settings.file_format = "PNG"
79
+ scene.render.image_settings.color_mode = "RGBA"
80
+ scene.render.film_transparent = True
81
+
82
+ # EEVEE quality knobs
83
+ # In Blender, eevee settings live under scene.eevee.
84
+ # These fields are used by many scripts including bpy-renderer. :contentReference[oaicite:2]{index=2}
85
+ if hasattr(scene, "eevee"):
86
+ try:
87
+ scene.eevee.taa_render_samples = int(render_samples)
88
+ except Exception:
89
+ pass
90
+ # These flags may not exist in every minor version; guard them.
91
+ for name, val in [
92
+ ("use_gtao", True),
93
+ ("use_ssr", True),
94
+ ("use_bloom", True),
95
+ ]:
96
+ if hasattr(scene.eevee, name):
97
+ try:
98
+ setattr(scene.eevee, name, val)
99
+ except Exception:
100
+ pass
101
+
102
+ # Normals quality (also in bpy-renderer init) :contentReference[oaicite:3]{index=3}
103
+ if hasattr(scene.render, "use_high_quality_normals"):
104
+ try:
105
+ scene.render.use_high_quality_normals = True
106
+ except Exception:
107
+ pass
108
+
109
+ def _init_cycles_settings(self, render_samples: int = 128):
110
+ scene = bpy.context.scene
111
+
112
+ scene.render.image_settings.file_format = "PNG"
113
+ scene.render.image_settings.color_mode = "RGBA"
114
+ scene.render.film_transparent = True
115
+
116
+ scene.cycles.samples = int(render_samples)
117
+ scene.cycles.filter_type = "BOX"
118
+ scene.cycles.filter_width = 1
119
+ scene.cycles.diffuse_bounces = 1
120
+ scene.cycles.glossy_bounces = 1
121
+ scene.cycles.transparent_max_bounces = (3 if not self.geo_mode else 0)
122
+ scene.cycles.transmission_bounces = (3 if not self.geo_mode else 1)
123
+ scene.cycles.use_denoising = True
124
+
125
+ # GPU (best-effort)
126
+ try:
127
+ scene.cycles.device = "GPU"
128
+ bpy.context.preferences.addons["cycles"].preferences.get_devices()
129
+ bpy.context.preferences.addons["cycles"].preferences.compute_device_type = "CUDA"
130
+ except Exception:
131
+ pass
132
+
133
+ # -------------------------
134
+ # Public init
135
+ # -------------------------
136
+ def init_render_settings(self):
137
+ # Resolution
138
+ bpy.context.scene.render.resolution_x = self.resolution
139
+ bpy.context.scene.render.resolution_y = self.resolution
140
+ bpy.context.scene.render.resolution_percentage = 100
141
+
142
+ # Pick engine robustly (EEVEE vs EEVEE_NEXT etc.)
143
+ actual_engine = self._resolve_render_engine(self.engine)
144
+
145
+ # Samples:
146
+ # - For geo_mode: keep minimal samples for speed
147
+ # - For RGB: moderate samples
148
+ if actual_engine == "CYCLES":
149
+ samples = 128 if not self.geo_mode else 1
150
+ self._init_cycles_settings(render_samples=samples)
151
+ else:
152
+ # EEVEE family
153
+ samples = 64 if not self.geo_mode else 1
154
+ self._init_eevee_settings(render_samples=samples)
155
+
156
+ def init_scene(self):
157
+ for obj in bpy.data.objects:
158
+ bpy.data.objects.remove(obj, do_unlink=True)
159
+ for material in bpy.data.materials:
160
+ bpy.data.materials.remove(material, do_unlink=True)
161
+ for texture in bpy.data.textures:
162
+ bpy.data.textures.remove(texture, do_unlink=True)
163
+ for image in bpy.data.images:
164
+ bpy.data.images.remove(image, do_unlink=True)
165
+
166
+ def init_camera(self):
167
+ cam = bpy.data.objects.new("Camera", bpy.data.cameras.new("Camera"))
168
+ bpy.context.collection.objects.link(cam)
169
+ bpy.context.scene.camera = cam
170
+ cam.data.sensor_height = cam.data.sensor_width = 32
171
+ cam_constraint = cam.constraints.new(type="TRACK_TO")
172
+ cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
173
+ cam_constraint.up_axis = "UP_Y"
174
+ cam_empty = bpy.data.objects.new("Empty", None)
175
+ cam_empty.location = (0, 0, 0)
176
+ bpy.context.scene.collection.objects.link(cam_empty)
177
+ cam_constraint.target = cam_empty
178
+ return cam
179
+
180
+ def init_lighting(self):
181
+ bpy.ops.object.select_all(action="DESELECT")
182
+ bpy.ops.object.select_by_type(type="LIGHT")
183
+ bpy.ops.object.delete()
184
+
185
+ default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
186
+ bpy.context.collection.objects.link(default_light)
187
+ default_light.data.energy = 1000
188
+ default_light.location = (4, 1, 6)
189
+ default_light.rotation_euler = (0, 0, 0)
190
+
191
+ top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
192
+ bpy.context.collection.objects.link(top_light)
193
+ top_light.data.energy = 10000
194
+ top_light.location = (0, 0, 10)
195
+ top_light.scale = (100, 100, 100)
196
+
197
+ bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
198
+ bpy.context.collection.objects.link(bottom_light)
199
+ bottom_light.data.energy = 1000
200
+ bottom_light.location = (0, 0, -10)
201
+ bottom_light.rotation_euler = (0, 0, 0)
202
+ return {"default_light": default_light, "top_light": top_light, "bottom_light": bottom_light}
203
+
204
+ def load_object(self, object_path):
205
+ file_extension = object_path.split(".")[-1].lower()
206
+ if file_extension not in self.import_functions:
207
+ raise ValueError(f"Unsupported file type: {file_extension}")
208
+ import_function = self.import_functions[file_extension]
209
+ print(f"Loading object from {object_path}")
210
+ if file_extension == "blend":
211
+ import_function(directory=object_path, link=False)
212
+ elif file_extension in {"glb", "gltf"}:
213
+ import_function(filepath=object_path, merge_vertices=True, import_shading="NORMALS")
214
+ else:
215
+ import_function(filepath=object_path)
216
+
217
+ def delete_invisible_objects(self):
218
+ bpy.ops.object.select_all(action="DESELECT")
219
+ for obj in bpy.context.scene.objects:
220
+ if obj.hide_viewport or obj.hide_render:
221
+ obj.hide_viewport = False
222
+ obj.hide_render = False
223
+ obj.hide_select = False
224
+ obj.select_set(True)
225
+ bpy.ops.object.delete()
226
+ invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
227
+ for col in invisible_collections:
228
+ bpy.data.collections.remove(col)
229
+
230
+ def split_mesh_normal(self):
231
+ bpy.ops.object.select_all(action="DESELECT")
232
+ objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
233
+ bpy.context.view_layer.objects.active = objs[0]
234
+ for obj in objs:
235
+ obj.select_set(True)
236
+ bpy.ops.object.mode_set(mode="EDIT")
237
+ bpy.ops.mesh.select_all(action="SELECT")
238
+ bpy.ops.mesh.split_normals()
239
+ bpy.ops.object.mode_set(mode="OBJECT")
240
+ bpy.ops.object.select_all(action="DESELECT")
241
+
242
+ def override_material(self):
243
+ new_mat = bpy.data.materials.new(name="Override0123456789")
244
+ new_mat.use_nodes = True
245
+ new_mat.node_tree.nodes.clear()
246
+ bsdf = new_mat.node_tree.nodes.new("ShaderNodeBsdfDiffuse")
247
+ bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
248
+ bsdf.inputs[1].default_value = 1
249
+ output = new_mat.node_tree.nodes.new("ShaderNodeOutputMaterial")
250
+ new_mat.node_tree.links.new(bsdf.outputs["BSDF"], output.inputs["Surface"])
251
+ bpy.context.scene.view_layers["View Layer"].material_override = new_mat
252
+
253
+ def scene_bbox(self):
254
+ bbox_min = (math.inf,) * 3
255
+ bbox_max = (-math.inf,) * 3
256
+ found = False
257
+ scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
258
+ for obj in scene_meshes:
259
+ found = True
260
+ for coord in obj.bound_box:
261
+ coord = mathutils.Vector(coord)
262
+ coord = obj.matrix_world @ coord
263
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
264
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
265
+ if not found:
266
+ raise RuntimeError("no objects in scene to compute bounding box for")
267
+ return mathutils.Vector(bbox_min), mathutils.Vector(bbox_max)
268
+
269
+ def normalize_scene(self):
270
+ scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
271
+ if len(scene_root_objects) > 1:
272
+ scene = bpy.data.objects.new("ParentEmpty", None)
273
+ bpy.context.scene.collection.objects.link(scene)
274
+ for obj in scene_root_objects:
275
+ obj.parent = scene
276
+ else:
277
+ scene = scene_root_objects[0]
278
+
279
+ bbox_min, bbox_max = self.scene_bbox()
280
+ print(f"[INFO] Bounding box: {bbox_min}, {bbox_max}")
281
+ scale = 1 / max(bbox_max - bbox_min)
282
+ scene.scale = scene.scale * scale
283
+ bpy.context.view_layer.update()
284
+ bbox_min, bbox_max = self.scene_bbox()
285
+ offset = -(bbox_min + bbox_max) / 2
286
+ scene.matrix_world.translation += offset
287
+ bpy.ops.object.select_all(action="DESELECT")
288
+ return scale, offset
289
+
290
+ def set_camera_from_matrix(self, cam, transform_matrix):
291
+ matrix = mathutils.Matrix(transform_matrix)
292
+ cam.matrix_world = matrix
293
+ bpy.context.view_layer.update()
294
+
295
+ def render_from_transforms(self, file_path, transforms_json_path, output_path):
296
+ with open(transforms_json_path, "r") as f:
297
+ transforms_data = json.load(f)
298
+
299
+ self.init_render_settings()
300
+
301
+ # Load scene
302
+ if file_path.endswith(".blend"):
303
+ self.delete_invisible_objects()
304
+ else:
305
+ self.init_scene()
306
+ self.load_object(file_path)
307
+ if self.split_normal:
308
+ self.split_mesh_normal()
309
+ print("[INFO] Scene initialized.")
310
+
311
+ scale, offset = self.normalize_scene()
312
+ print(f"[INFO] Scene normalized with auto scale: {scale}, offset: {offset}")
313
+
314
+ cam = self.init_camera()
315
+ self.init_lighting()
316
+ print("[INFO] Camera and lighting initialized.")
317
+ if self.geo_mode:
318
+ self.override_material()
319
+
320
+ # NOTE: your transforms_json format seems like a list-of-dicts.
321
+ transform_matrix = transforms_data[0]["transform_matrix"]
322
+ camera_angle_x = transforms_data[0].get("camera_angle_x", None)
323
+
324
+ self.set_camera_from_matrix(cam, transform_matrix)
325
+ if camera_angle_x is not None:
326
+ cam.data.lens = 16 / np.tan(camera_angle_x / 2)
327
+
328
+ bpy.context.scene.render.filepath = output_path
329
+ bpy.ops.render.render(write_still=True)
330
+ bpy.context.view_layer.update()
331
+
332
+
333
+ def render_from_transforms(
334
+ file_path,
335
+ transforms_json_path,
336
+ output_path,
337
+ resolution=512,
338
+ engine="BLENDER_EEVEE",
339
+ geo_mode=False,
340
+ split_normal=False,
341
+ ):
342
+ renderer = BpyRenderer(resolution=resolution, engine=engine, geo_mode=geo_mode, split_normal=split_normal)
343
+ return renderer.render_from_transforms(file_path, transforms_json_path, output_path)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ file_path = "./assets/example.glb"
348
+ transforms_json_path = "transforms.json"
349
+ output_path = "./assets/img.png"
350
+
351
+ # Recommended:
352
+ # - engine="BLENDER_EEVEE" for Blender 3.x
353
+ # - engine="BLENDER_EEVEE_NEXT" for Blender 4.x
354
+ # This script auto-fallbacks between them.
355
+ render_from_transforms(
356
+ file_path=file_path,
357
+ transforms_json_path=transforms_json_path,
358
+ output_path=output_path,
359
+ resolution=512,
360
+ engine="BLENDER_EEVEE",
361
+ )
362
+
363
+ # import bpy
364
+ # import json
365
+ # import math
366
+ # import mathutils
367
+ # import numpy as np
368
+
369
+ # class BpyRenderer:
370
+ # def __init__(self, resolution=512, engine="CYCLES", geo_mode=False, split_normal=False):
371
+ # self.resolution = resolution
372
+ # self.engine = engine
373
+ # self.geo_mode = geo_mode
374
+ # self.split_normal = split_normal
375
+ # self.import_functions = self._setup_import_functions()
376
+
377
+ # def _setup_import_functions(self):
378
+ # import_functions = {
379
+ # "obj": bpy.ops.wm.obj_import,
380
+ # "glb": bpy.ops.import_scene.gltf,
381
+ # "gltf": bpy.ops.import_scene.gltf,
382
+ # "usd": bpy.ops.import_scene.usd,
383
+ # "fbx": bpy.ops.import_scene.fbx,
384
+ # "stl": bpy.ops.import_mesh.stl,
385
+ # "usda": bpy.ops.import_scene.usda,
386
+ # "dae": bpy.ops.wm.collada_import,
387
+ # "ply": bpy.ops.wm.ply_import,
388
+ # "abc": bpy.ops.wm.alembic_import,
389
+ # "blend": bpy.ops.wm.append,
390
+ # }
391
+ # return import_functions
392
+
393
+ # def init_render_settings(self):
394
+ # bpy.context.scene.render.engine = self.engine
395
+ # bpy.context.scene.render.resolution_x = self.resolution
396
+ # bpy.context.scene.render.resolution_y = self.resolution
397
+ # bpy.context.scene.render.resolution_percentage = 100
398
+ # bpy.context.scene.render.image_settings.file_format = "PNG"
399
+ # bpy.context.scene.render.image_settings.color_mode = "RGBA"
400
+ # bpy.context.scene.render.film_transparent = True
401
+ # if self.engine == "CYCLES":
402
+ # bpy.context.scene.render.engine = "CYCLES"
403
+ # bpy.context.scene.cycles.samples = 128 if not self.geo_mode else 1
404
+ # bpy.context.scene.cycles.filter_type = "BOX"
405
+ # bpy.context.scene.cycles.filter_width = 1
406
+ # bpy.context.scene.cycles.diffuse_bounces = 1
407
+ # bpy.context.scene.cycles.glossy_bounces = 1
408
+ # bpy.context.scene.cycles.transparent_max_bounces = (3 if not self.geo_mode else 0)
409
+ # bpy.context.scene.cycles.transmission_bounces = (3 if not self.geo_mode else 1)
410
+ # bpy.context.scene.cycles.use_denoising = True
411
+ # try:
412
+ # bpy.context.scene.cycles.device = "GPU"
413
+ # bpy.context.preferences.addons["cycles"].preferences.get_devices()
414
+ # bpy.context.preferences.addons["cycles"].preferences.compute_device_type = "CUDA"
415
+ # except:
416
+ # pass
417
+
418
+ # def init_scene(self):
419
+ # for obj in bpy.data.objects:
420
+ # bpy.data.objects.remove(obj, do_unlink=True)
421
+ # for material in bpy.data.materials:
422
+ # bpy.data.materials.remove(material, do_unlink=True)
423
+ # for texture in bpy.data.textures:
424
+ # bpy.data.textures.remove(texture, do_unlink=True)
425
+ # for image in bpy.data.images:
426
+ # bpy.data.images.remove(image, do_unlink=True)
427
+
428
+ # def init_camera(self):
429
+ # cam = bpy.data.objects.new("Camera", bpy.data.cameras.new("Camera"))
430
+ # bpy.context.collection.objects.link(cam)
431
+ # bpy.context.scene.camera = cam
432
+ # cam.data.sensor_height = cam.data.sensor_width = 32
433
+ # cam_constraint = cam.constraints.new(type="TRACK_TO")
434
+ # cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
435
+ # cam_constraint.up_axis = "UP_Y"
436
+ # cam_empty = bpy.data.objects.new("Empty", None)
437
+ # cam_empty.location = (0, 0, 0)
438
+ # bpy.context.scene.collection.objects.link(cam_empty)
439
+ # cam_constraint.target = cam_empty
440
+ # return cam
441
+
442
+ # def init_lighting(self):
443
+ # bpy.ops.object.select_all(action="DESELECT")
444
+ # bpy.ops.object.select_by_type(type="LIGHT")
445
+ # bpy.ops.object.delete()
446
+
447
+ # default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
448
+ # bpy.context.collection.objects.link(default_light)
449
+ # default_light.data.energy = 1000
450
+ # default_light.location = (4, 1, 6)
451
+ # default_light.rotation_euler = (0, 0, 0)
452
+
453
+ # top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
454
+ # bpy.context.collection.objects.link(top_light)
455
+ # top_light.data.energy = 10000
456
+ # top_light.location = (0, 0, 10)
457
+ # top_light.scale = (100, 100, 100)
458
+
459
+ # bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
460
+ # bpy.context.collection.objects.link(bottom_light)
461
+ # bottom_light.data.energy = 1000
462
+ # bottom_light.location = (0, 0, -10)
463
+ # bottom_light.rotation_euler = (0, 0, 0)
464
+ # return {"default_light": default_light, "top_light": top_light, "bottom_light": bottom_light}
465
+
466
+ # def load_object(self, object_path):
467
+ # file_extension = object_path.split(".")[-1].lower()
468
+ # if file_extension not in self.import_functions:
469
+ # raise ValueError(f"Unsupported file type: {file_extension}")
470
+ # import_function = self.import_functions[file_extension]
471
+ # print(f"Loading object from {object_path}")
472
+ # if file_extension == "blend":
473
+ # import_function(directory=object_path, link=False)
474
+ # elif file_extension in {"glb", "gltf"}:
475
+ # import_function(filepath=object_path, merge_vertices=True, import_shading="NORMALS")
476
+ # else:
477
+ # import_function(filepath=object_path)
478
+
479
+ # def delete_invisible_objects(self):
480
+ # bpy.ops.object.select_all(action="DESELECT")
481
+ # for obj in bpy.context.scene.objects:
482
+ # if obj.hide_viewport or obj.hide_render:
483
+ # obj.hide_viewport = False
484
+ # obj.hide_render = False
485
+ # obj.hide_select = False
486
+ # obj.select_set(True)
487
+ # bpy.ops.object.delete()
488
+ # invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
489
+ # for col in invisible_collections:
490
+ # bpy.data.collections.remove(col)
491
+
492
+ # def unhide_all_objects(self):
493
+ # for obj in bpy.context.scene.objects:
494
+ # obj.hide_set(False)
495
+
496
+ # def convert_to_meshes(self):
497
+ # bpy.ops.object.select_all(action="DESELECT")
498
+ # bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
499
+ # for obj in bpy.context.scene.objects:
500
+ # obj.select_set(True)
501
+ # bpy.ops.object.convert(target="MESH")
502
+
503
+ # def triangulate_meshes(self):
504
+ # bpy.ops.object.select_all(action="DESELECT")
505
+ # objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
506
+ # bpy.context.view_layer.objects.active = objs[0]
507
+ # for obj in objs:
508
+ # obj.select_set(True)
509
+ # bpy.ops.object.mode_set(mode="EDIT")
510
+ # bpy.ops.mesh.reveal()
511
+ # bpy.ops.mesh.select_all(action="SELECT")
512
+ # bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
513
+ # bpy.ops.object.mode_set(mode="OBJECT")
514
+ # bpy.ops.object.select_all(action="DESELECT")
515
+
516
+ # def split_mesh_normal(self):
517
+ # bpy.ops.object.select_all(action="DESELECT")
518
+ # objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
519
+ # bpy.context.view_layer.objects.active = objs[0]
520
+ # for obj in objs:
521
+ # obj.select_set(True)
522
+ # bpy.ops.object.mode_set(mode="EDIT")
523
+ # bpy.ops.mesh.select_all(action="SELECT")
524
+ # bpy.ops.mesh.split_normals()
525
+ # bpy.ops.object.mode_set(mode="OBJECT")
526
+ # bpy.ops.object.select_all(action="DESELECT")
527
+
528
+ # def override_material(self):
529
+ # new_mat = bpy.data.materials.new(name="Override0123456789")
530
+ # new_mat.use_nodes = True
531
+ # new_mat.node_tree.nodes.clear()
532
+ # bsdf = new_mat.node_tree.nodes.new("ShaderNodeBsdfDiffuse")
533
+ # bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
534
+ # bsdf.inputs[1].default_value = 1
535
+ # output = new_mat.node_tree.nodes.new("ShaderNodeOutputMaterial")
536
+ # new_mat.node_tree.links.new(bsdf.outputs["BSDF"], output.inputs["Surface"])
537
+ # bpy.context.scene.view_layers["View Layer"].material_override = new_mat
538
+
539
+ # def scene_bbox(self):
540
+ # bbox_min = (math.inf,) * 3
541
+ # bbox_max = (-math.inf,) * 3
542
+ # found = False
543
+ # scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
544
+ # for obj in scene_meshes:
545
+ # found = True
546
+ # for coord in obj.bound_box:
547
+ # coord = mathutils.Vector(coord)
548
+ # coord = obj.matrix_world @ coord
549
+ # bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
550
+ # bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
551
+ # if not found:
552
+ # raise RuntimeError("no objects in scene to compute bounding box for")
553
+ # return mathutils.Vector(bbox_min), mathutils.Vector(bbox_max)
554
+
555
+ # def normalize_scene(self):
556
+ # scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
557
+ # if len(scene_root_objects) > 1:
558
+ # scene = bpy.data.objects.new("ParentEmpty", None)
559
+ # bpy.context.scene.collection.objects.link(scene)
560
+ # for obj in scene_root_objects:
561
+ # obj.parent = scene
562
+ # else:
563
+ # scene = scene_root_objects[0]
564
+
565
+ # bbox_min, bbox_max = self.scene_bbox()
566
+ # print(f"[INFO] Bounding box: {bbox_min}, {bbox_max}")
567
+ # scale = 1 / max(bbox_max - bbox_min)
568
+ # scene.scale = scene.scale * scale
569
+ # bpy.context.view_layer.update()
570
+ # bbox_min, bbox_max = self.scene_bbox()
571
+ # offset = -(bbox_min + bbox_max) / 2
572
+ # scene.matrix_world.translation += offset
573
+ # bpy.ops.object.select_all(action="DESELECT")
574
+ # return scale, offset
575
+
576
+ # def set_camera_from_matrix(self, cam, transform_matrix):
577
+ # matrix = mathutils.Matrix(transform_matrix)
578
+ # cam.matrix_world = matrix
579
+ # bpy.context.view_layer.update()
580
+
581
+ # def render_from_transforms(self, file_path, transforms_json_path, output_path):
582
+ # with open(transforms_json_path, 'r') as f:
583
+ # transforms_data = json.load(f)
584
+
585
+ # self.init_render_settings()
586
+
587
+ # if file_path.endswith(".blend"):
588
+ # self.delete_invisible_objects()
589
+ # else:
590
+ # self.init_scene()
591
+ # self.load_object(file_path)
592
+ # if self.split_normal:
593
+ # self.split_mesh_normal()
594
+ # print("[INFO] Scene initialized.")
595
+
596
+ # scale, offset = self.normalize_scene()
597
+ # print(f"[INFO] Scene normalized with auto scale: {scale}, offset: {offset}")
598
+
599
+ # cam = self.init_camera()
600
+ # self.init_lighting()
601
+ # print("[INFO] Camera and lighting initialized.")
602
+ # if self.geo_mode:
603
+ # self.override_material()
604
+
605
+ # transform_matrix = transforms_data[0]["transform_matrix"]
606
+ # camera_angle_x = transforms_data[0]["camera_angle_x"]
607
+ # self.set_camera_from_matrix(cam, transform_matrix)
608
+ # if camera_angle_x is not None:
609
+ # cam.data.lens = 16 / np.tan(camera_angle_x / 2)
610
+
611
+ # bpy.context.scene.render.filepath = output_path
612
+ # bpy.ops.render.render(write_still=True)
613
+ # bpy.context.view_layer.update()
614
+
615
+ # def render_from_transforms(file_path, transforms_json_path, output_path, resolution=512, engine="CYCLES", geo_mode=False, split_normal=False):
616
+ # renderer = BpyRenderer(resolution=resolution, engine=engine, geo_mode=geo_mode, split_normal=split_normal)
617
+ # return renderer.render_from_transforms(file_path, transforms_json_path, output_path)
618
+
619
+ # if __name__ == "__main__":
620
+ # file_path = "./assets/example.glb"
621
+ # transforms_json_path = "transforms.json"
622
+ # output_path = "./assets/img.png"
623
+ # render_from_transforms(file_path=file_path, transforms_json_path=transforms_json_path, output_path=output_path)
data_toolkit/color_glb.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ from trimesh.visual.material import PBRMaterial
8
+
9
+ def _load_as_single_mesh(part_path):
10
+ obj = trimesh.load(part_path, force="scene")
11
+ if isinstance(obj, trimesh.Scene):
12
+ dumped = obj.dump()
13
+ meshes = [m for m in dumped if isinstance(m, trimesh.Trimesh) and len(m.vertices) > 0]
14
+ return trimesh.util.concatenate(meshes)
15
+ if isinstance(obj, trimesh.Trimesh):
16
+ return obj
17
+
18
+ def set_mesh_solid_pbr(mesh, rgba_uint8=(255, 255, 255, 255), emissive=True):
19
+ rgb = np.array(rgba_uint8[:3], dtype=np.float32) / 255.0
20
+ a = float(rgba_uint8[3]) / 255.0
21
+ colors = np.tile(np.array(rgba_uint8, dtype=np.uint8), (len(mesh.vertices), 1))
22
+ mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh, vertex_colors=colors)
23
+ mat_kwargs = dict(
24
+ baseColorFactor=[float(rgb[0]), float(rgb[1]), float(rgb[2]), a],
25
+ metallicFactor=0.0,
26
+ roughnessFactor=1.0,
27
+ )
28
+ if emissive:
29
+ mat_kwargs["emissiveFactor"] = [float(rgb[0]), float(rgb[1]), float(rgb[2])]
30
+ mesh.visual.material = PBRMaterial(**mat_kwargs)
31
+ return mesh
32
+
33
+ def color_glb(parts_path, output_path, interactive):
34
+ part_meshes = []
35
+ for part_name in sorted(os.listdir(parts_path)):
36
+ part_path = os.path.join(parts_path, part_name)
37
+ part_meshes.append(_load_as_single_mesh(part_path))
38
+
39
+ if interactive:
40
+ for i in range(len(part_meshes)):
41
+ colors = [(255, 255, 255, 255) if j == i else (0, 0, 0, 255) for j in range(len(part_meshes))]
42
+ scene = trimesh.Scene()
43
+ for j, m in enumerate(part_meshes):
44
+ mc = m.copy()
45
+ set_mesh_solid_pbr(mc, rgba_uint8=colors[j], emissive=True)
46
+ scene.add_geometry(mc, node_name=f"part_{j}", geom_name=f"geom_{j}")
47
+ os.makedirs(os.path.join(output_path, f"{i}"), exist_ok=True)
48
+ scene.export(os.path.join(output_path, f"{i}", "output.glb"))
49
+ else:
50
+ colors = []
51
+ for i in range(len(part_meshes)):
52
+ while True:
53
+ rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 255)
54
+ if rgb not in colors:
55
+ colors.append(rgb)
56
+ break
57
+
58
+ # colors_base = [ # Static colors
59
+ # (0, 0, 0, 255),
60
+ # (0, 0, 255, 255),
61
+ # (0, 255, 0, 255),
62
+ # (0, 255, 255, 255),
63
+ # (255, 0, 0, 255),
64
+ # (255, 0, 255, 255),
65
+ # (255, 255, 0, 255),
66
+ # (255, 255, 255, 255)
67
+ # ]
68
+ # colors = random.sample(colors_base, len(part_meshes))
69
+
70
+ with open(os.path.join(output_path, "colors.json"), "w", encoding="utf-8") as f:
71
+ json.dump([list(c) for c in colors], f, ensure_ascii=False, indent=4)
72
+
73
+ scene = trimesh.Scene()
74
+ for i, m in enumerate(part_meshes):
75
+ mc = m.copy()
76
+ set_mesh_solid_pbr(mc, rgba_uint8=colors[i], emissive=True)
77
+ scene.add_geometry(mc, node_name=f"part_{i}", geom_name=f"geom_{i}")
78
+ scene.export(os.path.join(output_path, "output.glb"))
79
+
80
+ if __name__ == "__main__":
81
+ parts_path = "./assets/parts"
82
+ output_path = "./assets/interactive_seg"
83
+ interactive = True
84
+ color_glb(parts_path, output_path, interactive)
data_toolkit/color_img.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import numpy as np
5
+ import trimesh
6
+ import torch
7
+ import nvdiffrast.torch as nr
8
+ from PIL import Image
9
+
10
+ def build_projection_matrix(fov, width, height, z_near=0.01, z_far=100.0):
11
+ aspect = float(width) / float(height)
12
+ fov_y = 2.0 * np.arctan(np.tan(fov / 2.0) / aspect)
13
+ f = 1.0 / np.tan(fov_y / 2.0)
14
+ P = np.array([
15
+ [f / aspect, 0.0, 0.0, 0.0],
16
+ [0.0, f, 0.0, 0.0],
17
+ [0.0, 0.0, (z_far + z_near) / (z_near - z_far), (2.0 * z_far * z_near) / (z_near - z_far)],
18
+ [0.0, 0.0, -1.0, 0.0],
19
+ ], dtype=np.float32)
20
+ return P
21
+
22
+ def compute_bbox_center_and_scale_like_blender(vertices):
23
+ bbox_min = vertices.min(axis=0)
24
+ bbox_max = vertices.max(axis=0)
25
+ bbox_extents = bbox_max - bbox_min
26
+ scale = 1.0 / np.max(bbox_extents)
27
+ offset = -(bbox_min + bbox_max) / 2.0
28
+ return offset, scale
29
+
30
+ def _load_as_single_mesh(part_path):
31
+ obj = trimesh.load(part_path, force="scene")
32
+ if isinstance(obj, trimesh.Scene):
33
+ dumped = obj.dump()
34
+ meshes = [m for m in dumped if isinstance(m, trimesh.Trimesh) and len(m.vertices) > 0]
35
+ return trimesh.util.concatenate(meshes)
36
+ if isinstance(obj, trimesh.Trimesh):
37
+ return obj
38
+
39
+ def load_parts_from_directory(object_path):
40
+ per_part_vertices = []
41
+ per_part_faces = []
42
+ part_names = []
43
+ vertices_counts = []
44
+ vertex_offset = 0
45
+ for part_name in sorted(os.listdir(object_path)):
46
+ part_path = os.path.join(object_path, part_name)
47
+ mesh = _load_as_single_mesh(part_path)
48
+ v = mesh.vertices.astype(np.float32)
49
+ f = mesh.faces.astype(np.int32)
50
+ per_part_vertices.append(v)
51
+ per_part_faces.append(f + vertex_offset)
52
+ part_names.append(part_name)
53
+ vertices_counts.append(v.shape[0])
54
+ vertex_offset += v.shape[0]
55
+ return per_part_vertices, per_part_faces, part_names, vertices_counts
56
+
57
+ def save_png(path, array_uint8):
58
+ os.makedirs(os.path.dirname(path), exist_ok=True)
59
+ Image.fromarray(array_uint8, mode="RGB").save(path)
60
+
61
+ def render_views(glctx, V, F, C, output_path, transforms_path):
62
+ width = 512
63
+ height = 512
64
+ fov = 40.0*np.pi/180.0
65
+ V_t = torch.from_numpy(V).to(torch.float32).cuda()
66
+ F_t = torch.from_numpy(F).to(torch.int32).cuda()
67
+ C_t = torch.from_numpy(C).to(torch.float32).cuda()
68
+ theta = np.pi / 2.0
69
+ Gx = np.array([
70
+ [1.0, 0.0, 0.0, 0.0],
71
+ [0.0, np.cos(theta), -np.sin(theta), 0.0],
72
+ [0.0, np.sin(theta), np.cos(theta), 0.0],
73
+ [0.0, 0.0, 0.0, 1.0],
74
+ ], dtype=np.float32)
75
+ Gx_t = torch.from_numpy(Gx).to(torch.float32).cuda()
76
+
77
+ with open(transforms_path, "r") as f:
78
+ transforms = json.load(f)
79
+ cam_to_world = np.array(transforms[0]["transform_matrix"], dtype=np.float32)
80
+ world_to_cam = np.linalg.inv(cam_to_world)
81
+ P = build_projection_matrix(fov, width, height)
82
+
83
+ V_mat = torch.from_numpy(world_to_cam).to(torch.float32).cuda()
84
+ P_mat = torch.from_numpy(P).to(torch.float32).cuda()
85
+ M_t = torch.eye(4, dtype=torch.float32).cuda()
86
+ pos_h = torch.cat([V_t, torch.ones((V_t.shape[0], 1), dtype=torch.float32).cuda()], dim=1)
87
+ pos_clip = (P_mat @ V_mat @ M_t @ Gx_t) @ pos_h.t()
88
+ pos_clip = pos_clip.t().contiguous().unsqueeze(0)
89
+
90
+ rast, _ = nr.rasterize(glctx, pos_clip, F_t, resolution=[height, width])
91
+ feat, _ = nr.interpolate(C_t.unsqueeze(0), rast, F_t)
92
+ cov = rast[..., 3:4]
93
+ img = feat.clamp(0.0, 1.0)
94
+ bg = torch.ones_like(img)
95
+ out = img * (cov > 0) + bg * (cov <= 0)
96
+ out_np = (out[0].cpu().numpy() * 255.0).astype(np.uint8)
97
+ out_np = out_np[::-1, :, :]
98
+ save_png(output_path, out_np)
99
+
100
+ def color_img(object_path, output_path, transforms, colors_path):
101
+ per_part_vertices, per_part_faces, part_names, vertices_counts = load_parts_from_directory(object_path)
102
+ V = np.concatenate(per_part_vertices, axis=0).astype(np.float32)
103
+ F = np.concatenate(per_part_faces, axis=0).astype(np.int32)
104
+ offset, scale = compute_bbox_center_and_scale_like_blender(V)
105
+ V_scaled = V * scale
106
+ V_norm = V_scaled + offset[None, :]
107
+ V = V_norm
108
+
109
+ with open(colors_path, "r") as f:
110
+ external_colors = json.load(f)
111
+ color_map = {}
112
+ colors = []
113
+ for idx, part_name in enumerate(part_names):
114
+ rgb = external_colors[idx][:3]
115
+ color_map[part_name] = [int(rgb[0]), int(rgb[1]), int(rgb[2])]
116
+ num_v = vertices_counts[idx]
117
+ col = (np.array(rgb, dtype=np.float32) / 255.0)[None, :]
118
+ colors.append(np.repeat(col, repeats=num_v, axis=0))
119
+
120
+ C = np.concatenate(colors, axis=0).astype(np.float32)
121
+ glctx = nr.RasterizeCudaContext()
122
+ render_views(glctx, V, F, C, output_path, transforms)
123
+
124
+ if __name__ == "__main__":
125
+ object_path = "./assets/parts"
126
+ output_path = "./assets/full_seg_w_2d_map/2d_map.png"
127
+ transforms = "transforms.json"
128
+ colors_path = "./assets/full_seg_w_2d_map/colors.json"
129
+ color_img(object_path, output_path, transforms, colors_path)
data_toolkit/example_full_seg.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
+ if ROOT_DIR not in sys.path:
5
+ sys.path.insert(0, ROOT_DIR)
6
+
7
+ from trellis2 import models
8
+ from color_glb import color_glb
9
+ from glb_to_vxz import glb_to_vxz
10
+ from vxz_to_slat import vxz_to_slat
11
+ from img_to_cond import img_to_cond
12
+ from glb_to_parts import glb_to_parts
13
+ from bpy_render import render_from_transforms
14
+ from trellis2.pipelines.rembg import BiRefNet
15
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
16
+
17
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
18
+ rembg_model.cuda()
19
+ image_cond_model = DinoV3FeatureExtractor(model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
20
+ image_cond_model.cuda()
21
+
22
+ shape_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
23
+ tex_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
24
+
25
+ glb = "./assets/example.glb"
26
+ input_vxz = "./assets/input.vxz"
27
+ parts_path = "./assets/parts"
28
+ full_seg_path = "./assets/full_seg"
29
+ interactive = False
30
+
31
+ transforms = "transforms.json"
32
+ img = "./assets/img.png"
33
+ cond = "./assets/cond.pth"
34
+
35
+ output_glb_path = os.path.join(full_seg_path, "output.glb")
36
+ output_vxz_path = os.path.join(full_seg_path, "output.vxz")
37
+
38
+ glb_to_vxz(glb, input_vxz)
39
+ glb_to_parts(glb, parts_path)
40
+ color_glb(parts_path, full_seg_path, interactive)
41
+
42
+ render_from_transforms(glb, transforms, img)
43
+ img_to_cond(rembg_model, image_cond_model, img, cond)
44
+ glb_to_vxz(output_glb_path, output_vxz_path)
45
+ vxz_to_slat(shape_encoder, tex_encoder, input_vxz, output_vxz_path, full_seg_path, interactive)
data_toolkit/example_full_seg_w_2d_map.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
+ if ROOT_DIR not in sys.path:
5
+ sys.path.insert(0, ROOT_DIR)
6
+
7
+ from trellis2 import models
8
+ from color_glb import color_glb
9
+ from color_img import color_img
10
+ from glb_to_vxz import glb_to_vxz
11
+ from vxz_to_slat import vxz_to_slat
12
+ from img_to_cond import img_to_cond
13
+ from glb_to_parts import glb_to_parts
14
+ from trellis2.pipelines.rembg import BiRefNet
15
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
16
+
17
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
18
+ rembg_model.cuda()
19
+ image_cond_model = DinoV3FeatureExtractor(model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
20
+ image_cond_model.cuda()
21
+
22
+ shape_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
23
+ tex_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
24
+
25
+ glb = "./assets/example.glb"
26
+ input_vxz = "./assets/input.vxz"
27
+ parts_path = "./assets/parts"
28
+ full_seg_w_2d_map_path = "./assets/full_seg_w_2d_map"
29
+ interactive = False
30
+
31
+ transforms = "transforms.json"
32
+ img = "./assets/full_seg_w_2d_map/2d_map.png"
33
+ cond = "./assets/full_seg_w_2d_map/cond.pth"
34
+
35
+ colors = os.path.join(full_seg_w_2d_map_path, "colors.json")
36
+ output_glb_path = os.path.join(full_seg_w_2d_map_path, "output.glb")
37
+ output_vxz_path = os.path.join(full_seg_w_2d_map_path, "output.vxz")
38
+
39
+ glb_to_vxz(glb, input_vxz)
40
+ glb_to_parts(glb, parts_path)
41
+ color_glb(parts_path, full_seg_w_2d_map_path, interactive)
42
+
43
+ color_img(parts_path, img, transforms, colors)
44
+ img_to_cond(rembg_model, image_cond_model, img, cond)
45
+ glb_to_vxz(output_glb_path, output_vxz_path)
46
+ vxz_to_slat(shape_encoder, tex_encoder, input_vxz, output_vxz_path, full_seg_w_2d_map_path, interactive)
data_toolkit/example_interactive_seg.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
+ if ROOT_DIR not in sys.path:
5
+ sys.path.insert(0, ROOT_DIR)
6
+
7
+ from trellis2 import models
8
+ from color_glb import color_glb
9
+ from glb_to_vxz import glb_to_vxz
10
+ from vxz_to_slat import vxz_to_slat
11
+ from img_to_cond import img_to_cond
12
+ from glb_to_parts import glb_to_parts
13
+ from bpy_render import render_from_transforms
14
+ from trellis2.pipelines.rembg import BiRefNet
15
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
16
+
17
+ rembg_model = BiRefNet(model_name="/root/autodl-tmp/RMBG-2.0")
18
+ rembg_model.cuda()
19
+ image_cond_model = DinoV3FeatureExtractor(model_name="/root/autodl-tmp/dinov3-vitl16-pretrain-lvd1689m")
20
+ image_cond_model.cuda()
21
+
22
+ shape_encoder = models.from_pretrained("/root/autodl-tmp/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
23
+ tex_encoder = models.from_pretrained("/root/autodl-tmp/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
24
+
25
+ glb = "./assets/example.glb"
26
+ input_vxz = "./assets/input.vxz"
27
+ parts_path = "./assets/parts"
28
+ interactive_seg_path = "./assets/interactive_seg"
29
+ interactive = True
30
+
31
+ transforms = "transforms.json"
32
+ img = "./assets/img.png"
33
+ cond = "./assets/cond.pth"
34
+
35
+ glb_to_vxz(glb, input_vxz)
36
+ glb_to_parts(glb, parts_path)
37
+ color_glb(parts_path, interactive_seg_path, interactive)
38
+
39
+ render_from_transforms(glb, transforms, img)
40
+ img_to_cond(rembg_model, image_cond_model, img, cond)
41
+ for part_name in sorted(os.listdir(interactive_seg_path)):
42
+ part_path = os.path.join(interactive_seg_path, part_name)
43
+ output_glb_path = os.path.join(part_path, "output.glb")
44
+ output_vxz_path = os.path.join(part_path, "output.vxz")
45
+ glb_to_vxz(output_glb_path, output_vxz_path)
46
+ vxz_to_slat(shape_encoder, tex_encoder, input_vxz, output_vxz_path, part_path, interactive)
data_toolkit/glb_to_parts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import trimesh
3
+
4
+ def glb_to_parts(glb_path, output_dir):
5
+ scene = trimesh.load(glb_path, force='scene')
6
+ os.makedirs(output_dir, exist_ok=True)
7
+ geometries = list(scene.geometry.values())
8
+ for idx, geometry in enumerate(geometries):
9
+ part_scene = trimesh.Scene()
10
+ part_scene.add_geometry(geometry)
11
+ output_path = os.path.join(output_dir, f"{idx}.glb")
12
+ part_scene.export(output_path)
13
+
14
+ if __name__ == "__main__":
15
+ glb_path = "./assets/example.glb"
16
+ output_dir = "./assets/parts"
17
+ glb_to_parts(glb_path, output_dir)
data_toolkit/glb_to_vxz.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import trimesh
3
+ import o_voxel
4
+
5
+ from PIL import Image
6
+
7
+ def make_texture_square_pow2(img: Image.Image, target_size=None):
8
+ w, h = img.size
9
+ max_side = max(w, h)
10
+ pow2 = 1
11
+ while pow2 < max_side:
12
+ pow2 *= 2
13
+ if target_size is not None:
14
+ pow2 = target_size
15
+ pow2 = min(pow2, 2048)
16
+ return img.resize((pow2, pow2), Image.BILINEAR)
17
+
18
+ def preprocess_scene_textures(asset):
19
+ if not isinstance(asset, trimesh.Scene):
20
+ return asset
21
+ TEX_KEYS = ["baseColorTexture", "normalTexture", "metallicRoughnessTexture", "emissiveTexture", "occlusionTexture"]
22
+ for geom in asset.geometry.values():
23
+ visual = getattr(geom, "visual", None)
24
+ mat = getattr(visual, "material", None)
25
+ if mat is None:
26
+ continue
27
+ for key in TEX_KEYS:
28
+ if not hasattr(mat, key):
29
+ continue
30
+ tex = getattr(mat, key)
31
+ if tex is None:
32
+ continue
33
+ if isinstance(tex, Image.Image):
34
+ setattr(mat, key, make_texture_square_pow2(tex))
35
+ elif hasattr(tex, "image") and tex.image is not None:
36
+ img = tex.image
37
+ if not isinstance(img, Image.Image):
38
+ img = Image.fromarray(img)
39
+ tex.image = make_texture_square_pow2(img)
40
+ if hasattr(mat, "image") and mat.image is not None:
41
+ img = mat.image
42
+ if not isinstance(img, Image.Image):
43
+ img = Image.fromarray(img)
44
+ mat.image = make_texture_square_pow2(img)
45
+ return asset
46
+
47
+ def glb_to_vxz(glb_path, vxz_path):
48
+ asset = trimesh.load(glb_path, force='scene')
49
+ asset = preprocess_scene_textures(asset)
50
+ aabb = asset.bounding_box.bounds
51
+ center = (aabb[0] + aabb[1]) / 2
52
+ scale = 0.99999 / (aabb[1] - aabb[0]).max()
53
+ asset.apply_translation(-center)
54
+ asset.apply_scale(scale)
55
+ mesh = asset.to_mesh()
56
+ vertices = torch.from_numpy(mesh.vertices).float()
57
+ faces = torch.from_numpy(mesh.faces).long()
58
+
59
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
60
+ vertices, faces, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
61
+ face_weight=1.0, boundary_weight=0.2, regularization_weight=1e-2, timing=False
62
+ )
63
+ vid = o_voxel.serialize.encode_seq(voxel_indices)
64
+ mapping = torch.argsort(vid)
65
+ voxel_indices = voxel_indices[mapping]
66
+ dual_vertices = dual_vertices[mapping]
67
+ intersected = intersected[mapping]
68
+
69
+ voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr(
70
+ asset, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], timing=False
71
+ )
72
+ vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat)
73
+ mapping_mat = torch.argsort(vid_mat)
74
+ attributes = {k: v[mapping_mat] for k, v in attributes.items()}
75
+
76
+ dual_vertices = dual_vertices * 512 - voxel_indices
77
+ dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8)
78
+ intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
79
+
80
+ attributes['dual_vertices'] = dual_vertices
81
+ attributes['intersected'] = intersected
82
+ o_voxel.io.write(vxz_path, voxel_indices, attributes)
83
+
84
+ if __name__ == "__main__":
85
+ glb_path = "./assets/example.glb"
86
+ vxz_path = "./assets/input.vxz"
87
+ glb_to_vxz(glb_path, vxz_path)
data_toolkit/img_to_cond.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
+ if ROOT_DIR not in sys.path:
5
+ sys.path.insert(0, ROOT_DIR)
6
+
7
+ import torch
8
+ import numpy as np
9
+
10
+ from PIL import Image
11
+ from trellis2.pipelines.rembg import BiRefNet
12
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
13
+
14
+ def preprocess_image(rembg_model, input):
15
+ if input.mode != "RGB":
16
+ bg = Image.new("RGB", input.size, (255, 255, 255))
17
+ bg.paste(input, mask=input.split()[3])
18
+ input = bg
19
+ has_alpha = False
20
+ if input.mode == 'RGBA':
21
+ alpha = np.array(input)[:, :, 3]
22
+ if not np.all(alpha == 255):
23
+ has_alpha = True
24
+ max_size = max(input.size)
25
+ scale = min(1, 1024 / max_size)
26
+ if scale < 1:
27
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
28
+ if has_alpha:
29
+ output = input
30
+ else:
31
+ input = input.convert('RGB')
32
+ output = rembg_model(input)
33
+ output_np = np.array(output)
34
+ alpha = output_np[:, :, 3]
35
+ bbox = np.argwhere(alpha > 0.8 * 255)
36
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
37
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
38
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
39
+ size = int(size * 1)
40
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
41
+ output = output.crop(bbox) # type: ignore
42
+ output = np.array(output).astype(np.float32) / 255
43
+ output = output[:, :, :3] * output[:, :, 3:4]
44
+ output = Image.fromarray((output * 255).astype(np.uint8))
45
+ return output
46
+
47
+ def get_cond(image_cond_model, image):
48
+ image_cond_model.image_size = 512
49
+ cond = image_cond_model(image)
50
+ neg_cond = torch.zeros_like(cond)
51
+ return {'cond': cond.cpu(), 'neg_cond': neg_cond.cpu()}
52
+
53
+ def img_to_cond(rembg_model, image_cond_model, image_path, save_path):
54
+ image = Image.open(image_path)
55
+ image = preprocess_image(rembg_model, image)
56
+ cond_dict = get_cond(image_cond_model, [image])
57
+ torch.save(cond_dict, save_path)
58
+
59
+ if __name__ == "__main__":
60
+ image_path = "./assets/img.png"
61
+ save_path = "./assets/cond.pth"
62
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
63
+ rembg_model.cuda()
64
+ image_cond_model = DinoV3FeatureExtractor(model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
65
+ image_cond_model.cuda()
66
+ img_to_cond(rembg_model, image_cond_model, image_path, save_path)
data_toolkit/texturing_pipeline.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Trellis2TexturingPipeline",
3
+ "args": {
4
+ "models": {
5
+ "shape_slat_encoder": "ckpts/shape_enc_next_dc_f16c32_fp16",
6
+ "tex_slat_decoder": "ckpts/tex_dec_next_dc_f16c32_fp16",
7
+ "tex_slat_flow_model_512": "ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16",
8
+ "tex_slat_flow_model_1024": "ckpts/slat_flow_imgshape2tex_dit_1_3B_1024_bf16"
9
+ },
10
+ "shape_slat_normalization": {
11
+ "mean": [
12
+ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
13
+ -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
14
+ 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
15
+ -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
16
+ ],
17
+ "std": [
18
+ 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
19
+ 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
20
+ 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
21
+ 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
22
+ ]
23
+ },
24
+ "tex_slat_sampler": {
25
+ "name": "FlowEulerGuidanceIntervalSampler",
26
+ "args": {
27
+ "sigma_min": 1e-5
28
+ },
29
+ "params": {
30
+ "steps": 12,
31
+ "guidance_strength": 1.0,
32
+ "guidance_rescale": 0.0,
33
+ "guidance_interval": [0.6, 0.9],
34
+ "rescale_t": 3.0
35
+ }
36
+ },
37
+ "tex_slat_normalization": {
38
+ "mean": [
39
+ 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
40
+ 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
41
+ -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
42
+ 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
43
+ ],
44
+ "std": [
45
+ 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
46
+ 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
47
+ 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
48
+ 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
49
+ ]
50
+ },
51
+ "image_cond_model": {
52
+ "name": "DinoV3FeatureExtractor",
53
+ "args": {
54
+ "model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m"
55
+ }
56
+ },
57
+ "rembg_model": {
58
+ "name": "BiRefNet",
59
+ "args": {
60
+ "model_name": "briaai/RMBG-2.0"
61
+ }
62
+ }
63
+ }
64
+ }
data_toolkit/transforms.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "camera_angle_x": 0.6981317007977318,
4
+ "transform_matrix": [
5
+ [
6
+ 0.8819212913513184,
7
+ 0.06494797021150589,
8
+ -0.46690112352371216,
9
+ -0.9338021874427795
10
+ ],
11
+ [
12
+ -0.4713967740535736,
13
+ 0.12150891125202179,
14
+ -0.8735106587409973,
15
+ -1.7470210790634155
16
+ ],
17
+ [
18
+ -4.881157167346828e-08,
19
+ 0.9904633164405823,
20
+ 0.13777753710746765,
21
+ 0.2755555510520935
22
+ ],
23
+ [
24
+ 0,
25
+ 0,
26
+ 0,
27
+ 1
28
+ ]
29
+ ]
30
+ }
31
+ ]
data_toolkit/vxz_to_slat.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
+ if ROOT_DIR not in sys.path:
5
+ sys.path.insert(0, ROOT_DIR)
6
+
7
+ import torch
8
+ import o_voxel
9
+ import trellis2.modules.sparse as sp
10
+
11
+ from trellis2 import models
12
+
13
+ def vxz_to_latent_slat(shape_encoder, tex_encoder, vxz_path, return_foreground=False):
14
+ coords, data = o_voxel.io.read(vxz_path)
15
+ coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords], dim=1).cuda()
16
+ vertices = (data['dual_vertices'].cuda() / 255)
17
+ intersected = torch.cat([data['intersected'] % 2, data['intersected'] // 2 % 2, data['intersected'] // 4 % 2], dim=-1).bool().cuda()
18
+ vertices_sparse = sp.SparseTensor(vertices, coords)
19
+ intersected_sparse = sp.SparseTensor(intersected.float(), coords)
20
+ with torch.no_grad():
21
+ shape_slat = shape_encoder(vertices_sparse, intersected_sparse)
22
+ shape_slat = sp.SparseTensor(shape_slat.feats.cuda(), shape_slat.coords.cuda())
23
+
24
+ base_color = (data['base_color'] / 255)
25
+ metallic = (data['metallic'] / 255)
26
+ roughness = (data['roughness'] / 255)
27
+ alpha = (data['alpha'] / 255)
28
+ attr = torch.cat([base_color, metallic, roughness, alpha], dim=-1).float().cuda() * 2 - 1
29
+ with torch.no_grad():
30
+ tex_slat = tex_encoder(sp.SparseTensor(attr, coords))
31
+ if return_foreground:
32
+ mask = ((base_color == 1.0).sum(dim=1) == 3)
33
+ neg_mask = ((base_color != 1.0).sum(dim=1) == 3)
34
+ tex_slat_foreground = tex_encoder(sp.SparseTensor(attr[mask], coords[mask]))
35
+ tex_slat_background = tex_encoder(sp.SparseTensor(attr[neg_mask], coords[neg_mask]))
36
+ foreground_coords = torch.unique(tex_slat_foreground.coords, dim=0)
37
+ background_coords = torch.unique(tex_slat_background.coords, dim=0)
38
+ N = background_coords.shape[0]
39
+ all_coords = torch.cat([background_coords, foreground_coords], dim=0)
40
+ _, inv = torch.unique(all_coords, dim=0, return_inverse=True)
41
+ inv_background = inv[:N]
42
+ inv_foreground = inv[N:]
43
+ keep = ~torch.isin(inv_foreground, inv_background)
44
+ foreground_coords = foreground_coords[keep]
45
+ if return_foreground:
46
+ return shape_slat, tex_slat, foreground_coords
47
+ else:
48
+ return shape_slat, tex_slat
49
+
50
+ def get_common_coords(slat1, slat2, slat3, slat4, foreground_coords_origin=None):
51
+ coords_list = [slat1.coords, slat2.coords, slat3.coords, slat4.coords]
52
+ xs = [torch.unique(x, dim=0) for x in coords_list]
53
+ all_coords = torch.cat(xs, dim=0)
54
+ uniq_coords, counts = torch.unique(all_coords, dim=0, return_counts=True)
55
+ common_coords = uniq_coords[counts == len(coords_list)].cuda()
56
+
57
+ if foreground_coords_origin is not None:
58
+ xs_foreground = [torch.unique(x, dim=0) for x in (common_coords, foreground_coords_origin)]
59
+ all_coords_foreground = torch.cat(xs_foreground, dim=0)
60
+ uniq_coords_foreground, counts_foreground = torch.unique(all_coords_foreground, dim=0, return_counts=True)
61
+ foreground_coords = uniq_coords_foreground[counts_foreground == 2].cuda()
62
+ return common_coords, foreground_coords
63
+ else:
64
+ return common_coords
65
+
66
+ def get_slat_by_common_coords(slat_origin, common_coords):
67
+ N = slat_origin.coords.shape[0]
68
+ all_coords = torch.cat([slat_origin.coords, common_coords], dim=0)
69
+ uniq_coords, inv = torch.unique(all_coords, dim=0, return_inverse=True)
70
+ inv_slat = inv[:N].cuda()
71
+ inv_common = inv[N:].cuda()
72
+ device = slat_origin.coords.device
73
+ idx_map = torch.full((uniq_coords.shape[0],), -1, dtype=torch.int32, device=device)
74
+ slat_idx = torch.arange(N, dtype=torch.int32, device=device)
75
+ idx_map.scatter_reduce_(0, inv_slat, slat_idx, reduce='amin', include_self=False)
76
+ idx_in_slat = idx_map[inv_common]
77
+ feats = slat_origin.feats[idx_in_slat]
78
+ return sp.SparseTensor(feats, common_coords)
79
+
80
+ def get_point(point_num, common_coords, foreground_coords):
81
+ device = common_coords.device
82
+ point_feats_coords = torch.zeros((10, 4), dtype=torch.int32, device=device)
83
+ point_labels = torch.zeros((10, 1), dtype=torch.int32, device=device)
84
+ foreground_idx = torch.randperm(foreground_coords.shape[0], device=device)[:point_num]
85
+ point_foreground = foreground_coords[foreground_idx]
86
+ if point_foreground.shape[0] != point_num:
87
+ return None
88
+ point_feats_coords[:point_num] = point_foreground
89
+ point_labels[:point_num] = 1
90
+ return {'point_feats': point_feats_coords.cpu(), 'point_labels': point_labels.cpu()}
91
+
92
+ def vxz_to_slat(shape_encoder, tex_encoder, input_vxz_path, output_vxz_path, save_dir, interactive):
93
+ input_shape_slat_origin, input_tex_slat_origin = vxz_to_latent_slat(shape_encoder, tex_encoder, input_vxz_path)
94
+ if interactive:
95
+ output_shape_slat_origin, output_tex_slat_origin, foreground_coords_origin = vxz_to_latent_slat(shape_encoder, tex_encoder, output_vxz_path, return_foreground=interactive)
96
+ common_coords, foreground_coords = get_common_coords(input_shape_slat_origin, input_tex_slat_origin, output_shape_slat_origin, output_tex_slat_origin, foreground_coords_origin)
97
+ else:
98
+ output_shape_slat_origin, output_tex_slat_origin = vxz_to_latent_slat(shape_encoder, tex_encoder, output_vxz_path)
99
+ common_coords = get_common_coords(input_shape_slat_origin, input_tex_slat_origin, output_shape_slat_origin, output_tex_slat_origin)
100
+
101
+ os.makedirs(save_dir, exist_ok=True)
102
+ shape_slat = get_slat_by_common_coords(input_shape_slat_origin, common_coords)
103
+ torch.save({"feats": shape_slat.feats.cpu(), "coords": shape_slat.coords.cpu()}, os.path.join(save_dir, "shape_slat.pth"))
104
+ input_tex_slat = get_slat_by_common_coords(input_tex_slat_origin, common_coords)
105
+ torch.save({"feats": input_tex_slat.feats.cpu(), "coords": input_tex_slat.coords.cpu()}, os.path.join(save_dir, "input_tex_slat.pth"))
106
+ output_tex_slat_gt = get_slat_by_common_coords(output_tex_slat_origin, common_coords)
107
+ torch.save({"feats": output_tex_slat_gt.feats.cpu(), "coords": output_tex_slat_gt.coords.cpu()}, os.path.join(save_dir, "output_tex_slat.pth"))
108
+
109
+ if interactive:
110
+ for point_num in range(1, 11):
111
+ input_points = get_point(point_num, common_coords, foreground_coords)
112
+ if input_points is None:
113
+ continue
114
+ torch.save(input_points, os.path.join(save_dir, f"point_{point_num}.pth"))
115
+
116
+ if __name__ == "__main__":
117
+ input_vxz_path = "./assets/input.vxz"
118
+ output_vxz_path = "./assets/interactive_seg/0/output.vxz"
119
+ save_dir = "./assets/interactive_seg/0"
120
+ interactive = True
121
+ shape_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
122
+ tex_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
123
+ vxz_to_slat(shape_encoder, tex_encoder, input_vxz_path, output_vxz_path, save_dir, interactive)
examples/00aee5c2fef743d69421bb642d446a5b.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2d10a086262ccdf1109d5eb2fa385efb005da7503df9d39e5ab5cce0aa7d236
3
+ size 1585072
examples/00aee5c2fef743d69421bb642d446a5b.png ADDED

Git LFS Details

  • SHA256: c3d6fce7abcd4633e2fc482ace7383cb7733d2d4cf783d2b8a15e554a603d1a1
  • Pointer size: 130 Bytes
  • Size of remote file: 18.8 kB
examples/01b8043112e74366a21256d5e64398fb.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da1ad73c6eeb2338ed7d982bb5307daaa2f5faab91e0f883de214960c739be19
3
+ size 467796
examples/01b8043112e74366a21256d5e64398fb.png ADDED

Git LFS Details

  • SHA256: f98e0e43027c9b56afab4330bf7e55a02e216bc5f75084d44ef51a13db93828c
  • Pointer size: 130 Bytes
  • Size of remote file: 14.5 kB
examples/0c070001a3904cd6809a31345475e930.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57cd98fb3ca54128942744a8090fd040487159107cc810acdd1c20594ceac9bb
3
+ size 1524860
examples/0c070001a3904cd6809a31345475e930.png ADDED

Git LFS Details

  • SHA256: a2a132202263b4a60976fd6ebe941ff66a05e93e3af0adec656db74f60a526c0
  • Pointer size: 130 Bytes
  • Size of remote file: 26.3 kB
examples/0c3ca2b32545416f8f1e6f0e87def1a6.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:984c76bea31761c1f600240e6212b51c407b4cb69c19af30352d11879274a3b9
3
+ size 649004
examples/0c3ca2b32545416f8f1e6f0e87def1a6.png ADDED

Git LFS Details

  • SHA256: 68405d719adb4311f82501e37d6d03a489dcac058ee7f5e631849f0fdb42010c
  • Pointer size: 130 Bytes
  • Size of remote file: 16.6 kB
examples/1b3e8b99913442308aa989e3f87680b3.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa19b47445f18cbd75d812078103d0ae9bbc0b1ad404eeac12e78c9ed6ab41a1
3
+ size 2320124
examples/1b3e8b99913442308aa989e3f87680b3.png ADDED

Git LFS Details

  • SHA256: 0487593d7c9955c09215d061e0762da6fd5a39baa2bc8b16deead86d3d1d0da3
  • Pointer size: 130 Bytes
  • Size of remote file: 36.6 kB
examples/1c33b2e86c023a72905a5bea4ae713d0.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f0589018f02feaafa00770b8e1fe434009ed5bd419127452edd87d5534fa51b
3
+ size 1164628
examples/1c33b2e86c023a72905a5bea4ae713d0.png ADDED

Git LFS Details

  • SHA256: 9cc0741e849e373067c37dc5024075d7d82d0ec45188ccb750d5150ee8b5d80a
  • Pointer size: 130 Bytes
  • Size of remote file: 23.6 kB
examples/1ca8ea337fbc4bcfbeb3c633bc4c43f0.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5417c20a8f8b7fad7df93f2c656d1a6d22c44c284512e2bbd1ad0f498b9f7287
3
+ size 1755800
examples/1ca8ea337fbc4bcfbeb3c633bc4c43f0.png ADDED

Git LFS Details

  • SHA256: fad497de1ec231989a7417972cf3b96fd80f86cc50062ea8872b9a45437891d8
  • Pointer size: 130 Bytes
  • Size of remote file: 19 kB
examples/2260799ee4e342398b64ab4ce8af1559.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69914b1a07a6e3866b1d88bbb6f8d33fc13d492315626165455de902be818cb2
3
+ size 515608
examples/2260799ee4e342398b64ab4ce8af1559.png ADDED

Git LFS Details

  • SHA256: d1adf1ce0c961521e33a8c788b4ca8b339b517d0d0fba496e5fdd2530d5376ce
  • Pointer size: 130 Bytes
  • Size of remote file: 29.2 kB
examples/2ae5cf2990c34e7db704f677de8de74c.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:902905bc4cc0563f242a7cc5f96ee92165a4912527fda52f2beaa43ab6e67367
3
+ size 468560
examples/2ae5cf2990c34e7db704f677de8de74c.png ADDED

Git LFS Details

  • SHA256: e2a0888afd86521ef3f480f8412507f8b6b00f2d73032f8f965328d5a2355526
  • Pointer size: 130 Bytes
  • Size of remote file: 24.8 kB
examples/2ceb6778ac114101833e4c531544ada8.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d0ad0049486e6f46ea3ec4df8ea0acf2d846a0f5cac22032a1b26cf2050dc3f
3
+ size 225560
examples/2ceb6778ac114101833e4c531544ada8.png ADDED

Git LFS Details

  • SHA256: 895eb10de66dd41cfc9506bc058f6102ce961c36f8ff485daa6362407aae25f2
  • Pointer size: 130 Bytes
  • Size of remote file: 20.5 kB
examples/4b57e73e82ab400aa307adac36ea0e5e.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89dabad77dffd5cfc26fd4e6f977fb9ed686fa165a950d0c35df0693262a9352
3
+ size 3001252
examples/4b57e73e82ab400aa307adac36ea0e5e.png ADDED

Git LFS Details

  • SHA256: fb3e98c8b1bf0c910491945aef599914ea45fab15b7edf6d8b6e36e032b23548
  • Pointer size: 130 Bytes
  • Size of remote file: 13.7 kB
inference_full.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+
5
+ import json
6
+ import torch
7
+ import trimesh
8
+ import o_voxel
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import trellis2.modules.sparse as sp
12
+
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from trellis2 import models
16
+ from collections import OrderedDict
17
+ from trellis2.pipelines.rembg import BiRefNet
18
+ from trellis2.representations import MeshWithVoxel
19
+ from data_toolkit.bpy_render import render_from_transforms
20
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
21
+
22
+
23
+ TRELLIS_PIPELINE_JSON = "data_toolkit/texturing_pipeline.json"
24
+ TRELLIS_TEX_FLOW = "microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16"
25
+ TRELLIS_SHAPE_ENC = "microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16"
26
+ TRELLIS_TEX_ENC = "microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16"
27
+ TRELLIS_SHAPE_DEC = "microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16"
28
+ TRELLIS_TEX_DEC = "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16"
29
+ DINO_PATH = "facebook/dinov3-vitl16-pretrain-lvd1689m"
30
+ # DINO_PATH = "/bj-mlp-buaa-prod/pretrained_weights/pretrained_weights/dinov3"
31
+
32
+
33
+ def _colorvisuals_to_texturevisuals(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
34
+ """
35
+ Convert ColorVisuals to TextureVisuals by baking per-face colors into a tiny atlas
36
+ and generating per-face UVs. Ensure the resulting material is PBRMaterial to satisfy
37
+ downstream GLTF/PBR-only pipelines.
38
+ """
39
+ if mesh.visual is None:
40
+ return mesh
41
+
42
+ # If already textured, just ensure PBR material
43
+ if isinstance(mesh.visual, trimesh.visual.texture.TextureVisuals):
44
+ mat = getattr(mesh.visual, "material", None)
45
+ if isinstance(mat, trimesh.visual.material.SimpleMaterial):
46
+ # Avoid side-effects if the mesh is shared elsewhere
47
+ mesh = mesh.copy()
48
+ try:
49
+ mesh.visual.material = mat.to_pbr()
50
+ except Exception:
51
+ # Fallback: construct a minimal PBR material from the image
52
+ mesh.visual.material = trimesh.visual.material.PBRMaterial(
53
+ baseColorTexture=mat.image
54
+ )
55
+ return mesh
56
+
57
+ # Only handle ColorVisuals here
58
+ if not isinstance(mesh.visual, trimesh.visual.color.ColorVisuals):
59
+ return mesh
60
+
61
+ F = int(len(mesh.faces))
62
+ if F <= 0:
63
+ return mesh
64
+
65
+ # ---- Get per-face RGBA (uint8) ----
66
+ face_rgba = None
67
+
68
+ # Prefer face colors if present
69
+ if hasattr(mesh.visual, "face_colors") and mesh.visual.face_colors is not None:
70
+ fc = np.asarray(mesh.visual.face_colors)
71
+ if fc.ndim == 2 and fc.shape[0] == F:
72
+ face_rgba = fc[:, :4].astype(np.uint8)
73
+
74
+ # Fallback: average vertex colors per face
75
+ if face_rgba is None and hasattr(mesh.visual, "vertex_colors") and mesh.visual.vertex_colors is not None:
76
+ vc = np.asarray(mesh.visual.vertex_colors)
77
+ if vc.ndim == 2 and vc.shape[0] == len(mesh.vertices):
78
+ tri = mesh.faces
79
+ vcol = vc[tri] # (F,3,4)
80
+ face_rgba = np.rint(vcol.mean(axis=1)).astype(np.uint8)
81
+
82
+ if face_rgba is None:
83
+ face_rgba = np.tile(np.array([[255, 255, 255, 255]], dtype=np.uint8), (F, 1))
84
+
85
+ grid = int(math.ceil(math.sqrt(F)))
86
+ img = np.zeros((grid, grid, 4), dtype=np.uint8)
87
+
88
+ for i in range(F):
89
+ x = i % grid
90
+ y = i // grid
91
+ if y >= grid:
92
+ break
93
+ img[y, x, :] = face_rgba[i]
94
+
95
+ pil_img = Image.fromarray(img, mode="RGBA")
96
+
97
+ v_new = mesh.vertices[mesh.faces].reshape(-1, 3)
98
+ f_new = np.arange(F * 3, dtype=np.int64).reshape(F, 3)
99
+
100
+ uv_new = np.zeros((F * 3, 2), dtype=np.float32)
101
+ for i in range(F):
102
+ x = i % grid
103
+ y = i // grid
104
+ u = (x + 0.5) / float(grid)
105
+ v = (y + 0.5) / float(grid)
106
+ uv_new[i * 3 : i * 3 + 3, 0] = u
107
+ uv_new[i * 3 : i * 3 + 3, 1] = v
108
+
109
+ pbr = trimesh.visual.material.PBRMaterial(
110
+ baseColorTexture=pil_img,
111
+ metallicFactor=0.0,
112
+ roughnessFactor=1.0,
113
+ doubleSided=True,
114
+ alphaMode="BLEND",
115
+ )
116
+ visual = trimesh.visual.texture.TextureVisuals(uv=uv_new, material=pbr)
117
+
118
+ out = trimesh.Trimesh(vertices=v_new, faces=f_new, visual=visual, process=False)
119
+ return out
120
+
121
+
122
+ def ensure_texture_visuals(asset):
123
+ """
124
+ Ensure all geometries in a Scene (or a single Trimesh) use TextureVisuals.
125
+ For ColorVisuals, we bake them into a synthetic atlas.
126
+ """
127
+ if isinstance(asset, trimesh.Scene):
128
+ # Replace geometry objects in-place; graph nodes still refer to geometry names
129
+ for geom_name, g in list(asset.geometry.items()):
130
+ if isinstance(g, trimesh.Trimesh):
131
+ asset.geometry[geom_name] = _colorvisuals_to_texturevisuals(g)
132
+ return asset
133
+
134
+ if isinstance(asset, trimesh.Trimesh):
135
+ return _colorvisuals_to_texturevisuals(asset)
136
+
137
+ return asset
138
+
139
+
140
+ class Sampler:
141
+ def _inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond):
142
+ t = torch.tensor([t*1000] * x_t.shape[0], dtype=torch.float32).cuda()
143
+ return model(x_t, tex_slat, shape_slat, t, cond, coords_len_list)
144
+
145
+ def guidance_inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale=0.0):
146
+ if guidance_strength == 1:
147
+ return self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['cond'])
148
+ elif guidance_strength == 0:
149
+ return self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['neg_cond'])
150
+ else:
151
+ pred_pos = self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['cond'])
152
+ pred_neg = self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['neg_cond'])
153
+ pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
154
+ if guidance_rescale > 0:
155
+ x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
156
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
157
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
158
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
159
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
160
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
161
+ pred = self._xstart_to_pred(x_t, t, x_0)
162
+ return pred
163
+
164
+ def interval_inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, sampler_params):
165
+ guidance_strength = sampler_params['guidance_strength']
166
+ guidance_interval = sampler_params['guidance_interval']
167
+ guidance_rescale = sampler_params['guidance_rescale']
168
+ if guidance_interval[0] <= t <= guidance_interval[1]:
169
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale)
170
+ else:
171
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, 1, guidance_rescale)
172
+
173
+ @torch.no_grad()
174
+ def sample_once(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, t_prev, cond_dict, sampler_params):
175
+ pred_v = self.interval_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, sampler_params)
176
+ pred_x_prev = x_t - (t - t_prev) * pred_v
177
+ return pred_x_prev
178
+
179
+ @torch.no_grad()
180
+ def sample(self, model, noise, tex_slat, shape_slat, coords_len_list, cond_dict, sampler_params):
181
+ sample = noise
182
+ steps = sampler_params['steps']
183
+ rescale_t = sampler_params['rescale_t']
184
+ t_seq = np.linspace(1, 0, steps + 1)
185
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
186
+ t_seq = t_seq.tolist()
187
+ t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
188
+ for t, t_prev in tqdm(t_pairs, desc="Sampling"):
189
+ sample = self.sample_once(model, sample, tex_slat, shape_slat, coords_len_list, t, t_prev, cond_dict, sampler_params)
190
+ return sample
191
+
192
+
193
+ class Gen3DSeg(nn.Module):
194
+ def __init__(self, flow_model):
195
+ super().__init__()
196
+ self.flow_model = flow_model
197
+
198
+ def forward(self, x_t, tex_slats, shape_slats, t, cond, coords_len_list):
199
+ input_tex_feats_list = []
200
+ input_tex_coords_list = []
201
+ shape_feats_list = []
202
+ shape_coords_list = []
203
+ begin = 0
204
+ for coords_len in coords_len_list:
205
+ end = begin + coords_len
206
+ input_tex_feats_list.append(x_t.feats[begin:end])
207
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
208
+ input_tex_coords_list.append(x_t.coords[begin:end])
209
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
210
+ shape_feats_list.append(shape_slats.feats[begin:end])
211
+ shape_feats_list.append(shape_slats.feats[begin:end])
212
+ shape_coords_list.append(shape_slats.coords[begin:end])
213
+ shape_coords_list.append(shape_slats.coords[begin:end])
214
+ begin = end
215
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
216
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
217
+
218
+ output_tex_slats = self.flow_model(x_t, t, cond, shape_slats)
219
+
220
+ output_tex_feats_list = []
221
+ output_tex_coords_list = []
222
+ begin = 0
223
+ for coords_len in coords_len_list:
224
+ end = begin + coords_len
225
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
226
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
227
+ begin = begin + 2 * coords_len
228
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
229
+ return output_tex_slat
230
+
231
+
232
+ def make_texture_square_pow2(img: Image.Image, target_size=None):
233
+ w, h = img.size
234
+ max_side = max(w, h)
235
+ pow2 = 1
236
+ while pow2 < max_side:
237
+ pow2 *= 2
238
+ if target_size is not None:
239
+ pow2 = target_size
240
+ pow2 = min(pow2, 2048)
241
+ return img.resize((pow2, pow2), Image.BILINEAR)
242
+
243
+
244
+ def preprocess_scene_textures(asset):
245
+ if not isinstance(asset, trimesh.Scene):
246
+ return asset
247
+ TEX_KEYS = ["baseColorTexture", "normalTexture", "metallicRoughnessTexture", "emissiveTexture", "occlusionTexture"]
248
+ for geom in asset.geometry.values():
249
+ visual = getattr(geom, "visual", None)
250
+ mat = getattr(visual, "material", None)
251
+ if mat is None:
252
+ continue
253
+ for key in TEX_KEYS:
254
+ if not hasattr(mat, key):
255
+ continue
256
+ tex = getattr(mat, key)
257
+ if tex is None:
258
+ continue
259
+ if isinstance(tex, Image.Image):
260
+ setattr(mat, key, make_texture_square_pow2(tex))
261
+ elif hasattr(tex, "image") and tex.image is not None:
262
+ img = tex.image
263
+ if not isinstance(img, Image.Image):
264
+ img = Image.fromarray(img)
265
+ tex.image = make_texture_square_pow2(img)
266
+ if hasattr(mat, "image") and mat.image is not None:
267
+ img = mat.image
268
+ if not isinstance(img, Image.Image):
269
+ img = Image.fromarray(img)
270
+ mat.image = make_texture_square_pow2(img)
271
+ return asset
272
+
273
+
274
+ def process_glb_to_vxz(glb_path, vxz_path):
275
+ asset = trimesh.load(glb_path, force='scene')
276
+ asset = preprocess_scene_textures(asset)
277
+
278
+ asset = ensure_texture_visuals(asset)
279
+
280
+ aabb = asset.bounding_box.bounds
281
+ center = (aabb[0] + aabb[1]) / 2
282
+ scale = 0.99999 / (aabb[1] - aabb[0]).max()
283
+ asset.apply_translation(-center)
284
+ asset.apply_scale(scale)
285
+ mesh = asset.to_mesh()
286
+ vertices = torch.from_numpy(mesh.vertices).float()
287
+ faces = torch.from_numpy(mesh.faces).long()
288
+
289
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
290
+ vertices, faces, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
291
+ face_weight=1.0, boundary_weight=0.2, regularization_weight=1e-2, timing=False
292
+ )
293
+ vid = o_voxel.serialize.encode_seq(voxel_indices)
294
+ mapping = torch.argsort(vid)
295
+ voxel_indices = voxel_indices[mapping]
296
+ dual_vertices = dual_vertices[mapping]
297
+ intersected = intersected[mapping]
298
+
299
+ voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr(
300
+ asset, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], timing=False
301
+ )
302
+ vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat)
303
+ mapping_mat = torch.argsort(vid_mat)
304
+ attributes = {k: v[mapping_mat] for k, v in attributes.items()}
305
+
306
+ dual_vertices = dual_vertices * 512 - voxel_indices
307
+ dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8)
308
+ intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
309
+
310
+ attributes['dual_vertices'] = dual_vertices
311
+ attributes['intersected'] = intersected
312
+ o_voxel.io.write(vxz_path, voxel_indices, attributes)
313
+
314
+
315
+ def vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, vxz_path):
316
+ coords, data = o_voxel.io.read(vxz_path)
317
+ coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords], dim=1).cuda()
318
+ vertices = (data['dual_vertices'].cuda() / 255)
319
+ intersected = torch.cat([data['intersected'] % 2, data['intersected'] // 2 % 2, data['intersected'] // 4 % 2], dim=-1).bool().cuda()
320
+ vertices_sparse = sp.SparseTensor(vertices, coords)
321
+ intersected_sparse = sp.SparseTensor(intersected.float(), coords)
322
+ with torch.no_grad():
323
+ shape_slat = shape_encoder(vertices_sparse, intersected_sparse)
324
+ shape_slat = sp.SparseTensor(shape_slat.feats.cuda(), shape_slat.coords.cuda())
325
+ shape_decoder.set_resolution(512)
326
+ meshes, subs = shape_decoder(shape_slat, return_subs=True)
327
+
328
+ base_color = (data['base_color'] / 255)
329
+ metallic = (data['metallic'] / 255)
330
+ roughness = (data['roughness'] / 255)
331
+ alpha = (data['alpha'] / 255)
332
+ attr = torch.cat([base_color, metallic, roughness, alpha], dim=-1).float().cuda() * 2 - 1
333
+ with torch.no_grad():
334
+ tex_slat = tex_encoder(sp.SparseTensor(attr, coords))
335
+ return shape_slat, meshes, subs, tex_slat
336
+
337
+
338
+ def preprocess_image(rembg_model, input):
339
+ if input.mode != "RGB":
340
+ bg = Image.new("RGB", input.size, (255, 255, 255))
341
+ bg.paste(input, mask=input.split()[3])
342
+ input = bg
343
+ has_alpha = False
344
+ if input.mode == 'RGBA':
345
+ alpha = np.array(input)[:, :, 3]
346
+ if not np.all(alpha == 255):
347
+ has_alpha = True
348
+ max_size = max(input.size)
349
+ scale = min(1, 1024 / max_size)
350
+ if scale < 1:
351
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
352
+ if has_alpha:
353
+ output = input
354
+ else:
355
+ input = input.convert('RGB')
356
+ output = rembg_model(input)
357
+ output_np = np.array(output)
358
+ alpha = output_np[:, :, 3]
359
+ bbox = np.argwhere(alpha > 0.8 * 255)
360
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
361
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
362
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
363
+ size = int(size * 1)
364
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
365
+ output = output.crop(bbox) # type: ignore
366
+ output = np.array(output).astype(np.float32) / 255
367
+ output = output[:, :, :3] * output[:, :, 3:4]
368
+ output = Image.fromarray((output * 255).astype(np.uint8))
369
+ return output
370
+
371
+
372
+ def get_cond(image_cond_model, image):
373
+ image_cond_model.image_size = 512
374
+ cond = image_cond_model(image)
375
+ neg_cond = torch.zeros_like(cond)
376
+ return {'cond': cond, 'neg_cond': neg_cond}
377
+
378
+
379
+ def tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, input_tex_slat, cond_dict):
380
+ device = shape_slat.feats.device
381
+ shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None].to(device)
382
+ shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None].to(device)
383
+ tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None].to(device)
384
+ tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None].to(device)
385
+ shape_slat = ((shape_slat - shape_mean) / shape_std)
386
+ input_tex_slat = ((input_tex_slat - tex_mean) / tex_std)
387
+ coords_len_list = [shape_slat.coords.shape[0]]
388
+ noise = sp.SparseTensor(torch.randn_like(input_tex_slat.feats), shape_slat.coords)
389
+ output_tex_slat = sampler.sample(gen3dseg, noise, input_tex_slat, shape_slat, coords_len_list, cond_dict, pipeline_args['tex_slat_sampler']['params'])
390
+ output_tex_slat = output_tex_slat * tex_std + tex_mean
391
+ return output_tex_slat
392
+
393
+
394
+ def slat_to_glb(meshes, tex_voxels, resolution=512):
395
+ pbr_attr_layout = {
396
+ 'base_color': slice(0, 3),
397
+ 'metallic': slice(3, 4),
398
+ 'roughness': slice(4, 5),
399
+ 'alpha': slice(5, 6),
400
+ }
401
+ out_mesh = []
402
+ for m, v in zip(meshes, tex_voxels):
403
+ m.fill_holes()
404
+ out_mesh.append(
405
+ MeshWithVoxel(
406
+ m.vertices, m.faces,
407
+ origin = [-0.5, -0.5, -0.5],
408
+ voxel_size = 1 / resolution,
409
+ coords = v.coords[:, 1:],
410
+ attrs = v.feats,
411
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
412
+ layout=pbr_attr_layout
413
+ )
414
+ )
415
+ mesh = out_mesh[0]
416
+ mesh.simplify(10000000)
417
+ glb = o_voxel.postprocess.to_glb(
418
+ vertices = mesh.vertices,
419
+ faces = mesh.faces,
420
+ attr_volume = mesh.attrs,
421
+ coords = mesh.coords,
422
+ attr_layout = mesh.layout,
423
+ voxel_size = mesh.voxel_size,
424
+ aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
425
+ decimation_target = 100000,
426
+ texture_size = 4096,
427
+ remesh = True,
428
+ remesh_band = 1,
429
+ remesh_project = 0,
430
+ verbose = True
431
+ )
432
+ return glb
433
+
434
+
435
+ class _LoadedPipeline:
436
+ def __init__(self):
437
+ self.loaded = False
438
+ self.current_ckpt = None
439
+
440
+ self.pipeline_args = None
441
+ self.tex_slat_flow_model = None
442
+ self.gen3dseg = None
443
+ self.sampler = None
444
+
445
+ self.shape_encoder = None
446
+ self.tex_encoder = None
447
+ self.shape_decoder = None
448
+ self.tex_decoder = None
449
+
450
+ self.rembg_model = None
451
+ self.image_cond_model = None
452
+
453
+ def load_all_models(self):
454
+ if self.loaded:
455
+ return
456
+
457
+ print("-" * 100)
458
+ print("[Init] Loading pipeline config ............")
459
+ with open(TRELLIS_PIPELINE_JSON, "r") as f:
460
+ pipeline_config = json.load(f)
461
+ self.pipeline_args = pipeline_config['args']
462
+
463
+ print("-" * 100)
464
+ print("[Init] Loading TRELLIS backbone ............")
465
+ self.tex_slat_flow_model = models.from_pretrained(TRELLIS_TEX_FLOW)
466
+
467
+ self.gen3dseg = Gen3DSeg(self.tex_slat_flow_model)
468
+ self.gen3dseg.eval()
469
+ self.gen3dseg.cuda()
470
+
471
+ self.sampler = Sampler()
472
+
473
+ self.shape_encoder = models.from_pretrained(TRELLIS_SHAPE_ENC).cuda().eval()
474
+ self.tex_encoder = models.from_pretrained(TRELLIS_TEX_ENC).cuda().eval()
475
+ self.shape_decoder = models.from_pretrained(TRELLIS_SHAPE_DEC).cuda().eval()
476
+ self.tex_decoder = models.from_pretrained(TRELLIS_TEX_DEC).cuda().eval()
477
+
478
+ print("-" * 100)
479
+ print("[Init] Loading conditioners ............")
480
+
481
+ self.rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
482
+ self.rembg_model.cuda()
483
+
484
+ self.image_cond_model = DinoV3FeatureExtractor(DINO_PATH)
485
+ self.image_cond_model.cuda()
486
+
487
+ self.loaded = True
488
+ print("[Init] Done.")
489
+
490
+ def load_ckpt_if_needed(self, ckpt_path: str):
491
+ if self.current_ckpt == ckpt_path:
492
+ return
493
+
494
+ print("-" * 100)
495
+ print(f"[CKPT] Loading ckpt: {ckpt_path}")
496
+ state_dict = torch.load(ckpt_path)['state_dict']
497
+ state_dict = OrderedDict([(k.replace("gen3dseg.", ""), v) for k, v in state_dict.items()])
498
+ self.gen3dseg.load_state_dict(state_dict)
499
+ self.gen3dseg.eval()
500
+ self.gen3dseg.cuda()
501
+ self.current_ckpt = ckpt_path
502
+
503
+
504
+ PIPE = _LoadedPipeline()
505
+
506
+
507
+ def inference_with_loaded_models(ckpt_path, item):
508
+ PIPE.load_all_models()
509
+ PIPE.load_ckpt_if_needed(ckpt_path)
510
+
511
+ if PIPE.rembg_model is None:
512
+ raise RuntimeError("PIPE.rembg_model is None. Check BiRefNet loading and .cuda() usage.")
513
+ if PIPE.image_cond_model is None:
514
+ raise RuntimeError("PIPE.image_cond_model is None. Check DinoV3FeatureExtractor loading and .cuda() usage.")
515
+
516
+ process_glb_to_vxz(item['glb'], item['input_vxz'])
517
+ shape_slat, meshes, subs, tex_slat = vxz_to_latent_slat(
518
+ PIPE.shape_encoder, PIPE.shape_decoder, PIPE.tex_encoder, item['input_vxz']
519
+ )
520
+
521
+ if not item['2d_map']:
522
+ render_from_transforms(item['glb'], item['transforms'], item['img'])
523
+
524
+ image = Image.open(item['img'])
525
+ image = preprocess_image(PIPE.rembg_model, image)
526
+ cond = get_cond(PIPE.image_cond_model, [image])
527
+
528
+ output_tex_slat = tex_slat_sample_single(
529
+ PIPE.gen3dseg, PIPE.sampler, PIPE.pipeline_args, shape_slat, tex_slat, cond
530
+ )
531
+ with torch.no_grad():
532
+ tex_voxels = PIPE.tex_decoder(output_tex_slat, guide_subs=subs) * 0.5 + 0.5
533
+
534
+ glb = slat_to_glb(meshes, tex_voxels)
535
+
536
+ T = np.eye(4, dtype=np.float64)
537
+ T[:3, :3] = np.array(
538
+ [
539
+ [1, 0, 0],
540
+ [0, 0, -1],
541
+ [0, 1, 0],
542
+ ],
543
+ dtype=np.float64,
544
+ )
545
+
546
+ if hasattr(glb, "apply_transform") and callable(getattr(glb, "apply_transform")):
547
+ glb.apply_transform(T)
548
+ glb.export(item["export_glb"])
549
+ else:
550
+ glb.export(item["export_glb"])
551
+ scene_or_mesh = trimesh.load(item["export_glb"], force="scene")
552
+ scene_or_mesh.apply_transform(T)
553
+ scene_or_mesh.export(item["export_glb"])
inference_full_ori.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+
5
+ import json
6
+ import torch
7
+ import trimesh
8
+ import o_voxel
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import trellis2.modules.sparse as sp
12
+
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from trellis2 import models
16
+ from collections import OrderedDict
17
+ from trellis2.pipelines.rembg import BiRefNet
18
+ from trellis2.representations import MeshWithVoxel
19
+ from data_toolkit.bpy_render import render_from_transforms
20
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
21
+
22
+ class Sampler:
23
+ def _inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond):
24
+ t = torch.tensor([t*1000] * x_t.shape[0], dtype=torch.float32).cuda()
25
+ return model(x_t, tex_slat, shape_slat, t, cond, coords_len_list)
26
+
27
+ def guidance_inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale=0.0):
28
+ if guidance_strength == 1:
29
+ return self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['cond'])
30
+ elif guidance_strength == 0:
31
+ return self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['neg_cond'])
32
+ else:
33
+ pred_pos = self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['cond'])
34
+ pred_neg = self._inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict['neg_cond'])
35
+ pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
36
+ if guidance_rescale > 0:
37
+ x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
38
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
39
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
40
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
41
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
42
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
43
+ pred = self._xstart_to_pred(x_t, t, x_0)
44
+ return pred
45
+
46
+ def interval_inference_model(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, sampler_params):
47
+ guidance_strength = sampler_params['guidance_strength']
48
+ guidance_interval = sampler_params['guidance_interval']
49
+ guidance_rescale = sampler_params['guidance_rescale']
50
+ if guidance_interval[0] <= t <= guidance_interval[1]:
51
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale)
52
+ else:
53
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, 1, guidance_rescale)
54
+
55
+ @torch.no_grad()
56
+ def sample_once(self, model, x_t, tex_slat, shape_slat, coords_len_list, t, t_prev, cond_dict, sampler_params):
57
+ pred_v = self.interval_inference_model(model, x_t, tex_slat, shape_slat, coords_len_list, t, cond_dict, sampler_params)
58
+ pred_x_prev = x_t - (t - t_prev) * pred_v
59
+ return pred_x_prev
60
+
61
+ @torch.no_grad()
62
+ def sample(self, model, noise, tex_slat, shape_slat, coords_len_list, cond_dict, sampler_params):
63
+ sample = noise
64
+ steps = sampler_params['steps']
65
+ rescale_t = sampler_params['rescale_t']
66
+ t_seq = np.linspace(1, 0, steps + 1)
67
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
68
+ t_seq = t_seq.tolist()
69
+ t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
70
+ for t, t_prev in tqdm(t_pairs, desc="Sampling"):
71
+ sample = self.sample_once(model, sample, tex_slat, shape_slat, coords_len_list, t, t_prev, cond_dict, sampler_params)
72
+ return sample
73
+
74
+ class Gen3DSeg(nn.Module):
75
+ def __init__(self, flow_model):
76
+ super().__init__()
77
+ self.flow_model = flow_model
78
+
79
+ def forward(self, x_t, tex_slats, shape_slats, t, cond, coords_len_list):
80
+ input_tex_feats_list = []
81
+ input_tex_coords_list = []
82
+ shape_feats_list = []
83
+ shape_coords_list = []
84
+ begin = 0
85
+ for coords_len in coords_len_list:
86
+ end = begin + coords_len
87
+ input_tex_feats_list.append(x_t.feats[begin:end])
88
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
89
+ input_tex_coords_list.append(x_t.coords[begin:end])
90
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
91
+ shape_feats_list.append(shape_slats.feats[begin:end])
92
+ shape_feats_list.append(shape_slats.feats[begin:end])
93
+ shape_coords_list.append(shape_slats.coords[begin:end])
94
+ shape_coords_list.append(shape_slats.coords[begin:end])
95
+ begin = end
96
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
97
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
98
+
99
+ output_tex_slats = self.flow_model(x_t, t, cond, shape_slats)
100
+
101
+ output_tex_feats_list = []
102
+ output_tex_coords_list = []
103
+ begin = 0
104
+ for coords_len in coords_len_list:
105
+ end = begin + coords_len
106
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
107
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
108
+ begin = begin + 2 * coords_len
109
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
110
+ return output_tex_slat
111
+
112
+ def make_texture_square_pow2(img: Image.Image, target_size=None):
113
+ w, h = img.size
114
+ max_side = max(w, h)
115
+ pow2 = 1
116
+ while pow2 < max_side:
117
+ pow2 *= 2
118
+ if target_size is not None:
119
+ pow2 = target_size
120
+ pow2 = min(pow2, 2048)
121
+ return img.resize((pow2, pow2), Image.BILINEAR)
122
+
123
+ def preprocess_scene_textures(asset):
124
+ if not isinstance(asset, trimesh.Scene):
125
+ return asset
126
+ TEX_KEYS = ["baseColorTexture", "normalTexture", "metallicRoughnessTexture", "emissiveTexture", "occlusionTexture"]
127
+ for geom in asset.geometry.values():
128
+ visual = getattr(geom, "visual", None)
129
+ mat = getattr(visual, "material", None)
130
+ if mat is None:
131
+ continue
132
+ for key in TEX_KEYS:
133
+ if not hasattr(mat, key):
134
+ continue
135
+ tex = getattr(mat, key)
136
+ if tex is None:
137
+ continue
138
+ if isinstance(tex, Image.Image):
139
+ setattr(mat, key, make_texture_square_pow2(tex))
140
+ elif hasattr(tex, "image") and tex.image is not None:
141
+ img = tex.image
142
+ if not isinstance(img, Image.Image):
143
+ img = Image.fromarray(img)
144
+ tex.image = make_texture_square_pow2(img)
145
+ if hasattr(mat, "image") and mat.image is not None:
146
+ img = mat.image
147
+ if not isinstance(img, Image.Image):
148
+ img = Image.fromarray(img)
149
+ mat.image = make_texture_square_pow2(img)
150
+ return asset
151
+
152
+ def process_glb_to_vxz(glb_path, vxz_path):
153
+ asset = trimesh.load(glb_path, force='scene')
154
+ asset = preprocess_scene_textures(asset)
155
+ aabb = asset.bounding_box.bounds
156
+ center = (aabb[0] + aabb[1]) / 2
157
+ scale = 0.99999 / (aabb[1] - aabb[0]).max()
158
+ asset.apply_translation(-center)
159
+ asset.apply_scale(scale)
160
+ mesh = asset.to_mesh()
161
+ vertices = torch.from_numpy(mesh.vertices).float()
162
+ faces = torch.from_numpy(mesh.faces).long()
163
+
164
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
165
+ vertices, faces, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
166
+ face_weight=1.0, boundary_weight=0.2, regularization_weight=1e-2, timing=False
167
+ )
168
+ vid = o_voxel.serialize.encode_seq(voxel_indices)
169
+ mapping = torch.argsort(vid)
170
+ voxel_indices = voxel_indices[mapping]
171
+ dual_vertices = dual_vertices[mapping]
172
+ intersected = intersected[mapping]
173
+
174
+ voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr(
175
+ asset, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], timing=False
176
+ )
177
+ vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat)
178
+ mapping_mat = torch.argsort(vid_mat)
179
+ attributes = {k: v[mapping_mat] for k, v in attributes.items()}
180
+
181
+ dual_vertices = dual_vertices * 512 - voxel_indices
182
+ dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8)
183
+ intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
184
+
185
+ attributes['dual_vertices'] = dual_vertices
186
+ attributes['intersected'] = intersected
187
+ o_voxel.io.write(vxz_path, voxel_indices, attributes)
188
+
189
+ def vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, vxz_path):
190
+ coords, data = o_voxel.io.read(vxz_path)
191
+ coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords], dim=1).cuda()
192
+ vertices = (data['dual_vertices'].cuda() / 255)
193
+ intersected = torch.cat([data['intersected'] % 2, data['intersected'] // 2 % 2, data['intersected'] // 4 % 2], dim=-1).bool().cuda()
194
+ vertices_sparse = sp.SparseTensor(vertices, coords)
195
+ intersected_sparse = sp.SparseTensor(intersected.float(), coords)
196
+ with torch.no_grad():
197
+ shape_slat = shape_encoder(vertices_sparse, intersected_sparse)
198
+ shape_slat = sp.SparseTensor(shape_slat.feats.cuda(), shape_slat.coords.cuda())
199
+ shape_decoder.set_resolution(512)
200
+ meshes, subs = shape_decoder(shape_slat, return_subs=True)
201
+
202
+ base_color = (data['base_color'] / 255)
203
+ metallic = (data['metallic'] / 255)
204
+ roughness = (data['roughness'] / 255)
205
+ alpha = (data['alpha'] / 255)
206
+ attr = torch.cat([base_color, metallic, roughness, alpha], dim=-1).float().cuda() * 2 - 1
207
+ with torch.no_grad():
208
+ tex_slat = tex_encoder(sp.SparseTensor(attr, coords))
209
+ return shape_slat, meshes, subs, tex_slat
210
+
211
+ def preprocess_image(rembg_model, input):
212
+ if input.mode != "RGB":
213
+ bg = Image.new("RGB", input.size, (255, 255, 255))
214
+ bg.paste(input, mask=input.split()[3])
215
+ input = bg
216
+ has_alpha = False
217
+ if input.mode == 'RGBA':
218
+ alpha = np.array(input)[:, :, 3]
219
+ if not np.all(alpha == 255):
220
+ has_alpha = True
221
+ max_size = max(input.size)
222
+ scale = min(1, 1024 / max_size)
223
+ if scale < 1:
224
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
225
+ if has_alpha:
226
+ output = input
227
+ else:
228
+ input = input.convert('RGB')
229
+ output = rembg_model(input)
230
+ output_np = np.array(output)
231
+ alpha = output_np[:, :, 3]
232
+ bbox = np.argwhere(alpha > 0.8 * 255)
233
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
234
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
235
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
236
+ size = int(size * 1)
237
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
238
+ output = output.crop(bbox) # type: ignore
239
+ output = np.array(output).astype(np.float32) / 255
240
+ output = output[:, :, :3] * output[:, :, 3:4]
241
+ output = Image.fromarray((output * 255).astype(np.uint8))
242
+ return output
243
+
244
+ def get_cond(image_cond_model, image):
245
+ image_cond_model.image_size = 512
246
+ cond = image_cond_model(image)
247
+ neg_cond = torch.zeros_like(cond)
248
+ return {'cond': cond, 'neg_cond': neg_cond}
249
+
250
+ def tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, input_tex_slat, cond_dict):
251
+ device = shape_slat.feats.device
252
+ shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None].to(device)
253
+ shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None].to(device)
254
+ tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None].to(device)
255
+ tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None].to(device)
256
+ shape_slat = ((shape_slat - shape_mean) / shape_std)
257
+ input_tex_slat = ((input_tex_slat - tex_mean) / tex_std)
258
+ coords_len_list = [shape_slat.coords.shape[0]]
259
+ noise = sp.SparseTensor(torch.randn_like(input_tex_slat.feats), shape_slat.coords)
260
+ output_tex_slat = sampler.sample(gen3dseg, noise, input_tex_slat, shape_slat, coords_len_list, cond_dict, pipeline_args['tex_slat_sampler']['params'])
261
+ output_tex_slat = output_tex_slat * tex_std + tex_mean
262
+ return output_tex_slat
263
+
264
+ def slat_to_glb(meshes, tex_voxels, resolution=512):
265
+ pbr_attr_layout = {
266
+ 'base_color': slice(0, 3),
267
+ 'metallic': slice(3, 4),
268
+ 'roughness': slice(4, 5),
269
+ 'alpha': slice(5, 6),
270
+ }
271
+ out_mesh = []
272
+ for m, v in zip(meshes, tex_voxels):
273
+ m.fill_holes()
274
+ out_mesh.append(
275
+ MeshWithVoxel(
276
+ m.vertices, m.faces,
277
+ origin = [-0.5, -0.5, -0.5],
278
+ voxel_size = 1 / resolution,
279
+ coords = v.coords[:, 1:],
280
+ attrs = v.feats,
281
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
282
+ layout=pbr_attr_layout
283
+ )
284
+ )
285
+ mesh = out_mesh[0]
286
+ mesh.simplify(10000000)
287
+ # mesh.simplify(16777216) # nvdiffrast limit
288
+ glb = o_voxel.postprocess.to_glb(
289
+ vertices = mesh.vertices,
290
+ faces = mesh.faces,
291
+ attr_volume = mesh.attrs,
292
+ coords = mesh.coords,
293
+ attr_layout = mesh.layout,
294
+ voxel_size = mesh.voxel_size,
295
+ aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
296
+ decimation_target = 100000, # 1000000
297
+ texture_size = 4096,
298
+ remesh = True,
299
+ # remesh = False,
300
+ remesh_band = 1,
301
+ remesh_project = 0,
302
+ verbose = True
303
+ )
304
+ return glb
305
+
306
+ def inference(ckpt_path, item):
307
+ print("-"*100)
308
+ print("Loading model ............")
309
+
310
+ with open("/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/texturing_pipeline.json", "r") as f:
311
+ pipeline_config = json.load(f)
312
+ pipeline_args = pipeline_config['args']
313
+ tex_slat_flow_model = models.from_pretrained(
314
+ "/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
315
+
316
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
317
+ state_dict = torch.load(ckpt_path)['state_dict']
318
+ state_dict = OrderedDict([(k.replace("gen3dseg.", ""), v) for k, v in state_dict.items()])
319
+ gen3dseg.load_state_dict(state_dict)
320
+ gen3dseg.eval()
321
+ gen3dseg.cuda()
322
+ sampler = Sampler()
323
+
324
+ shape_encoder = models.from_pretrained(
325
+ "/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
326
+ tex_encoder = models.from_pretrained(
327
+ "/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
328
+ shape_decoder = models.from_pretrained(
329
+ "/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16").cuda().eval()
330
+ tex_decoder = models.from_pretrained(
331
+ "/media/nfs/tmp_data/fenghr/download/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16").cuda().eval()
332
+
333
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
334
+ rembg_model.cuda()
335
+ # image_cond_model = DinoV3FeatureExtractor(
336
+ # model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
337
+ image_cond_model = DinoV3FeatureExtractor("/media/nfs/tmp_data/fenghr/download/dinov3")
338
+ image_cond_model.cuda()
339
+
340
+ process_glb_to_vxz(item['glb'], item['input_vxz'])
341
+ shape_slat, meshes, subs, tex_slat = vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, item['input_vxz'])
342
+
343
+ print("-"*100)
344
+ print("Getting cond ............")
345
+ if not item['2d_map']:
346
+ render_from_transforms(item['glb'], item['transforms'], item['img'])
347
+ image = Image.open(item['img'])
348
+ image = preprocess_image(rembg_model, image)
349
+ cond = get_cond(image_cond_model, [image])
350
+
351
+ print("-"*100)
352
+ print("Sampling .................")
353
+ output_tex_slat = tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, tex_slat, cond)
354
+ with torch.no_grad():
355
+ tex_voxels = tex_decoder(output_tex_slat, guide_subs=subs) * 0.5 + 0.5
356
+
357
+ print("-"*100)
358
+ print("Exporting glb ............")
359
+ glb = slat_to_glb(meshes, tex_voxels)
360
+ glb.export(item['export_glb'])
361
+
362
+ if __name__ == "__main__":
363
+ _2d_map = False
364
+ if _2d_map:
365
+ ckpt_path = "/media/nfs/tmp_data/fenghr/SegviGen/pretrained_models/full_seg_w_2d_map.ckpt"
366
+ item = {
367
+ "2d_map": True,
368
+ "glb": "./data_toolkit/assets/example.glb",
369
+ "input_vxz": "./data_toolkit/assets/input.vxz",
370
+ "img": "./data_toolkit/assets/full_seg_w_2d_map/2d_map.png",
371
+ "export_glb": "./data_toolkit/assets/output.glb"
372
+ }
373
+ else:
374
+ ckpt_path = "/media/nfs/tmp_data/fenghr/SegviGen/pretrained_models/full_seg.ckpt"
375
+ item = {
376
+ "2d_map": False,
377
+ "glb": "./data_toolkit/assets/example.glb",
378
+ "input_vxz": "./data_toolkit/assets/input.vxz",
379
+ "transforms": "./data_toolkit/transforms.json",
380
+ "img": "./data_toolkit/assets/img.png",
381
+ "export_glb": "./data_toolkit/assets/output.glb"
382
+ }
383
+ inference(ckpt_path, item)
inference_interactive.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+
5
+ import json
6
+ import torch
7
+ import trimesh
8
+ import o_voxel
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import trellis2.modules.sparse as sp
12
+
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from trellis2 import models
16
+ from types import MethodType
17
+ from collections import OrderedDict
18
+ from torch.nn import functional as F
19
+ from trellis2.pipelines.rembg import BiRefNet
20
+ from trellis2.modules.utils import manual_cast
21
+ from trellis2.representations import MeshWithVoxel
22
+ from data_toolkit.bpy_render import render_from_transforms
23
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
24
+
25
+ class Sampler:
26
+ def _inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond):
27
+ t = torch.tensor([t*1000] * x_t.shape[0], dtype=torch.float32).cuda()
28
+ return model(x_t, tex_slat, shape_slat, t, cond, input_points, coords_len_list)
29
+
30
+ def guidance_inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale=0.0):
31
+ if guidance_strength == 1:
32
+ return self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['cond'])
33
+ elif guidance_strength == 0:
34
+ return self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['neg_cond'])
35
+ else:
36
+ pred_pos = self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['cond'])
37
+ pred_neg = self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['neg_cond'])
38
+ pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
39
+ if guidance_rescale > 0:
40
+ x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
41
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
42
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
43
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
44
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
45
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
46
+ pred = self._xstart_to_pred(x_t, t, x_0)
47
+ return pred
48
+
49
+ def interval_inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, sampler_params):
50
+ guidance_strength = sampler_params['guidance_strength']
51
+ guidance_interval = sampler_params['guidance_interval']
52
+ guidance_rescale = sampler_params['guidance_rescale']
53
+ if guidance_interval[0] <= t <= guidance_interval[1]:
54
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, guidance_strength, guidance_rescale)
55
+ else:
56
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, 1, guidance_rescale)
57
+
58
+ @torch.no_grad()
59
+ def sample_once(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, t_prev, cond_dict, sampler_params):
60
+ pred_v = self.interval_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, sampler_params)
61
+ pred_x_prev = x_t - (t - t_prev) * pred_v
62
+ return pred_x_prev
63
+
64
+ @torch.no_grad()
65
+ def sample(self, model, noise, tex_slat, shape_slat, input_points, coords_len_list, cond_dict, sampler_params):
66
+ sample = noise
67
+ steps = sampler_params['steps']
68
+ rescale_t = sampler_params['rescale_t']
69
+ t_seq = np.linspace(1, 0, steps + 1)
70
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
71
+ t_seq = t_seq.tolist()
72
+ t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
73
+ for t, t_prev in tqdm(t_pairs, desc="Sampling"):
74
+ sample = self.sample_once(model, sample, tex_slat, shape_slat, input_points, coords_len_list, t, t_prev, cond_dict, sampler_params)
75
+ return sample
76
+
77
+ def flow_forward(self, x, t, cond, concat_cond, point_embeds, coords_len_list):
78
+ # x.feats: [N, 32]
79
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
80
+ if isinstance(cond, list):
81
+ cond = sp.VarLenTensor.from_tensor_list(cond)
82
+ # x.feats: [N, 64]
83
+ h = self.input_layer(x)
84
+ # h.feats: [N, 1536]
85
+ h = manual_cast(h, self.dtype)
86
+ t_emb = self.t_embedder(t)
87
+ t_emb = self.adaLN_modulation(t_emb)
88
+ t_emb = manual_cast(t_emb, self.dtype)
89
+ cond = manual_cast(cond, self.dtype)
90
+ point_embeds = manual_cast(point_embeds, self.dtype)
91
+
92
+ h_feats_list = []
93
+ h_coords_list = []
94
+ begin = 0
95
+ for i, coords_len in enumerate(coords_len_list):
96
+ end = begin + 2 * coords_len
97
+ h_feats_list.append(h.feats[begin:end])
98
+ h_coords_list.append(h.coords[begin:end])
99
+ h_feats_list.append(point_embeds.feats[i*10:(i+1)*10])
100
+ h_coords_list.append(point_embeds.coords[i*10:(i+1)*10])
101
+ begin = end + 10
102
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
103
+
104
+ for block in self.blocks:
105
+ h = block(h, t_emb, cond)
106
+
107
+ h_feats_list = []
108
+ h_coords_list = []
109
+ begin = 0
110
+ for i, coords_len in enumerate(coords_len_list):
111
+ end = begin + 2 * coords_len
112
+ h_feats_list.append(h.feats[begin:end])
113
+ h_coords_list.append(h.coords[begin:end])
114
+ begin = end
115
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
116
+
117
+ h = manual_cast(h, x.dtype)
118
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
119
+ # h.feats: [N, 1536]
120
+ h = self.out_layer(h)
121
+ # h.feats: [N, 32]
122
+ return h
123
+
124
+ class Gen3DSeg(nn.Module):
125
+ def __init__(self, flow_model):
126
+ super().__init__()
127
+ self.flow_model = flow_model
128
+ self.seg_embeddings = nn.Embedding(1, 1536)
129
+
130
+ def get_positional_encoding(self, input_points):
131
+ point_feats_embed = torch.zeros((10, 1536), dtype=torch.float32).to(input_points['point_slats'].feats.device)
132
+ labels = input_points['point_labels'].squeeze(-1)
133
+ point_feats_embed[labels == 1] = self.seg_embeddings.weight
134
+ return sp.SparseTensor(point_feats_embed, input_points['point_slats'].coords)
135
+
136
+ def forward(self, x_t, tex_slats, shape_slats, t, cond, input_points, coords_len_list):
137
+ input_tex_feats_list = []
138
+ input_tex_coords_list = []
139
+ shape_feats_list = []
140
+ shape_coords_list = []
141
+ begin = 0
142
+ for coords_len in coords_len_list:
143
+ end = begin + coords_len
144
+ input_tex_feats_list.append(x_t.feats[begin:end])
145
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
146
+ input_tex_coords_list.append(x_t.coords[begin:end])
147
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
148
+ shape_feats_list.append(shape_slats.feats[begin:end])
149
+ shape_feats_list.append(shape_slats.feats[begin:end])
150
+ shape_coords_list.append(shape_slats.coords[begin:end])
151
+ shape_coords_list.append(shape_slats.coords[begin:end])
152
+ begin = end
153
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
154
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
155
+
156
+ point_embeds = self.get_positional_encoding(input_points)
157
+ output_tex_slats = self.flow_model(x_t, t, cond, shape_slats, point_embeds, coords_len_list)
158
+
159
+ output_tex_feats_list = []
160
+ output_tex_coords_list = []
161
+ begin = 0
162
+ for coords_len in coords_len_list:
163
+ end = begin + coords_len
164
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
165
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
166
+ begin = begin + 2 * coords_len
167
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
168
+ return output_tex_slat
169
+
170
+ def make_texture_square_pow2(img: Image.Image, target_size=None):
171
+ w, h = img.size
172
+ max_side = max(w, h)
173
+ pow2 = 1
174
+ while pow2 < max_side:
175
+ pow2 *= 2
176
+ if target_size is not None:
177
+ pow2 = target_size
178
+ pow2 = min(pow2, 2048)
179
+ return img.resize((pow2, pow2), Image.BILINEAR)
180
+
181
+ def preprocess_scene_textures(asset):
182
+ if not isinstance(asset, trimesh.Scene):
183
+ return asset
184
+ TEX_KEYS = ["baseColorTexture", "normalTexture", "metallicRoughnessTexture", "emissiveTexture", "occlusionTexture"]
185
+ for geom in asset.geometry.values():
186
+ visual = getattr(geom, "visual", None)
187
+ mat = getattr(visual, "material", None)
188
+ if mat is None:
189
+ continue
190
+ for key in TEX_KEYS:
191
+ if not hasattr(mat, key):
192
+ continue
193
+ tex = getattr(mat, key)
194
+ if tex is None:
195
+ continue
196
+ if isinstance(tex, Image.Image):
197
+ setattr(mat, key, make_texture_square_pow2(tex))
198
+ elif hasattr(tex, "image") and tex.image is not None:
199
+ img = tex.image
200
+ if not isinstance(img, Image.Image):
201
+ img = Image.fromarray(img)
202
+ tex.image = make_texture_square_pow2(img)
203
+ if hasattr(mat, "image") and mat.image is not None:
204
+ img = mat.image
205
+ if not isinstance(img, Image.Image):
206
+ img = Image.fromarray(img)
207
+ mat.image = make_texture_square_pow2(img)
208
+ return asset
209
+
210
+ def process_glb_to_vxz(glb_path, vxz_path):
211
+ asset = trimesh.load(glb_path, force='scene')
212
+ asset = preprocess_scene_textures(asset)
213
+ aabb = asset.bounding_box.bounds
214
+ center = (aabb[0] + aabb[1]) / 2
215
+ scale = 0.99999 / (aabb[1] - aabb[0]).max()
216
+ asset.apply_translation(-center)
217
+ asset.apply_scale(scale)
218
+ mesh = asset.to_mesh()
219
+ vertices = torch.from_numpy(mesh.vertices).float()
220
+ faces = torch.from_numpy(mesh.faces).long()
221
+
222
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
223
+ vertices, faces, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
224
+ face_weight=1.0, boundary_weight=0.2, regularization_weight=1e-2, timing=False
225
+ )
226
+ vid = o_voxel.serialize.encode_seq(voxel_indices)
227
+ mapping = torch.argsort(vid)
228
+ voxel_indices = voxel_indices[mapping]
229
+ dual_vertices = dual_vertices[mapping]
230
+ intersected = intersected[mapping]
231
+
232
+ voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr(
233
+ asset, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], timing=False
234
+ )
235
+ vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat)
236
+ mapping_mat = torch.argsort(vid_mat)
237
+ attributes = {k: v[mapping_mat] for k, v in attributes.items()}
238
+
239
+ dual_vertices = dual_vertices * 512 - voxel_indices
240
+ dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8)
241
+ intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
242
+
243
+ attributes['dual_vertices'] = dual_vertices
244
+ attributes['intersected'] = intersected
245
+ o_voxel.io.write(vxz_path, voxel_indices, attributes)
246
+
247
+ def vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, vxz_path):
248
+ coords, data = o_voxel.io.read(vxz_path)
249
+ coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords], dim=1).cuda()
250
+ vertices = (data['dual_vertices'].cuda() / 255)
251
+ intersected = torch.cat([data['intersected'] % 2, data['intersected'] // 2 % 2, data['intersected'] // 4 % 2], dim=-1).bool().cuda()
252
+ vertices_sparse = sp.SparseTensor(vertices, coords)
253
+ intersected_sparse = sp.SparseTensor(intersected.float(), coords)
254
+ with torch.no_grad():
255
+ shape_slat = shape_encoder(vertices_sparse, intersected_sparse)
256
+ shape_slat = sp.SparseTensor(shape_slat.feats.cuda(), shape_slat.coords.cuda())
257
+ shape_decoder.set_resolution(512)
258
+ meshes, subs = shape_decoder(shape_slat, return_subs=True)
259
+
260
+ base_color = (data['base_color'] / 255)
261
+ metallic = (data['metallic'] / 255)
262
+ roughness = (data['roughness'] / 255)
263
+ alpha = (data['alpha'] / 255)
264
+ attr = torch.cat([base_color, metallic, roughness, alpha], dim=-1).float().cuda() * 2 - 1
265
+ with torch.no_grad():
266
+ tex_slat = tex_encoder(sp.SparseTensor(attr, coords))
267
+ return shape_slat, meshes, subs, tex_slat
268
+
269
+ def preprocess_image(rembg_model, input):
270
+ if input.mode != "RGB":
271
+ bg = Image.new("RGB", input.size, (255, 255, 255))
272
+ bg.paste(input, mask=input.split()[3])
273
+ input = bg
274
+ has_alpha = False
275
+ if input.mode == 'RGBA':
276
+ alpha = np.array(input)[:, :, 3]
277
+ if not np.all(alpha == 255):
278
+ has_alpha = True
279
+ max_size = max(input.size)
280
+ scale = min(1, 1024 / max_size)
281
+ if scale < 1:
282
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
283
+ if has_alpha:
284
+ output = input
285
+ else:
286
+ input = input.convert('RGB')
287
+ output = rembg_model(input)
288
+ output_np = np.array(output)
289
+ alpha = output_np[:, :, 3]
290
+ bbox = np.argwhere(alpha > 0.8 * 255)
291
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
292
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
293
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
294
+ size = int(size * 1)
295
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
296
+ output = output.crop(bbox) # type: ignore
297
+ output = np.array(output).astype(np.float32) / 255
298
+ output = output[:, :, :3] * output[:, :, 3:4]
299
+ output = Image.fromarray((output * 255).astype(np.uint8))
300
+ return output
301
+
302
+ def get_cond(image_cond_model, image):
303
+ image_cond_model.image_size = 512
304
+ cond = image_cond_model(image)
305
+ neg_cond = torch.zeros_like(cond)
306
+ return {'cond': cond, 'neg_cond': neg_cond}
307
+
308
+ def tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, input_tex_slat, cond_dict, input_points):
309
+ device = shape_slat.feats.device
310
+ shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None].to(device)
311
+ shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None].to(device)
312
+ tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None].to(device)
313
+ tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None].to(device)
314
+ shape_slat = ((shape_slat - shape_mean) / shape_std)
315
+ input_tex_slat = ((input_tex_slat - tex_mean) / tex_std)
316
+ coords_len_list = [shape_slat.coords.shape[0]]
317
+ noise = sp.SparseTensor(torch.randn_like(input_tex_slat.feats), shape_slat.coords)
318
+ output_tex_slat = sampler.sample(gen3dseg, noise, input_tex_slat, shape_slat, input_points, coords_len_list, cond_dict, pipeline_args['tex_slat_sampler']['params'])
319
+ output_tex_slat = output_tex_slat * tex_std + tex_mean
320
+ return output_tex_slat
321
+
322
+ def slat_to_glb(meshes, tex_voxels, resolution=512):
323
+ pbr_attr_layout = {
324
+ 'base_color': slice(0, 3),
325
+ 'metallic': slice(3, 4),
326
+ 'roughness': slice(4, 5),
327
+ 'alpha': slice(5, 6),
328
+ }
329
+ out_mesh = []
330
+ for m, v in zip(meshes, tex_voxels):
331
+ m.fill_holes()
332
+ out_mesh.append(
333
+ MeshWithVoxel(
334
+ m.vertices, m.faces,
335
+ origin = [-0.5, -0.5, -0.5],
336
+ voxel_size = 1 / resolution,
337
+ coords = v.coords[:, 1:],
338
+ attrs = v.feats,
339
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
340
+ layout=pbr_attr_layout
341
+ )
342
+ )
343
+ mesh = out_mesh[0]
344
+ mesh.simplify(10000000)
345
+ # mesh.simplify(16777216) # nvdiffrast limit
346
+ glb = o_voxel.postprocess.to_glb(
347
+ vertices = mesh.vertices,
348
+ faces = mesh.faces,
349
+ attr_volume = mesh.attrs,
350
+ coords = mesh.coords,
351
+ attr_layout = mesh.layout,
352
+ voxel_size = mesh.voxel_size,
353
+ aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
354
+ decimation_target = 100000, # 1000000
355
+ texture_size = 4096,
356
+ remesh = True,
357
+ # remesh = False,
358
+ remesh_band = 1,
359
+ remesh_project = 0,
360
+ verbose = True
361
+ )
362
+ return glb
363
+
364
+ def inference(ckpt_path, item, input_vxz_points_list):
365
+ print("-"*100)
366
+ print("Loading model ............")
367
+ with open("microsoft/TRELLIS.2-4B/pipeline.json", "r") as f:
368
+ pipeline_config = json.load(f)
369
+ pipeline_args = pipeline_config['args']
370
+ tex_slat_flow_model = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
371
+ tex_slat_flow_model.forward = MethodType(flow_forward, tex_slat_flow_model)
372
+
373
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
374
+ state_dict = torch.load(ckpt_path)['state_dict']
375
+ state_dict = OrderedDict([(k.replace("gen3dseg.", ""), v) for k, v in state_dict.items()])
376
+ gen3dseg.load_state_dict(state_dict)
377
+ gen3dseg.eval()
378
+ gen3dseg.cuda()
379
+ sampler = Sampler()
380
+
381
+ shape_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
382
+ tex_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
383
+ shape_decoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16").cuda().eval()
384
+ tex_decoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16").cuda().eval()
385
+
386
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
387
+ rembg_model.cuda()
388
+ image_cond_model = DinoV3FeatureExtractor(model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
389
+ image_cond_model.cuda()
390
+
391
+ process_glb_to_vxz(item['glb'], item['input_vxz'])
392
+ shape_slat, meshes, subs, tex_slat = vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, item['input_vxz'])
393
+
394
+ print("-"*100)
395
+ print("Getting cond ............")
396
+ render_from_transforms(item['glb'], item['transforms'], item['img'])
397
+ image = Image.open(item['img'])
398
+ image = preprocess_image(rembg_model, image)
399
+ cond = get_cond(image_cond_model, [image])
400
+
401
+ print("-"*100)
402
+ print("Sampling .................")
403
+ vxz_points_coords = torch.tensor(input_vxz_points_list, dtype=torch.int32).cuda()
404
+ vxz_points_coords = torch.cat([torch.zeros((vxz_points_coords.shape[0], 1), dtype=torch.int32).cuda(), vxz_points_coords], dim=1)
405
+ input_points_coords = tex_encoder(sp.SparseTensor(torch.zeros((vxz_points_coords.shape[0], 6), dtype=torch.float32).cuda(), vxz_points_coords)).coords
406
+ input_points_coords = torch.unique(input_points_coords, dim=0)
407
+ point_num = input_points_coords.shape[0]
408
+ if point_num >= 10:
409
+ input_points_coords = input_points_coords[:10]
410
+ point_labels = torch.tensor(([[1]]*10), dtype=torch.int32).cuda()
411
+ else:
412
+ input_points_coords = torch.cat([input_points_coords, torch.zeros((10 - point_num, 4), dtype=torch.int32).cuda()], dim=0)
413
+ point_labels = torch.tensor(([[1]]*point_num+[[0]]*(10-point_num)), dtype=torch.int32).cuda()
414
+ input_points = {'point_slats': sp.SparseTensor(input_points_coords, input_points_coords), 'point_labels': point_labels}
415
+
416
+ output_tex_slat = tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, tex_slat, cond, input_points)
417
+ with torch.no_grad():
418
+ tex_voxels = tex_decoder(output_tex_slat, guide_subs=subs) * 0.5 + 0.5
419
+
420
+ print("-"*100)
421
+ print("Exporting glb ............")
422
+ glb = slat_to_glb(meshes, tex_voxels)
423
+ glb.export(item['export_glb'])
424
+
425
+ if __name__ == "__main__":
426
+ ckpt_path = "path/to/interactive_seg.ckpt"
427
+ item = {
428
+ "glb": "./data_toolkit/assets/example.glb",
429
+ "input_vxz": "./data_toolkit/assets/input.vxz",
430
+ "transforms": "./data_toolkit/transforms.json",
431
+ "img": "./data_toolkit/assets/img.png",
432
+ "export_glb": "./data_toolkit/assets/output.glb"
433
+ }
434
+ input_vxz_points_list = [[388, 448, 392]] # example
435
+ inference(ckpt_path, item, input_vxz_points_list)
inference_unified.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+
5
+ import json
6
+ import torch
7
+ import trimesh
8
+ import o_voxel
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import trellis2.modules.sparse as sp
12
+
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from trellis2 import models
16
+ from types import MethodType
17
+ from collections import OrderedDict
18
+ from torch.nn import functional as F
19
+ from trellis2.pipelines.rembg import BiRefNet
20
+ from trellis2.modules.utils import manual_cast
21
+ from trellis2.representations import MeshWithVoxel
22
+ from data_toolkit.bpy_render import render_from_transforms
23
+ from trellis2.modules.image_feature_extractor import DinoV3FeatureExtractor
24
+
25
+ class Sampler:
26
+ def _inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond, tag):
27
+ t = torch.tensor([t*1000] * x_t.shape[0], dtype=torch.float32).cuda()
28
+ tag = torch.tensor([tag] * x_t.shape[0], dtype=torch.float32).cuda()
29
+ return model(x_t, tex_slat, shape_slat, t, tag, cond, input_points, coords_len_list)
30
+
31
+ def guidance_inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, tag, guidance_strength, guidance_rescale=0.0):
32
+ if guidance_strength == 1:
33
+ return self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['cond'], tag)
34
+ elif guidance_strength == 0:
35
+ return self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['neg_cond'], tag)
36
+ else:
37
+ pred_pos = self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['cond'], tag)
38
+ pred_neg = self._inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict['neg_cond'], tag)
39
+ pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
40
+ if guidance_rescale > 0:
41
+ x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
42
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
43
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
44
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
45
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
46
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
47
+ pred = self._xstart_to_pred(x_t, t, x_0)
48
+ return pred
49
+
50
+ def interval_inference_model(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, tag, sampler_params):
51
+ guidance_strength = sampler_params['guidance_strength']
52
+ guidance_interval = sampler_params['guidance_interval']
53
+ guidance_rescale = sampler_params['guidance_rescale']
54
+ if guidance_interval[0] <= t <= guidance_interval[1]:
55
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, tag, guidance_strength, guidance_rescale)
56
+ else:
57
+ return self.guidance_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, tag, 1, guidance_rescale)
58
+
59
+ @torch.no_grad()
60
+ def sample_once(self, model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, t_prev, cond_dict, tag, sampler_params):
61
+ pred_v = self.interval_inference_model(model, x_t, tex_slat, shape_slat, input_points, coords_len_list, t, cond_dict, tag, sampler_params)
62
+ pred_x_prev = x_t - (t - t_prev) * pred_v
63
+ return pred_x_prev
64
+
65
+ @torch.no_grad()
66
+ def sample(self, model, noise, tex_slat, shape_slat, input_points, coords_len_list, cond_dict, tag, sampler_params):
67
+ sample = noise
68
+ steps = sampler_params['steps']
69
+ rescale_t = sampler_params['rescale_t']
70
+ t_seq = np.linspace(1, 0, steps + 1)
71
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
72
+ t_seq = t_seq.tolist()
73
+ t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
74
+ for t, t_prev in tqdm(t_pairs, desc="Sampling"):
75
+ sample = self.sample_once(model, sample, tex_slat, shape_slat, input_points, coords_len_list, t, t_prev, cond_dict, tag, sampler_params)
76
+ return sample
77
+
78
+ def flow_forward(self, x, t, tag_embeds, cond, concat_cond, point_embeds, coords_len_list):
79
+ # x.feats: [N, 32]
80
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
81
+ if isinstance(cond, list):
82
+ cond = sp.VarLenTensor.from_tensor_list(cond)
83
+ # x.feats: [N, 64]
84
+ h = self.input_layer(x)
85
+ # h.feats: [N, 1536]
86
+ h = manual_cast(h, self.dtype)
87
+ t_emb = self.t_embedder(t)
88
+ t_emb = self.adaLN_modulation(t_emb)
89
+ tag_embeds = self.adaLN_modulation(tag_embeds)
90
+ t_emb = t_emb + tag_embeds
91
+ t_emb = manual_cast(t_emb, self.dtype)
92
+ cond = manual_cast(cond, self.dtype)
93
+ point_embeds = manual_cast(point_embeds, self.dtype)
94
+
95
+ h_feats_list = []
96
+ h_coords_list = []
97
+ begin = 0
98
+ for i, coords_len in enumerate(coords_len_list):
99
+ end = begin + 2 * coords_len
100
+ h_feats_list.append(h.feats[begin:end])
101
+ h_coords_list.append(h.coords[begin:end])
102
+ h_feats_list.append(point_embeds.feats[i*10:(i+1)*10])
103
+ h_coords_list.append(point_embeds.coords[i*10:(i+1)*10])
104
+ begin = end + 10
105
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
106
+
107
+ for block in self.blocks:
108
+ h = block(h, t_emb, cond)
109
+
110
+ h_feats_list = []
111
+ h_coords_list = []
112
+ begin = 0
113
+ for i, coords_len in enumerate(coords_len_list):
114
+ end = begin + 2 * coords_len
115
+ h_feats_list.append(h.feats[begin:end])
116
+ h_coords_list.append(h.coords[begin:end])
117
+ begin = end
118
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
119
+
120
+ h = manual_cast(h, x.dtype)
121
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
122
+ # h.feats: [N, 1536]
123
+ h = self.out_layer(h)
124
+ # h.feats: [N, 32]
125
+ return h
126
+
127
+ class Gen3DSeg(nn.Module):
128
+ def __init__(self, flow_model):
129
+ super().__init__()
130
+ self.flow_model = flow_model
131
+ self.seg_embeddings = nn.Embedding(1, 1536)
132
+ self.tag_mlp = nn.Sequential(nn.Linear(256, 1536, bias=True), nn.SiLU(), nn.Linear(1536, 1536, bias=True))
133
+
134
+ def tag_embedding(self, tag):
135
+ freqs = torch.exp(-np.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=tag.device)
136
+ args = tag[:, None].float() * freqs[None]
137
+ tag_freq = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
138
+ tag_embeds = self.tag_mlp(tag_freq)
139
+ return tag_embeds
140
+
141
+ def get_positional_encoding(self, input_points):
142
+ point_feats_embed = torch.zeros((10, 1536), dtype=torch.float32).to(input_points['point_slats'].feats.device)
143
+ labels = input_points['point_labels'].squeeze(-1)
144
+ point_feats_embed[labels == 1] = self.seg_embeddings.weight
145
+ return sp.SparseTensor(point_feats_embed, input_points['point_slats'].coords)
146
+
147
+ def forward(self, x_t, tex_slats, shape_slats, t, tags, cond, input_points, coords_len_list):
148
+ input_tex_feats_list = []
149
+ input_tex_coords_list = []
150
+ shape_feats_list = []
151
+ shape_coords_list = []
152
+ begin = 0
153
+ for coords_len in coords_len_list:
154
+ end = begin + coords_len
155
+ input_tex_feats_list.append(x_t.feats[begin:end])
156
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
157
+ input_tex_coords_list.append(x_t.coords[begin:end])
158
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
159
+ shape_feats_list.append(shape_slats.feats[begin:end])
160
+ shape_feats_list.append(shape_slats.feats[begin:end])
161
+ shape_coords_list.append(shape_slats.coords[begin:end])
162
+ shape_coords_list.append(shape_slats.coords[begin:end])
163
+ begin = end
164
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
165
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
166
+
167
+ tag_embeds = self.tag_embedding(tags)
168
+ point_embeds = self.get_positional_encoding(input_points)
169
+ output_tex_slats = self.flow_model(x_t, t, tag_embeds, cond, shape_slats, point_embeds, coords_len_list)
170
+
171
+ output_tex_feats_list = []
172
+ output_tex_coords_list = []
173
+ begin = 0
174
+ for coords_len in coords_len_list:
175
+ end = begin + coords_len
176
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
177
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
178
+ begin = begin + 2 * coords_len
179
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
180
+ return output_tex_slat
181
+
182
+ def make_texture_square_pow2(img: Image.Image, target_size=None):
183
+ w, h = img.size
184
+ max_side = max(w, h)
185
+ pow2 = 1
186
+ while pow2 < max_side:
187
+ pow2 *= 2
188
+ if target_size is not None:
189
+ pow2 = target_size
190
+ pow2 = min(pow2, 2048)
191
+ return img.resize((pow2, pow2), Image.BILINEAR)
192
+
193
+ def preprocess_scene_textures(asset):
194
+ if not isinstance(asset, trimesh.Scene):
195
+ return asset
196
+ TEX_KEYS = ["baseColorTexture", "normalTexture", "metallicRoughnessTexture", "emissiveTexture", "occlusionTexture"]
197
+ for geom in asset.geometry.values():
198
+ visual = getattr(geom, "visual", None)
199
+ mat = getattr(visual, "material", None)
200
+ if mat is None:
201
+ continue
202
+ for key in TEX_KEYS:
203
+ if not hasattr(mat, key):
204
+ continue
205
+ tex = getattr(mat, key)
206
+ if tex is None:
207
+ continue
208
+ if isinstance(tex, Image.Image):
209
+ setattr(mat, key, make_texture_square_pow2(tex))
210
+ elif hasattr(tex, "image") and tex.image is not None:
211
+ img = tex.image
212
+ if not isinstance(img, Image.Image):
213
+ img = Image.fromarray(img)
214
+ tex.image = make_texture_square_pow2(img)
215
+ if hasattr(mat, "image") and mat.image is not None:
216
+ img = mat.image
217
+ if not isinstance(img, Image.Image):
218
+ img = Image.fromarray(img)
219
+ mat.image = make_texture_square_pow2(img)
220
+ return asset
221
+
222
+ def process_glb_to_vxz(glb_path, vxz_path):
223
+ asset = trimesh.load(glb_path, force='scene')
224
+ asset = preprocess_scene_textures(asset)
225
+ aabb = asset.bounding_box.bounds
226
+ center = (aabb[0] + aabb[1]) / 2
227
+ scale = 0.99999 / (aabb[1] - aabb[0]).max()
228
+ asset.apply_translation(-center)
229
+ asset.apply_scale(scale)
230
+ mesh = asset.to_mesh()
231
+ vertices = torch.from_numpy(mesh.vertices).float()
232
+ faces = torch.from_numpy(mesh.faces).long()
233
+
234
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
235
+ vertices, faces, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
236
+ face_weight=1.0, boundary_weight=0.2, regularization_weight=1e-2, timing=False
237
+ )
238
+ vid = o_voxel.serialize.encode_seq(voxel_indices)
239
+ mapping = torch.argsort(vid)
240
+ voxel_indices = voxel_indices[mapping]
241
+ dual_vertices = dual_vertices[mapping]
242
+ intersected = intersected[mapping]
243
+
244
+ voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr(
245
+ asset, grid_size=512, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], timing=False
246
+ )
247
+ vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat)
248
+ mapping_mat = torch.argsort(vid_mat)
249
+ attributes = {k: v[mapping_mat] for k, v in attributes.items()}
250
+
251
+ dual_vertices = dual_vertices * 512 - voxel_indices
252
+ dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8)
253
+ intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
254
+
255
+ attributes['dual_vertices'] = dual_vertices
256
+ attributes['intersected'] = intersected
257
+ o_voxel.io.write(vxz_path, voxel_indices, attributes)
258
+
259
+ def vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, vxz_path):
260
+ coords, data = o_voxel.io.read(vxz_path)
261
+ coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords], dim=1).cuda()
262
+ vertices = (data['dual_vertices'].cuda() / 255)
263
+ intersected = torch.cat([data['intersected'] % 2, data['intersected'] // 2 % 2, data['intersected'] // 4 % 2], dim=-1).bool().cuda()
264
+ vertices_sparse = sp.SparseTensor(vertices, coords)
265
+ intersected_sparse = sp.SparseTensor(intersected.float(), coords)
266
+ with torch.no_grad():
267
+ shape_slat = shape_encoder(vertices_sparse, intersected_sparse)
268
+ shape_slat = sp.SparseTensor(shape_slat.feats.cuda(), shape_slat.coords.cuda())
269
+ shape_decoder.set_resolution(512)
270
+ meshes, subs = shape_decoder(shape_slat, return_subs=True)
271
+
272
+ base_color = (data['base_color'] / 255)
273
+ metallic = (data['metallic'] / 255)
274
+ roughness = (data['roughness'] / 255)
275
+ alpha = (data['alpha'] / 255)
276
+ attr = torch.cat([base_color, metallic, roughness, alpha], dim=-1).float().cuda() * 2 - 1
277
+ with torch.no_grad():
278
+ tex_slat = tex_encoder(sp.SparseTensor(attr, coords))
279
+ return shape_slat, meshes, subs, tex_slat
280
+
281
+ def preprocess_image(rembg_model, input):
282
+ if input.mode != "RGB":
283
+ bg = Image.new("RGB", input.size, (255, 255, 255))
284
+ bg.paste(input, mask=input.split()[3])
285
+ input = bg
286
+ has_alpha = False
287
+ if input.mode == 'RGBA':
288
+ alpha = np.array(input)[:, :, 3]
289
+ if not np.all(alpha == 255):
290
+ has_alpha = True
291
+ max_size = max(input.size)
292
+ scale = min(1, 1024 / max_size)
293
+ if scale < 1:
294
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
295
+ if has_alpha:
296
+ output = input
297
+ else:
298
+ input = input.convert('RGB')
299
+ output = rembg_model(input)
300
+ output_np = np.array(output)
301
+ alpha = output_np[:, :, 3]
302
+ bbox = np.argwhere(alpha > 0.8 * 255)
303
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
304
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
305
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
306
+ size = int(size * 1)
307
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
308
+ output = output.crop(bbox) # type: ignore
309
+ output = np.array(output).astype(np.float32) / 255
310
+ output = output[:, :, :3] * output[:, :, 3:4]
311
+ output = Image.fromarray((output * 255).astype(np.uint8))
312
+ return output
313
+
314
+ def get_cond(image_cond_model, image):
315
+ image_cond_model.image_size = 512
316
+ cond = image_cond_model(image)
317
+ neg_cond = torch.zeros_like(cond)
318
+ return {'cond': cond, 'neg_cond': neg_cond}
319
+
320
+ def tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, input_tex_slat, cond_dict, input_points, tag):
321
+ device = shape_slat.feats.device
322
+ shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None].to(device)
323
+ shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None].to(device)
324
+ tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None].to(device)
325
+ tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None].to(device)
326
+ shape_slat = ((shape_slat - shape_mean) / shape_std)
327
+ input_tex_slat = ((input_tex_slat - tex_mean) / tex_std)
328
+ coords_len_list = [shape_slat.coords.shape[0]]
329
+ noise = sp.SparseTensor(torch.randn_like(input_tex_slat.feats), shape_slat.coords)
330
+ output_tex_slat = sampler.sample(gen3dseg, noise, input_tex_slat, shape_slat, input_points, coords_len_list, cond_dict, tag, pipeline_args['tex_slat_sampler']['params'])
331
+ output_tex_slat = output_tex_slat * tex_std + tex_mean
332
+ return output_tex_slat
333
+
334
+ def slat_to_glb(meshes, tex_voxels, resolution=512):
335
+ pbr_attr_layout = {
336
+ 'base_color': slice(0, 3),
337
+ 'metallic': slice(3, 4),
338
+ 'roughness': slice(4, 5),
339
+ 'alpha': slice(5, 6),
340
+ }
341
+ out_mesh = []
342
+ for m, v in zip(meshes, tex_voxels):
343
+ m.fill_holes()
344
+ out_mesh.append(
345
+ MeshWithVoxel(
346
+ m.vertices, m.faces,
347
+ origin = [-0.5, -0.5, -0.5],
348
+ voxel_size = 1 / resolution,
349
+ coords = v.coords[:, 1:],
350
+ attrs = v.feats,
351
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
352
+ layout=pbr_attr_layout
353
+ )
354
+ )
355
+ mesh = out_mesh[0]
356
+ mesh.simplify(10000000)
357
+ # mesh.simplify(16777216) # nvdiffrast limit
358
+ glb = o_voxel.postprocess.to_glb(
359
+ vertices = mesh.vertices,
360
+ faces = mesh.faces,
361
+ attr_volume = mesh.attrs,
362
+ coords = mesh.coords,
363
+ attr_layout = mesh.layout,
364
+ voxel_size = mesh.voxel_size,
365
+ aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
366
+ decimation_target = 100000, # 1000000
367
+ texture_size = 4096,
368
+ remesh = True,
369
+ # remesh = False,
370
+ remesh_band = 1,
371
+ remesh_project = 0,
372
+ verbose = True
373
+ )
374
+ return glb
375
+
376
+ def inference(ckpt_path, item, tag, input_vxz_points_list=None):
377
+ print("-"*100)
378
+ print("Loading model ............")
379
+ with open("microsoft/TRELLIS.2-4B/pipeline.json", "r") as f:
380
+ pipeline_config = json.load(f)
381
+ pipeline_args = pipeline_config['args']
382
+ tex_slat_flow_model = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
383
+ tex_slat_flow_model.forward = MethodType(flow_forward, tex_slat_flow_model)
384
+
385
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
386
+ state_dict = torch.load(ckpt_path)['state_dict']
387
+ state_dict = OrderedDict([(k.replace("gen3dseg.", ""), v) for k, v in state_dict.items()])
388
+ gen3dseg.load_state_dict(state_dict)
389
+ gen3dseg.eval()
390
+ gen3dseg.cuda()
391
+ sampler = Sampler()
392
+
393
+ shape_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16").cuda().eval()
394
+ tex_encoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16").cuda().eval()
395
+ shape_decoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16").cuda().eval()
396
+ tex_decoder = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16").cuda().eval()
397
+
398
+ rembg_model = BiRefNet(model_name="briaai/RMBG-2.0")
399
+ rembg_model.cuda()
400
+ image_cond_model = DinoV3FeatureExtractor(model_name="facebook/dinov3-vitl16-pretrain-lvd1689m")
401
+ image_cond_model.cuda()
402
+
403
+ process_glb_to_vxz(item['glb'], item['input_vxz'])
404
+ shape_slat, meshes, subs, tex_slat = vxz_to_latent_slat(shape_encoder, shape_decoder, tex_encoder, item['input_vxz'])
405
+
406
+ print("-"*100)
407
+ print("Getting cond ............")
408
+ if tag in [0, 1]:
409
+ render_from_transforms(item['glb'], item['transforms'], item['img'])
410
+ image = Image.open(item['img'])
411
+ image = preprocess_image(rembg_model, image)
412
+ cond = get_cond(image_cond_model, [image])
413
+
414
+ print("-"*100)
415
+ print("Sampling .................")
416
+ if tag == 0:
417
+ vxz_points_coords = torch.tensor(input_vxz_points_list, dtype=torch.int32).cuda()
418
+ vxz_points_coords = torch.cat([torch.zeros((vxz_points_coords.shape[0], 1), dtype=torch.int32).cuda(), vxz_points_coords], dim=1)
419
+ input_points_coords = tex_encoder(sp.SparseTensor(torch.zeros((vxz_points_coords.shape[0], 6), dtype=torch.float32).cuda(), vxz_points_coords)).coords
420
+ input_points_coords = torch.unique(input_points_coords, dim=0)
421
+ point_num = input_points_coords.shape[0]
422
+ if point_num >= 10:
423
+ input_points_coords = input_points_coords[:10]
424
+ point_labels = torch.tensor(([[1]]*10), dtype=torch.int32).cuda()
425
+ else:
426
+ input_points_coords = torch.cat([input_points_coords, torch.zeros((10 - point_num, 4), dtype=torch.int32).cuda()], dim=0)
427
+ point_labels = torch.tensor(([[1]]*point_num+[[0]]*(10-point_num)), dtype=torch.int32).cuda()
428
+ else:
429
+ input_points_coords = torch.zeros((10, 4), dtype=torch.int32).cuda()
430
+ point_labels = torch.tensor(([[0]]*10), dtype=torch.int32).cuda()
431
+ input_points = {'point_slats': sp.SparseTensor(input_points_coords, input_points_coords), 'point_labels': point_labels}
432
+
433
+ output_tex_slat = tex_slat_sample_single(gen3dseg, sampler, pipeline_args, shape_slat, tex_slat, cond, input_points, tag)
434
+ with torch.no_grad():
435
+ tex_voxels = tex_decoder(output_tex_slat, guide_subs=subs) * 0.5 + 0.5
436
+
437
+ print("-"*100)
438
+ print("Exporting glb ............")
439
+ glb = slat_to_glb(meshes, tex_voxels)
440
+ glb.export(item['export_glb'])
441
+
442
+ if __name__ == "__main__":
443
+ ckpt_path = "path/to/unified.ckpt"
444
+ tag = 0
445
+ if tag == 0: # interactive seg
446
+ item = {
447
+ "glb": "./data_toolkit/assets/example.glb",
448
+ "input_vxz": "./data_toolkit/assets/input.vxz",
449
+ "transforms": "./data_toolkit/transforms.json",
450
+ "img": "./data_toolkit/assets/img.png",
451
+ "export_glb": "./data_toolkit/assets/output.glb"
452
+ }
453
+ input_vxz_points_list = [[388, 448, 392]] # example
454
+ inference(ckpt_path, item, tag, input_vxz_points_list)
455
+ elif tag == 1: # full seg
456
+ item = {
457
+ "glb": "./data_toolkit/assets/example.glb",
458
+ "input_vxz": "./data_toolkit/assets/input.vxz",
459
+ "transforms": "./data_toolkit/transforms.json",
460
+ "img": "./data_toolkit/assets/img.png",
461
+ "export_glb": "./data_toolkit/assets/output.glb"
462
+ }
463
+ inference(ckpt_path, item, tag)
464
+ elif tag == 2: # full seg with 2d map
465
+ item = {
466
+ "glb": "./data_toolkit/assets/example.glb",
467
+ "input_vxz": "./data_toolkit/assets/input.vxz",
468
+ "img": "./data_toolkit/assets/full_seg_w_2d_map/2d_map.png.png",
469
+ "export_glb": "./data_toolkit/assets/output.glb"
470
+ }
471
+ inference(ckpt_path, item, tag)
472
+ else:
473
+ raise ValueError(f"Invalid tag: {tag}")
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
+ torch==2.6.0
4
+ torchvision==0.21.0
5
+ triton==3.2.0
6
+ pillow==12.0.0
7
+ imageio==2.37.2
8
+ imageio-ffmpeg==0.6.0
9
+ tqdm==4.67.1
10
+ easydict==1.13
11
+ opencv-python-headless==4.12.0.88
12
+ trimesh==4.10.1
13
+ transformers==4.57.3
14
+ zstandard==0.25.0
15
+ kornia==0.8.2
16
+ timm==1.0.22
17
+ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
18
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
19
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
20
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
21
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
22
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrast-0.4.0-cp310-cp310-linux_x86_64.whl
23
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl
split.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ from collections import defaultdict
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import trimesh
8
+ from PIL import Image
9
+
10
+ # =========================
11
+ # 你只需要改这里
12
+ # =========================
13
+ # INPUT_GLB = "/mnt/pfs/users/huangzehuan/projects/SegviGen/examples/trellis2_output.glb"
14
+
15
+
16
+ # UID = "demonic_warrior_red_bronze_armor"
17
+ # UID = "playful_pose_white_top_portrait"
18
+ # UID = "african_inspired_metallic_silver_ensemble_with_headwrap"
19
+ # UID = "cyberpunk_bowser_motorcycle"
20
+ # UID = "crimson_battle_mecha_with_spikes"
21
+ UID = "black_lace_lingerie_ensemble"
22
+
23
+ INPUT_GLB = (
24
+ f"/mnt/pfs/users/maxueqi/studio/datasets/dense_mesh/segvigen_bak/{UID}/output.glb"
25
+ )
26
+
27
+ # 只用 RGB(忽略透明度/alpha)
28
+ COLOR_QUANT_STEP = 16 # RGB 量化步长:0/4/8/16(越大越“合并”)
29
+ PALETTE_SAMPLE_PIXELS = 2_000_000
30
+ PALETTE_MIN_PIXELS = 500 # 少于该像素数的颜色当噪声丢掉(边界抗锯齿中间色)
31
+ PALETTE_MAX_COLORS = 256 # 最多保留多少个主颜色
32
+ PALETTE_MERGE_DIST = 32 # ✅ 合并 palette 内近似颜色(解决“看着同色却拆两块”)
33
+
34
+ SAMPLES_PER_FACE = 4 # 1 或 4(推荐 4,能明显减少边界采样误差)
35
+ FLIP_V = True # glTF 常见需要 flip V
36
+ UV_WRAP_REPEAT = True # True: repeat (mod 1);False: clamp 到 [0,1]
37
+
38
+ MIN_FACES_PER_PART = 50
39
+ BAKE_TRANSFORMS = True
40
+ DEBUG_PRINT = True
41
+ # =========================
42
+
43
+
44
+ CHUNK_TYPE_JSON = 0x4E4F534A # b'JSON'
45
+ CHUNK_TYPE_BIN = 0x004E4942 # b'BIN\0'
46
+
47
+
48
+ def _default_out_path(in_path: str) -> str:
49
+ root, ext = os.path.splitext(in_path)
50
+ if ext.lower() not in [".glb", ".gltf"]:
51
+ ext = ".glb"
52
+ return f"{root}_seg.glb"
53
+
54
+
55
+ def _quantize_rgb(rgb: np.ndarray, step: int) -> np.ndarray:
56
+ """
57
+ rgb: (...,3) uint8
58
+ """
59
+ if step is None or step <= 0:
60
+ return rgb
61
+ q = (rgb.astype(np.int32) + step // 2) // step * step
62
+ q = np.clip(q, 0, 255).astype(np.uint8)
63
+ return q
64
+
65
+
66
+ def _load_glb_json_and_bin(glb_path: str) -> Tuple[dict, bytes]:
67
+ data = open(glb_path, "rb").read()
68
+ if len(data) < 12:
69
+ raise RuntimeError("Invalid GLB: too small")
70
+
71
+ magic, version, length = struct.unpack_from("<4sII", data, 0)
72
+ if magic != b"glTF":
73
+ raise RuntimeError("Not a GLB file (missing glTF header)")
74
+
75
+ offset = 12
76
+ gltf_json = None
77
+ bin_chunk = None
78
+
79
+ while offset + 8 <= len(data):
80
+ chunk_len, chunk_type = struct.unpack_from("<II", data, offset)
81
+ offset += 8
82
+ chunk_data = data[offset : offset + chunk_len]
83
+ offset += chunk_len
84
+
85
+ if chunk_type == CHUNK_TYPE_JSON:
86
+ gltf_json = chunk_data.decode("utf-8", errors="replace")
87
+ elif chunk_type == CHUNK_TYPE_BIN:
88
+ bin_chunk = chunk_data
89
+
90
+ if gltf_json is None:
91
+ raise RuntimeError("GLB missing JSON chunk")
92
+ if bin_chunk is None:
93
+ raise RuntimeError("GLB missing BIN chunk")
94
+
95
+ import json
96
+
97
+ return json.loads(gltf_json), bin_chunk
98
+
99
+
100
+ def _extract_basecolor_texture_image(glb_path: str) -> np.ndarray:
101
+ """
102
+ 从 GLB 内嵌资源里拿 baseColorTexture 的 PNG/JPG,返回 (H,W,4) uint8 RGBA
103
+ """
104
+ gltf, bin_chunk = _load_glb_json_and_bin(glb_path)
105
+
106
+ materials = gltf.get("materials", [])
107
+ textures = gltf.get("textures", [])
108
+ images = gltf.get("images", [])
109
+ buffer_views = gltf.get("bufferViews", [])
110
+
111
+ if not materials:
112
+ raise RuntimeError("No materials in GLB")
113
+
114
+ # 这里按 material[0] 取 baseColorTexture(你的 glb 只有一个材质/primitive)
115
+ pbr = materials[0].get("pbrMetallicRoughness", {})
116
+ base_tex_index = pbr.get("baseColorTexture", {}).get("index", None)
117
+ if base_tex_index is None:
118
+ raise RuntimeError("Material has no baseColorTexture")
119
+
120
+ if base_tex_index >= len(textures):
121
+ raise RuntimeError("baseColorTexture index out of range")
122
+
123
+ tex = textures[base_tex_index]
124
+ img_index = tex.get("source", None)
125
+ if img_index is None or img_index >= len(images):
126
+ raise RuntimeError("Texture has no valid image source")
127
+
128
+ img_info = images[img_index]
129
+ bv_index = img_info.get("bufferView", None)
130
+ mime = img_info.get("mimeType", None)
131
+ if bv_index is None:
132
+ uri = img_info.get("uri", None)
133
+ raise RuntimeError(f"Image is not embedded (bufferView missing). uri={uri}")
134
+
135
+ if bv_index >= len(buffer_views):
136
+ raise RuntimeError("image.bufferView out of range")
137
+
138
+ bv = buffer_views[bv_index]
139
+ bo = int(bv.get("byteOffset", 0))
140
+ bl = int(bv.get("byteLength", 0))
141
+ img_bytes = bin_chunk[bo : bo + bl]
142
+
143
+ if DEBUG_PRINT:
144
+ print(
145
+ f"[Texture] baseColorTextureIndex={base_tex_index}, imageIndex={img_index}, "
146
+ f"bufferView={bv_index}, mime={mime}, bytes={len(img_bytes)}"
147
+ )
148
+
149
+ pil = Image.open(trimesh.util.wrap_as_stream(img_bytes)).convert("RGBA")
150
+ return np.array(pil, dtype=np.uint8)
151
+
152
+
153
+ def _merge_palette_rgb(
154
+ palette_rgb: np.ndarray, counts: np.ndarray, merge_dist: float
155
+ ) -> np.ndarray:
156
+ """
157
+ 对 palette 内 RGB 做“近似合并”,用 counts 作为权重更新中心。
158
+ palette_rgb: (K,3) uint8
159
+ counts: (K,) int
160
+ """
161
+ if palette_rgb is None or len(palette_rgb) == 0:
162
+ return palette_rgb
163
+ if merge_dist is None or merge_dist <= 0:
164
+ return palette_rgb
165
+
166
+ rgb = palette_rgb.astype(np.float32)
167
+ counts = counts.astype(np.int64)
168
+
169
+ order = np.argsort(-counts)
170
+
171
+ centers = []
172
+ center_w = []
173
+ thr2 = float(merge_dist) * float(merge_dist)
174
+
175
+ for idx in order:
176
+ x = rgb[idx]
177
+ w = int(counts[idx])
178
+
179
+ if not centers:
180
+ centers.append(x.copy())
181
+ center_w.append(w)
182
+ continue
183
+
184
+ C = np.stack(centers, axis=0) # (M,3)
185
+ d2 = np.sum((C - x[None, :]) ** 2, axis=1)
186
+ k = int(np.argmin(d2))
187
+
188
+ if float(d2[k]) <= thr2:
189
+ cw = center_w[k]
190
+ centers[k] = (centers[k] * cw + x * w) / (cw + w)
191
+ center_w[k] = cw + w
192
+ else:
193
+ centers.append(x.copy())
194
+ center_w.append(w)
195
+
196
+ merged = np.clip(np.rint(np.stack(centers, axis=0)), 0, 255).astype(np.uint8)
197
+
198
+ if DEBUG_PRINT:
199
+ print(
200
+ f"[PaletteMerge] before={len(palette_rgb)} after={len(merged)} merge_dist={merge_dist}"
201
+ )
202
+
203
+ return merged
204
+
205
+
206
+ def _build_palette_rgb(tex_rgba: np.ndarray) -> np.ndarray:
207
+ """
208
+ 从贴图中提取 RGB 主颜色调色板(忽略 alpha)。
209
+ 返回: (K,3) uint8
210
+ """
211
+ rgb = tex_rgba[:, :, :3].reshape(-1, 3)
212
+ n = rgb.shape[0]
213
+
214
+ if n > PALETTE_SAMPLE_PIXELS:
215
+ rng = np.random.default_rng(0)
216
+ idx = rng.choice(n, size=PALETTE_SAMPLE_PIXELS, replace=False)
217
+ rgb = rgb[idx]
218
+
219
+ rgb = _quantize_rgb(rgb, COLOR_QUANT_STEP)
220
+
221
+ uniq, counts = np.unique(rgb, axis=0, return_counts=True)
222
+ order = np.argsort(-counts)
223
+ uniq = uniq[order]
224
+ counts = counts[order]
225
+
226
+ keep = counts >= PALETTE_MIN_PIXELS
227
+ uniq = uniq[keep]
228
+ counts = counts[keep]
229
+
230
+ if len(uniq) > PALETTE_MAX_COLORS:
231
+ uniq = uniq[:PALETTE_MAX_COLORS]
232
+ counts = counts[:PALETTE_MAX_COLORS]
233
+
234
+ if DEBUG_PRINT:
235
+ print(
236
+ f"[Palette] quant_step={COLOR_QUANT_STEP} palette_size(before_merge)={len(uniq)} "
237
+ f"min_pixels={PALETTE_MIN_PIXELS}"
238
+ )
239
+ for i in range(min(15, len(uniq))):
240
+ r, g, b = [int(x) for x in uniq[i]]
241
+ print(f" {i:02d} rgb=({r},{g},{b}) count={int(counts[i])}")
242
+
243
+ uniq = _merge_palette_rgb(uniq.astype(np.uint8), counts, PALETTE_MERGE_DIST)
244
+
245
+ if DEBUG_PRINT:
246
+ print(f"[Palette] palette_size(after_merge)={len(uniq)}")
247
+ for i in range(min(15, len(uniq))):
248
+ r, g, b = [int(x) for x in uniq[i]]
249
+ print(f" {i:02d} rgb=({r},{g},{b})")
250
+
251
+ return uniq.astype(np.uint8)
252
+
253
+
254
+ def _unwrap_uv3_for_seam(uv3: np.ndarray) -> np.ndarray:
255
+ """
256
+ uv3: (F,3,2). 若跨 seam(跨度>0.5),把小于0.5的一侧 +1,避免均值跑到另一边。
257
+ """
258
+ out = uv3.copy()
259
+ for d in range(2):
260
+ v = out[:, :, d]
261
+ vmin = v.min(axis=1)
262
+ vmax = v.max(axis=1)
263
+ seam = (vmax - vmin) > 0.5
264
+ if np.any(seam):
265
+ vv = v[seam]
266
+ vv = np.where(vv < 0.5, vv + 1.0, vv)
267
+ out[seam, :, d] = vv
268
+ return out
269
+
270
+
271
+ def _barycentric_samples(uv3: np.ndarray, samples_per_face: int) -> np.ndarray:
272
+ """
273
+ uv3: (F,3,2)
274
+ return: (F,S,2)
275
+ """
276
+ uv3 = _unwrap_uv3_for_seam(uv3)
277
+
278
+ if samples_per_face == 1:
279
+ w = np.array([1 / 3, 1 / 3, 1 / 3], dtype=np.float32)
280
+ uvs = uv3[:, 0, :] * w[0] + uv3[:, 1, :] * w[1] + uv3[:, 2, :] * w[2]
281
+ return uvs[:, None, :]
282
+
283
+ # 4 个点:中心 + 三个靠近顶点的内点(尽量远离边界抗锯齿带)
284
+ ws = np.array(
285
+ [
286
+ [1 / 3, 1 / 3, 1 / 3],
287
+ [0.80, 0.10, 0.10],
288
+ [0.10, 0.80, 0.10],
289
+ [0.10, 0.10, 0.80],
290
+ ],
291
+ dtype=np.float32,
292
+ )
293
+ uvs = (
294
+ uv3[:, None, 0, :] * ws[None, :, 0, None]
295
+ + uv3[:, None, 1, :] * ws[None, :, 1, None]
296
+ + uv3[:, None, 2, :] * ws[None, :, 2, None]
297
+ )
298
+ return uvs
299
+
300
+
301
+ def _wrap_or_clamp_uv(uv: np.ndarray) -> np.ndarray:
302
+ if UV_WRAP_REPEAT:
303
+ return np.mod(uv, 1.0)
304
+ return np.clip(uv, 0.0, 1.0)
305
+
306
+
307
+ def _sample_texture_nearest_rgb(tex_rgba: np.ndarray, uv: np.ndarray) -> np.ndarray:
308
+ """
309
+ tex_rgba: (H,W,4) uint8
310
+ uv: (N,2) float
311
+ return: (N,3) uint8
312
+ """
313
+ h, w = tex_rgba.shape[0], tex_rgba.shape[1]
314
+ uv = _wrap_or_clamp_uv(uv)
315
+
316
+ u = uv[:, 0]
317
+ v = uv[:, 1]
318
+ if FLIP_V:
319
+ v = 1.0 - v
320
+
321
+ x = np.rint(u * (w - 1)).astype(np.int32)
322
+ y = np.rint(v * (h - 1)).astype(np.int32)
323
+ x = np.clip(x, 0, w - 1)
324
+ y = np.clip(y, 0, h - 1)
325
+
326
+ return tex_rgba[y, x, :3].astype(np.uint8)
327
+
328
+
329
+ def _map_to_palette_rgb(
330
+ colors_rgb: np.ndarray, palette_rgb: np.ndarray, chunk: int = 20000
331
+ ) -> Tuple[np.ndarray, np.ndarray]:
332
+ """
333
+ 把采样到的 RGB 映射到最近的 palette RGB.
334
+ 如果 palette 为空,则用 colors_rgb 的 unique 作为“临时 palette”.
335
+ 返回:
336
+ labels: (N,) int
337
+ used_palette_rgb: (K,3) uint8
338
+ """
339
+ if palette_rgb is None or len(palette_rgb) == 0:
340
+ uniq, inv = np.unique(colors_rgb, axis=0, return_inverse=True)
341
+ return inv.astype(np.int32), uniq.astype(np.uint8)
342
+
343
+ c = colors_rgb.astype(np.float32)
344
+ p = palette_rgb.astype(np.float32)
345
+
346
+ out = np.empty((c.shape[0],), dtype=np.int32)
347
+ for i in range(0, c.shape[0], chunk):
348
+ cc = c[i : i + chunk]
349
+ d2 = ((cc[:, None, :] - p[None, :, :]) ** 2).sum(axis=2)
350
+ out[i : i + chunk] = np.argmin(d2, axis=1).astype(np.int32)
351
+
352
+ return out, palette_rgb
353
+
354
+
355
+ def _face_labels_from_texture_rgb(
356
+ mesh: trimesh.Trimesh,
357
+ tex_rgba: np.ndarray,
358
+ palette_rgb: np.ndarray,
359
+ ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
360
+ """
361
+ 用 TEXCOORD_0 + baseColorTexture,为每个 face 采样 RGB,并映射到 palette label。
362
+ 返回:
363
+ face_label: (F,) int
364
+ label_rgb: (K,3) uint8
365
+ """
366
+ uv = getattr(mesh.visual, "uv", None)
367
+ if uv is None:
368
+ return None
369
+
370
+ uv = np.asarray(uv, dtype=np.float32)
371
+ if uv.ndim != 2 or uv.shape[1] != 2 or uv.shape[0] != len(mesh.vertices):
372
+ return None
373
+
374
+ faces = mesh.faces
375
+ uv3 = uv[faces] # (F,3,2)
376
+
377
+ uvs = _barycentric_samples(uv3, SAMPLES_PER_FACE) # (F,S,2)
378
+ F, S = uvs.shape[0], uvs.shape[1]
379
+ flat_uv = uvs.reshape(-1, 2)
380
+
381
+ sampled_rgb = _sample_texture_nearest_rgb(tex_rgba, flat_uv) # (F*S,3)
382
+ sampled_rgb = _quantize_rgb(sampled_rgb, COLOR_QUANT_STEP)
383
+
384
+ sample_label, used_palette = _map_to_palette_rgb(sampled_rgb, palette_rgb)
385
+ sample_label = sample_label.reshape(F, S)
386
+
387
+ if S == 1:
388
+ return sample_label[:, 0].astype(np.int32), used_palette
389
+
390
+ # 4 票投票(向量化)
391
+ l0, l1, l2, l3 = (
392
+ sample_label[:, 0],
393
+ sample_label[:, 1],
394
+ sample_label[:, 2],
395
+ sample_label[:, 3],
396
+ )
397
+ c0 = 1 + (l0 == l1) + (l0 == l2) + (l0 == l3)
398
+ c1 = 1 + (l1 == l0) + (l1 == l2) + (l1 == l3)
399
+ c2 = 1 + (l2 == l0) + (l2 == l1) + (l2 == l3)
400
+ c3 = 1 + (l3 == l0) + (l3 == l1) + (l3 == l2)
401
+
402
+ counts = np.stack([c0, c1, c2, c3], axis=1) # (F,4)
403
+ vals = np.stack([l0, l1, l2, l3], axis=1) # (F,4)
404
+ best = vals[np.arange(F), np.argmax(counts, axis=1)]
405
+ return best.astype(np.int32), used_palette
406
+
407
+
408
+ # =========================
409
+ # 拓扑纠错
410
+ # =========================
411
+
412
+ import numpy as np
413
+ import trimesh
414
+ from scipy.sparse import coo_matrix
415
+ from scipy.sparse.csgraph import connected_components
416
+
417
+
418
+ def _get_physical_face_adjacency(mesh: trimesh.Trimesh) -> np.ndarray:
419
+ """
420
+ 忽略 UV 接缝,计算纯物理空间上的面片相邻关系。
421
+ """
422
+ # 1. 四舍五入顶点坐标(处理浮点数微小误差),找出空间中真正唯一的物理顶点
423
+ v_rounded = np.round(mesh.vertices, decimals=3)
424
+ v_unique, inv_indices = np.unique(v_rounded, axis=0, return_inverse=True)
425
+
426
+ # 2. 将原本的面片索引,映射到这些“唯一物理顶点”上
427
+ # 这样,跨越 UV 接缝的面片,此时它们引用的顶点索引就变成一样的了
428
+ physical_faces = inv_indices[mesh.faces]
429
+
430
+ # 3. 创建一个临时的“影子网格”(process=False 极其重要,防止 trimesh 内部重排面片)
431
+ tmp_mesh = trimesh.Trimesh(vertices=v_unique, faces=physical_faces, process=False)
432
+
433
+ # 返回影子网格的物理相邻边
434
+ return tmp_mesh.face_adjacency
435
+
436
+
437
+ def smooth_face_labels_by_topology(
438
+ mesh: trimesh.Trimesh, face_label: np.ndarray, min_faces: int = 50
439
+ ) -> np.ndarray:
440
+ """
441
+ 通过真实的 3D 物理拓扑关系过滤飞点,跨越 UV 接缝合并色块。
442
+
443
+ Phase 1: 在同色连通图上,把挨着大块的小块吞并到大块中。
444
+ Phase 2: 对残留小块(邻居全是小块),回退到全物理邻接,
445
+ 按物理邻居中的多数 label 吞并。
446
+ Phase 3: 对完全孤立的面片(无物理邻接边),按面片质心距离
447
+ 找最近的非孤立面片,继承其 label。
448
+ """
449
+ labels = face_label.copy()
450
+ edges = _get_physical_face_adjacency(mesh)
451
+ F = len(mesh.faces)
452
+
453
+ # ---- Phase 1: 同色连通域平滑 ----
454
+ for iteration in range(3):
455
+ same_label = labels[edges[:, 0]] == labels[edges[:, 1]]
456
+ sub_edges = edges[same_label]
457
+
458
+ if len(sub_edges) > 0:
459
+ data = np.ones(len(sub_edges), dtype=bool)
460
+ graph = coo_matrix((data, (sub_edges[:, 0], sub_edges[:, 1])), shape=(F, F))
461
+ graph = graph.maximum(graph.T)
462
+ n_components, comp_labels = connected_components(graph, directed=False)
463
+ else:
464
+ n_components = F
465
+ comp_labels = np.arange(F)
466
+
467
+ comp_sizes = np.bincount(comp_labels, minlength=n_components)
468
+ small_comps = np.where(comp_sizes < min_faces)[0]
469
+ if len(small_comps) == 0:
470
+ break
471
+
472
+ is_small = np.isin(comp_labels, small_comps)
473
+
474
+ mask0 = is_small[edges[:, 0]]
475
+ mask1 = is_small[edges[:, 1]]
476
+
477
+ boundary_edges_0 = edges[mask0 & ~mask1]
478
+ boundary_edges_1 = edges[mask1 & ~mask0]
479
+
480
+ b_inner = np.concatenate([boundary_edges_0[:, 0], boundary_edges_1[:, 1]])
481
+ b_outer = np.concatenate([boundary_edges_0[:, 1], boundary_edges_1[:, 0]])
482
+
483
+ if len(b_inner) == 0:
484
+ break
485
+
486
+ outer_labels = labels[b_outer]
487
+ inner_comps = comp_labels[b_inner]
488
+
489
+ for cid in np.unique(inner_comps):
490
+ cid_mask = inner_comps == cid
491
+ surrounding_labels = outer_labels[cid_mask]
492
+ if len(surrounding_labels) > 0:
493
+ best_label = np.bincount(surrounding_labels).argmax()
494
+ labels[comp_labels == cid] = best_label
495
+
496
+ # ---- Phase 2: 用全物理邻接处理残留小块 ----
497
+ # 重新计算同色连通域,找出还残留的小块
498
+ same_label = labels[edges[:, 0]] == labels[edges[:, 1]]
499
+ sub_edges = edges[same_label]
500
+ if len(sub_edges) > 0:
501
+ data = np.ones(len(sub_edges), dtype=bool)
502
+ graph = coo_matrix((data, (sub_edges[:, 0], sub_edges[:, 1])), shape=(F, F))
503
+ graph = graph.maximum(graph.T)
504
+ n_components, comp_labels = connected_components(graph, directed=False)
505
+ else:
506
+ n_components = F
507
+ comp_labels = np.arange(F)
508
+
509
+ comp_sizes = np.bincount(comp_labels, minlength=n_components)
510
+ small_comps_set = set(np.where(comp_sizes < min_faces)[0])
511
+
512
+ if small_comps_set:
513
+ is_small = np.array([comp_labels[i] in small_comps_set for i in range(F)])
514
+
515
+ # 构建全物理邻接查找表: face -> set of neighbor faces
516
+ adj = defaultdict(set)
517
+ for e0, e1 in edges:
518
+ adj[int(e0)].add(int(e1))
519
+ adj[int(e1)].add(int(e0))
520
+
521
+ # 迭代:每轮让小块面片从物理邻居(忽略颜色)中投票取多数 label
522
+ for _ in range(3):
523
+ changed = False
524
+ small_comps_now = set(
525
+ int(c)
526
+ for c in range(n_components)
527
+ if comp_sizes[c] < min_faces and c in small_comps_set
528
+ )
529
+ if not small_comps_now:
530
+ break
531
+
532
+ for cid in small_comps_now:
533
+ cid_faces = np.where(comp_labels == cid)[0]
534
+ # 收集所有物理邻居中不属于本连通域的面片的 label
535
+ neighbor_labels = []
536
+ for fi in cid_faces:
537
+ for nf in adj[int(fi)]:
538
+ if comp_labels[nf] != cid:
539
+ neighbor_labels.append(labels[nf])
540
+
541
+ if len(neighbor_labels) > 0:
542
+ best_label = int(np.bincount(neighbor_labels).argmax())
543
+ labels[cid_faces] = best_label
544
+ changed = True
545
+
546
+ if not changed:
547
+ break
548
+
549
+ # 重新计算连通域
550
+ same_label = labels[edges[:, 0]] == labels[edges[:, 1]]
551
+ sub_edges = edges[same_label]
552
+ if len(sub_edges) > 0:
553
+ data = np.ones(len(sub_edges), dtype=bool)
554
+ graph = coo_matrix(
555
+ (data, (sub_edges[:, 0], sub_edges[:, 1])), shape=(F, F)
556
+ )
557
+ graph = graph.maximum(graph.T)
558
+ n_components, comp_labels = connected_components(graph, directed=False)
559
+ else:
560
+ n_components = F
561
+ comp_labels = np.arange(F)
562
+ comp_sizes = np.bincount(comp_labels, minlength=n_components)
563
+ small_comps_set = set(np.where(comp_sizes < min_faces)[0])
564
+
565
+ # ---- Phase 3: 完全孤立面片(无物理邻接边),按质心距离继承 label ----
566
+ same_label = labels[edges[:, 0]] == labels[edges[:, 1]]
567
+ sub_edges = edges[same_label]
568
+ if len(sub_edges) > 0:
569
+ data = np.ones(len(sub_edges), dtype=bool)
570
+ graph = coo_matrix((data, (sub_edges[:, 0], sub_edges[:, 1])), shape=(F, F))
571
+ graph = graph.maximum(graph.T)
572
+ _, comp_labels = connected_components(graph, directed=False)
573
+ else:
574
+ comp_labels = np.arange(F)
575
+ comp_sizes = np.bincount(comp_labels)
576
+ orphan_comps = set(np.where(comp_sizes < min_faces)[0])
577
+
578
+ if orphan_comps:
579
+ orphan_mask = np.array([comp_labels[i] in orphan_comps for i in range(F)])
580
+ non_orphan_mask = ~orphan_mask
581
+ if non_orphan_mask.any() and orphan_mask.any():
582
+ centroids = mesh.triangles_center
583
+ orphan_indices = np.where(orphan_mask)[0]
584
+ non_orphan_indices = np.where(non_orphan_mask)[0]
585
+ non_orphan_centroids = centroids[non_orphan_indices]
586
+
587
+ for oi in orphan_indices:
588
+ dists = np.linalg.norm(non_orphan_centroids - centroids[oi], axis=1)
589
+ nearest = non_orphan_indices[np.argmin(dists)]
590
+ labels[oi] = labels[nearest]
591
+
592
+ if DEBUG_PRINT:
593
+ n_orphan = int(orphan_mask.sum())
594
+ print(f" [Phase3] Assigned {n_orphan} orphan faces by centroid proximity")
595
+
596
+ return labels
597
+
598
+
599
+ # =========================
600
+ # 分割主函数
601
+ # =========================
602
+
603
+
604
+ # def split_glb_by_texture_palette_rgb(
605
+ # in_glb_path: str,
606
+ # out_glb_path: Optional[str] = None,
607
+ # min_faces_per_part: int = 1,
608
+ # bake_transforms: bool = True,
609
+ # ) -> str:
610
+ # """
611
+ # 输入:glb(无 COLOR_0,但有 baseColorTexture + TEXCOORD_0)
612
+ # 输出:先从贴图提取 RGB 主色 palette(忽略 alpha),再按 palette label 分割
613
+ # """
614
+ # if out_glb_path is None:
615
+ # out_glb_path = _default_out_path(in_glb_path)
616
+
617
+ # tex_rgba = _extract_basecolor_texture_image(in_glb_path)
618
+ # palette_rgb = _build_palette_rgb(tex_rgba)
619
+
620
+ # scene = trimesh.load(in_glb_path, force="scene", process=False)
621
+ # out_scene = trimesh.Scene()
622
+
623
+ # part_count = 0
624
+ # base = os.path.splitext(os.path.basename(in_glb_path))[0]
625
+
626
+ # for node_name in scene.graph.nodes_geometry:
627
+ # geom_name = scene.graph[node_name][1]
628
+ # if geom_name is None:
629
+ # continue
630
+
631
+ # geom = scene.geometry.get(geom_name, None)
632
+ # if geom is None or not isinstance(geom, trimesh.Trimesh):
633
+ # continue
634
+
635
+ # mesh = geom.copy()
636
+
637
+ # if bake_transforms:
638
+ # T, _ = scene.graph.get(node_name)
639
+ # if T is not None:
640
+ # mesh.apply_transform(T)
641
+
642
+ # res = _face_labels_from_texture_rgb(mesh, tex_rgba, palette_rgb)
643
+ # if res is None:
644
+ # if DEBUG_PRINT:
645
+ # print(f"[{node_name}] no uv / cannot sample -> keep orig")
646
+ # out_scene.add_geometry(mesh, geom_name=f"{base}__{node_name}__orig")
647
+ # continue
648
+
649
+ # face_label, label_rgb = res
650
+
651
+ # # =========================
652
+ # # 🔥 新增调用:进行拓扑纠错,合并飞点
653
+ # # =========================
654
+ # face_label = smooth_face_labels_by_topology(mesh, face_label, min_faces=100)
655
+
656
+ # if DEBUG_PRINT:
657
+ # uniq_labels, cnts = np.unique(face_label, return_counts=True)
658
+ # order = np.argsort(-cnts)
659
+ # print(
660
+ # f"[{node_name}] faces={len(mesh.faces)} labels_used={len(uniq_labels)} palette_size={len(label_rgb)}"
661
+ # )
662
+ # for i in order[:10]:
663
+ # lab = int(uniq_labels[i])
664
+ # r, g, b = (
665
+ # [int(x) for x in label_rgb[lab]]
666
+ # if 0 <= lab < len(label_rgb)
667
+ # else (0, 0, 0)
668
+ # )
669
+ # print(f" label={lab} rgb=({r},{g},{b}) faces={int(cnts[i])}")
670
+
671
+ # groups = defaultdict(list)
672
+ # for fi, lab in enumerate(face_label):
673
+ # groups[int(lab)].append(fi)
674
+
675
+ # for lab, face_ids in groups.items():
676
+ # if len(face_ids) < min_faces_per_part:
677
+ # continue
678
+
679
+ # sub = mesh.submesh(
680
+ # [np.array(face_ids, dtype=np.int64)], append=True, repair=False
681
+ # )
682
+ # if sub is None:
683
+ # continue
684
+ # if isinstance(sub, (list, tuple)):
685
+ # if not sub:
686
+ # continue
687
+ # sub = sub[0]
688
+
689
+ # if 0 <= lab < len(label_rgb):
690
+ # r, g, b = [int(x) for x in label_rgb[lab]]
691
+ # part_name = f"{base}__{node_name}__label_{lab}__rgb_{r}_{g}_{b}"
692
+ # else:
693
+ # part_name = f"{base}__{node_name}__label_{lab}"
694
+
695
+ # out_scene.add_geometry(sub, geom_name=part_name)
696
+ # part_count += 1
697
+
698
+ # if part_count == 0:
699
+ # if DEBUG_PRINT:
700
+ # print("[INFO] part_count==0, fallback to original scene export.")
701
+ # out_scene = scene
702
+
703
+ # out_scene.export(out_glb_path)
704
+ # return out_glb_path
705
+
706
+
707
+ def split_glb_by_texture_palette_rgb(
708
+ in_glb_path: str,
709
+ out_glb_path: Optional[str] = None,
710
+ min_faces_per_part: int = 1,
711
+ bake_transforms: bool = True,
712
+ color_quant_step: int = 16,
713
+ palette_sample_pixels: int = 2_000_000,
714
+ palette_min_pixels: int = 500,
715
+ palette_max_colors: int = 256,
716
+ palette_merge_dist: int = 32,
717
+ samples_per_face: int = 4,
718
+ flip_v: bool = True,
719
+ uv_wrap_repeat: bool = True,
720
+ transition_conf_thresh: float = 1.0,
721
+ transition_prop_iters: int = 6,
722
+ transition_neighbor_min: int = 1,
723
+ small_component_action: str = "reassign",
724
+ small_component_min_faces: int = 50,
725
+ postprocess_iters: int = 3,
726
+ debug_print: bool = True,
727
+ ) -> str:
728
+ """
729
+ Input: GLB (no COLOR_0, but with baseColorTexture + TEXCOORD_0)
730
+ Output: Split based on palette labels derived from baseColorTexture
731
+ """
732
+ if out_glb_path is None:
733
+ out_glb_path = _default_out_path(in_glb_path)
734
+
735
+ tex_rgba = _extract_basecolor_texture_image(in_glb_path)
736
+ palette_rgb = _build_palette_rgb(tex_rgba)
737
+
738
+ scene = trimesh.load(in_glb_path, force="scene", process=False)
739
+ out_scene = trimesh.Scene()
740
+
741
+ part_count = 0
742
+ base = os.path.splitext(os.path.basename(in_glb_path))[0]
743
+
744
+ for node_name in scene.graph.nodes_geometry:
745
+ geom_name = scene.graph[node_name][1]
746
+ if geom_name is None:
747
+ continue
748
+
749
+ geom = scene.geometry.get(geom_name, None)
750
+ if geom is None or not isinstance(geom, trimesh.Trimesh):
751
+ continue
752
+
753
+ mesh = geom.copy()
754
+
755
+ if bake_transforms:
756
+ T, _ = scene.graph.get(node_name)
757
+ if T is not None:
758
+ mesh.apply_transform(T)
759
+
760
+ res = _face_labels_from_texture_rgb(mesh, tex_rgba, palette_rgb)
761
+ if res is None:
762
+ if debug_print:
763
+ print(f"[{node_name}] no uv / cannot sample -> keep orig")
764
+ out_scene.add_geometry(mesh, geom_name=f"{base}__{node_name}__orig")
765
+ continue
766
+
767
+ face_label, label_rgb = res
768
+
769
+ # =========================
770
+ # 🔥 New: Apply topology correction to merge small disconnected components
771
+ # =========================
772
+ face_label = smooth_face_labels_by_topology(mesh, face_label, min_faces=100)
773
+
774
+ if debug_print:
775
+ uniq_labels, cnts = np.unique(face_label, return_counts=True)
776
+ order = np.argsort(-cnts)
777
+ print(
778
+ f"[{node_name}] faces={len(mesh.faces)} labels_used={len(uniq_labels)} palette_size={len(label_rgb)}"
779
+ )
780
+ for i in order[:10]:
781
+ lab = int(uniq_labels[i])
782
+ r, g, b = (
783
+ [int(x) for x in label_rgb[lab]]
784
+ if 0 <= lab < len(label_rgb)
785
+ else (0, 0, 0)
786
+ )
787
+ print(f" label={lab} rgb=({r},{g},{b}) faces={int(cnts[i])}")
788
+
789
+ groups = defaultdict(list)
790
+ for fi, lab in enumerate(face_label):
791
+ groups[int(lab)].append(fi)
792
+
793
+ for lab, face_ids in groups.items():
794
+ if len(face_ids) < min_faces_per_part:
795
+ continue
796
+
797
+ sub = mesh.submesh([np.array(face_ids, dtype=np.int64)], append=True, repair=False)
798
+ if sub is None:
799
+ continue
800
+ if isinstance(sub, (list, tuple)):
801
+ if not sub:
802
+ continue
803
+ sub = sub[0]
804
+
805
+ if 0 <= lab < len(label_rgb):
806
+ r, g, b = [int(x) for x in label_rgb[lab]]
807
+ part_name = f"{base}__{node_name}__label_{lab}__rgb_{r}_{g}_{b}"
808
+ else:
809
+ part_name = f"{base}__{node_name}__label_{lab}"
810
+
811
+ out_scene.add_geometry(sub, geom_name=part_name)
812
+ part_count += 1
813
+
814
+ if part_count == 0:
815
+ if debug_print:
816
+ print("[INFO] part_count==0, fallback to original scene export.")
817
+ out_scene = scene
818
+
819
+ out_scene.export(out_glb_path)
820
+ return out_glb_path
821
+
822
+ def main():
823
+ out_path = split_glb_by_texture_palette_rgb(
824
+ INPUT_GLB,
825
+ out_glb_path=None,
826
+ min_faces_per_part=MIN_FACES_PER_PART,
827
+ bake_transforms=BAKE_TRANSFORMS,
828
+ )
829
+ print("Done. Exported:", out_path)
830
+
831
+
832
+ if __name__ == "__main__":
833
+ main()
split_ori.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ from collections import defaultdict, Counter
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import trimesh
8
+ from PIL import Image
9
+
10
+
11
+ # =========================
12
+ # 你只需要改这里
13
+ # =========================
14
+ INPUT_GLB = "/media/nfs/tmp_data/fenghr/SegviGen/data_toolkit/assets/output.glb" # 输入 GLB 路径
15
+
16
+ # -------------------------
17
+ # 颜色/调色板(RGB-only)
18
+ # -------------------------
19
+ COLOR_QUANT_STEP = 28
20
+ # RGB 量化步长:把颜色通道四舍五入到该步长的倍数(减少颜色抖动/过渡色)
21
+ # - 变大:颜色更粗、更容易合并为少数类(杂点/过渡色更少),但可能把本该不同的类别合并
22
+ # - 变小:颜色更细、更容易保留差异,但过渡/噪声颜色会增多,后处理压力更大
23
+
24
+ PALETTE_SAMPLE_PIXELS = 2_000_000
25
+ # 构建调色板时,从贴图中最多抽样多少像素用于统计颜色频次
26
+ # - 变大:调色板统计更准(更接近真实分布),但更慢、内存更高
27
+ # - 变小:更快,但可能漏掉小类别或统计不稳定
28
+
29
+ PALETTE_MIN_PIXELS = 300
30
+ # 调色板过滤阈值:在抽样像素中,出现次数 < 该值的颜色视为“边界过渡/噪声”,不进调色板
31
+ # - 变大:更激进地丢掉稀有颜色(减少边界过渡色导致的新label),但可能误删小零件类别
32
+ # - 变小:保留更多稀有颜色(小类别更容易保住),但也会引入更多噪声颜色
33
+
34
+ PALETTE_MAX_COLORS = 256
35
+ # 调色板最多保留的主颜色数量(按频次从高到低截断)
36
+ # - 变大:允许更多类别(细分更多),但更可能包含噪声色、导致碎块
37
+ # - 变小:类别数更少、更干净,但可能把多类合并
38
+
39
+ PALETTE_MERGE_DIST = 65
40
+ # 调色板内近似颜色合并阈值(RGB欧氏距离):把“看起来同色”的多种RGB合并成一个代表色
41
+ # - 变大:更容易把近似色合并(解决“肉眼同色却拆两块”),但过大可能把相邻类别合并
42
+ # - 变小:更保留差异,但同色可能仍分裂成多个label
43
+
44
+ # -------------------------
45
+ # UV 采样
46
+ # -------------------------
47
+ SAMPLES_PER_FACE = 4
48
+ # 每个三角面采样点数:1=只取中心;4=中心+靠近3个顶点(投票),更抗边界抗锯齿
49
+ # - 用 4:显著减少“采到边界中间色”导致的错分/过渡带
50
+ # - 用 1:最快,但更容易生成过渡带/杂点
51
+
52
+ FLIP_V = True
53
+ # UV 的 V 轴是否翻转(glTF 常见需要翻转才能和贴图坐标对齐)
54
+ # - 若你发现颜色整体错位/全错:优先尝试把它改成 False
55
+
56
+ UV_WRAP_REPEAT = True
57
+ # UV 超出 [0,1] 时的处理:True=repeat(mod 1),False=clamp 到边界
58
+ # - repeat:适合贴图采用重复寻址的情况
59
+ # - clamp:适合贴图不重复、超界应贴边的情况(repeat 可能采到错误区域)
60
+
61
+ # -------------------------
62
+ # ✅ “过渡面/边界带”归并(解决你说的:多出来一整块过渡区域)
63
+ # -------------------------
64
+ TRANSITION_CONF_THRESH = 1.0
65
+ # 过渡面判定阈值:置信度 = 4次采样中最多票数 / 4
66
+ # - 1.0:只要不是 4/4 完全一致,就当过渡面(最强去过渡带,最不容易多出一整块)
67
+ # - 0.75:只有出现 2-2 或更不稳定才算过渡面(更保守,边界更“原汁原味”,但可能残留过渡带)
68
+ # - 变大:更多面被当过渡面,会更强力贴合到两侧,但可能“抹边”更明显
69
+ # - 变小:更少面被当过渡面,保留边界细节,但更可能留下过渡区域
70
+
71
+ TRANSITION_PROP_ITERS = 6
72
+ # 标签传播迭代次数:把过渡面按邻居多数投票逐轮吸收到稳定区域
73
+ # - 变大:传播更充分,过渡带更容易被“吃掉”,但边界可能被推得更远/更平滑
74
+ # - 变小:传播更少,边界更保留,但可能仍残留部分过渡带
75
+
76
+ TRANSITION_NEIGHBOR_MIN = 1
77
+ # 过渡面更新时,邻居投票的最小票数要求(防止太少证据就改)
78
+ # - 变大:更新更谨慎,不容易被少数邻居误导,但可能残留过渡点
79
+ # - 变小:更容易被吸收,去噪更强,但可能稍微更“糊边”
80
+
81
+ # -------------------------
82
+ # ✅ 小连通块收尾(主要消“杂点小岛”,不是边界带)
83
+ # -------------------------
84
+ SMALL_COMPONENT_ACTION = "reassign"
85
+ # 小连通块处理方式:
86
+ # - "reassign":把小块按空间邻接归到周围大块(通常更符合“去杂点”)
87
+ # - "drop":直接丢掉这些面(输出会缺面,不建议除非你能接受空洞)
88
+
89
+ SMALL_COMPONENT_MIN_FACES = 50
90
+ # 小连通块阈值:某个 label 的一个连通块 face 数 < 该值,就当杂点处理
91
+ # - 变大:更强去杂点,但可能误伤真实小零件
92
+ # - 变小:更保留小零件,但杂点可能更多
93
+
94
+ POSTPROCESS_ITERS = 3
95
+ # 小连通块处理迭代次数(reassign时):
96
+ # - 变大:更彻底清理杂点,但更可能“抹掉”小细节
97
+ # - 变小:更保守
98
+
99
+ # -------------------------
100
+ # 导出过滤/其他
101
+ # -------------------------
102
+ MIN_FACES_PER_PART = 1
103
+ # 导出时��最小面数过滤:某个 part 的 face 数 < 该值就不导出
104
+ # - 变大:输出更干净(少碎片),但会丢失小零件
105
+ # - 变小:保留全部(包括小碎片)
106
+
107
+ BAKE_TRANSFORMS = True
108
+ # 是否把 node 的世界变换烘焙到顶点(True 更稳,导出后位置不容易错)
109
+ # - 一般保持 True;除非你明确想保留层级变换
110
+
111
+ DEBUG_PRINT = True
112
+ # 是否打印调试信息(palette大小、过渡面数量、迭代变化等)
113
+ # - True:方便调参;稳定后可关掉
114
+ # =========================
115
+
116
+
117
+ CHUNK_TYPE_JSON = 0x4E4F534A # b'JSON'
118
+ CHUNK_TYPE_BIN = 0x004E4942 # b'BIN\0'
119
+
120
+
121
+ def _default_out_path(in_path: str) -> str:
122
+ root, ext = os.path.splitext(in_path)
123
+ if ext.lower() not in [".glb", ".gltf"]:
124
+ ext = ".glb"
125
+ return f"{root}_seg.glb"
126
+
127
+
128
+ def _quantize_rgb(rgb: np.ndarray, step: int) -> np.ndarray:
129
+ if step is None or step <= 0:
130
+ return rgb
131
+ q = (rgb.astype(np.int32) + step // 2) // step * step
132
+ q = np.clip(q, 0, 255).astype(np.uint8)
133
+ return q
134
+
135
+
136
+ def _load_glb_json_and_bin(glb_path: str) -> Tuple[dict, bytes]:
137
+ data = open(glb_path, "rb").read()
138
+ if len(data) < 12:
139
+ raise RuntimeError("Invalid GLB: too small")
140
+
141
+ magic, version, length = struct.unpack_from("<4sII", data, 0)
142
+ if magic != b"glTF":
143
+ raise RuntimeError("Not a GLB file (missing glTF header)")
144
+
145
+ offset = 12
146
+ gltf_json = None
147
+ bin_chunk = None
148
+
149
+ while offset + 8 <= len(data):
150
+ chunk_len, chunk_type = struct.unpack_from("<II", data, offset)
151
+ offset += 8
152
+ chunk_data = data[offset: offset + chunk_len]
153
+ offset += chunk_len
154
+
155
+ if chunk_type == CHUNK_TYPE_JSON:
156
+ gltf_json = chunk_data.decode("utf-8", errors="replace")
157
+ elif chunk_type == CHUNK_TYPE_BIN:
158
+ bin_chunk = chunk_data
159
+
160
+ if gltf_json is None:
161
+ raise RuntimeError("GLB missing JSON chunk")
162
+ if bin_chunk is None:
163
+ raise RuntimeError("GLB missing BIN chunk")
164
+
165
+ import json
166
+ return json.loads(gltf_json), bin_chunk
167
+
168
+
169
+ def _extract_basecolor_texture_image(glb_path: str) -> np.ndarray:
170
+ gltf, bin_chunk = _load_glb_json_and_bin(glb_path)
171
+
172
+ materials = gltf.get("materials", [])
173
+ textures = gltf.get("textures", [])
174
+ images = gltf.get("images", [])
175
+ buffer_views = gltf.get("bufferViews", [])
176
+
177
+ if not materials:
178
+ raise RuntimeError("No materials in GLB")
179
+
180
+ pbr = materials[0].get("pbrMetallicRoughness", {})
181
+ base_tex_index = pbr.get("baseColorTexture", {}).get("index", None)
182
+ if base_tex_index is None:
183
+ raise RuntimeError("Material has no baseColorTexture")
184
+ if base_tex_index >= len(textures):
185
+ raise RuntimeError("baseColorTexture index out of range")
186
+
187
+ tex = textures[base_tex_index]
188
+ img_index = tex.get("source", None)
189
+ if img_index is None or img_index >= len(images):
190
+ raise RuntimeError("Texture has no valid image source")
191
+
192
+ img_info = images[img_index]
193
+ bv_index = img_info.get("bufferView", None)
194
+ mime = img_info.get("mimeType", None)
195
+ if bv_index is None:
196
+ uri = img_info.get("uri", None)
197
+ raise RuntimeError(f"Image is not embedded (bufferView missing). uri={uri}")
198
+ if bv_index >= len(buffer_views):
199
+ raise RuntimeError("image.bufferView out of range")
200
+
201
+ bv = buffer_views[bv_index]
202
+ bo = int(bv.get("byteOffset", 0))
203
+ bl = int(bv.get("byteLength", 0))
204
+ img_bytes = bin_chunk[bo: bo + bl]
205
+
206
+ if DEBUG_PRINT:
207
+ print(
208
+ f"[Texture] baseColorTextureIndex={base_tex_index}, imageIndex={img_index}, "
209
+ f"bufferView={bv_index}, mime={mime}, bytes={len(img_bytes)}"
210
+ )
211
+
212
+ pil = Image.open(trimesh.util.wrap_as_stream(img_bytes)).convert("RGBA")
213
+ return np.array(pil, dtype=np.uint8)
214
+
215
+
216
+ def _merge_palette_rgb(palette_rgb: np.ndarray, counts: np.ndarray, merge_dist: float) -> np.ndarray:
217
+ if palette_rgb is None or len(palette_rgb) == 0:
218
+ return palette_rgb
219
+ if merge_dist is None or merge_dist <= 0:
220
+ return palette_rgb
221
+
222
+ rgb = palette_rgb.astype(np.float32)
223
+ counts = counts.astype(np.int64)
224
+
225
+ order = np.argsort(-counts)
226
+ centers = []
227
+ center_w = []
228
+ thr2 = float(merge_dist) * float(merge_dist)
229
+
230
+ for idx in order:
231
+ x = rgb[idx]
232
+ w = int(counts[idx])
233
+
234
+ if not centers:
235
+ centers.append(x.copy())
236
+ center_w.append(w)
237
+ continue
238
+
239
+ C = np.stack(centers, axis=0)
240
+ d2 = np.sum((C - x[None, :]) ** 2, axis=1)
241
+ k = int(np.argmin(d2))
242
+
243
+ if float(d2[k]) <= thr2:
244
+ cw = center_w[k]
245
+ centers[k] = (centers[k] * cw + x * w) / (cw + w)
246
+ center_w[k] = cw + w
247
+ else:
248
+ centers.append(x.copy())
249
+ center_w.append(w)
250
+
251
+ merged = np.clip(np.rint(np.stack(centers, axis=0)), 0, 255).astype(np.uint8)
252
+
253
+ if DEBUG_PRINT:
254
+ print(f"[PaletteMerge] before={len(palette_rgb)} after={len(merged)} merge_dist={merge_dist}")
255
+
256
+ return merged
257
+
258
+
259
+ def _build_palette_rgb(tex_rgba: np.ndarray) -> np.ndarray:
260
+ rgb = tex_rgba[:, :, :3].reshape(-1, 3)
261
+ n = rgb.shape[0]
262
+
263
+ if n > PALETTE_SAMPLE_PIXELS:
264
+ rng = np.random.default_rng(0)
265
+ idx = rng.choice(n, size=PALETTE_SAMPLE_PIXELS, replace=False)
266
+ rgb = rgb[idx]
267
+
268
+ rgb = _quantize_rgb(rgb, COLOR_QUANT_STEP)
269
+
270
+ uniq, counts = np.unique(rgb, axis=0, return_counts=True)
271
+ order = np.argsort(-counts)
272
+ uniq = uniq[order]
273
+ counts = counts[order]
274
+
275
+ keep = counts >= PALETTE_MIN_PIXELS
276
+ uniq = uniq[keep]
277
+ counts = counts[keep]
278
+
279
+ if len(uniq) > PALETTE_MAX_COLORS:
280
+ uniq = uniq[:PALETTE_MAX_COLORS]
281
+ counts = counts[:PALETTE_MAX_COLORS]
282
+
283
+ if DEBUG_PRINT:
284
+ print(
285
+ f"[Palette] quant_step={COLOR_QUANT_STEP} palette_size(before_merge)={len(uniq)} "
286
+ f"min_pixels={PALETTE_MIN_PIXELS}"
287
+ )
288
+ for i in range(min(15, len(uniq))):
289
+ r, g, b = [int(x) for x in uniq[i]]
290
+ print(f" {i:02d} rgb=({r},{g},{b}) count={int(counts[i])}")
291
+
292
+ uniq = _merge_palette_rgb(uniq.astype(np.uint8), counts, PALETTE_MERGE_DIST)
293
+
294
+ if DEBUG_PRINT:
295
+ print(f"[Palette] palette_size(after_merge)={len(uniq)}")
296
+ for i in range(min(15, len(uniq))):
297
+ r, g, b = [int(x) for x in uniq[i]]
298
+ print(f" {i:02d} rgb=({r},{g},{b})")
299
+
300
+ return uniq.astype(np.uint8)
301
+
302
+
303
+ def _unwrap_uv3_for_seam(uv3: np.ndarray) -> np.ndarray:
304
+ out = uv3.copy()
305
+ for d in range(2):
306
+ v = out[:, :, d]
307
+ vmin = v.min(axis=1)
308
+ vmax = v.max(axis=1)
309
+ seam = (vmax - vmin) > 0.5
310
+ if np.any(seam):
311
+ vv = v[seam]
312
+ vv = np.where(vv < 0.5, vv + 1.0, vv)
313
+ out[seam, :, d] = vv
314
+ return out
315
+
316
+
317
+ def _barycentric_samples(uv3: np.ndarray, samples_per_face: int) -> np.ndarray:
318
+ uv3 = _unwrap_uv3_for_seam(uv3)
319
+
320
+ if samples_per_face == 1:
321
+ w = np.array([1 / 3, 1 / 3, 1 / 3], dtype=np.float32)
322
+ uvs = uv3[:, 0, :] * w[0] + uv3[:, 1, :] * w[1] + uv3[:, 2, :] * w[2]
323
+ return uvs[:, None, :]
324
+
325
+ ws = np.array(
326
+ [
327
+ [1 / 3, 1 / 3, 1 / 3],
328
+ [0.80, 0.10, 0.10],
329
+ [0.10, 0.80, 0.10],
330
+ [0.10, 0.10, 0.80],
331
+ ],
332
+ dtype=np.float32,
333
+ )
334
+ uvs = (
335
+ uv3[:, None, 0, :] * ws[None, :, 0, None]
336
+ + uv3[:, None, 1, :] * ws[None, :, 1, None]
337
+ + uv3[:, None, 2, :] * ws[None, :, 2, None]
338
+ )
339
+ return uvs
340
+
341
+
342
+ def _wrap_or_clamp_uv(uv: np.ndarray) -> np.ndarray:
343
+ if UV_WRAP_REPEAT:
344
+ return np.mod(uv, 1.0)
345
+ return np.clip(uv, 0.0, 1.0)
346
+
347
+
348
+ def _sample_texture_nearest_rgb(tex_rgba: np.ndarray, uv: np.ndarray) -> np.ndarray:
349
+ h, w = tex_rgba.shape[0], tex_rgba.shape[1]
350
+ uv = _wrap_or_clamp_uv(uv)
351
+
352
+ u = uv[:, 0]
353
+ v = uv[:, 1]
354
+ if FLIP_V:
355
+ v = 1.0 - v
356
+
357
+ x = np.rint(u * (w - 1)).astype(np.int32)
358
+ y = np.rint(v * (h - 1)).astype(np.int32)
359
+ x = np.clip(x, 0, w - 1)
360
+ y = np.clip(y, 0, h - 1)
361
+
362
+ return tex_rgba[y, x, :3].astype(np.uint8)
363
+
364
+
365
+ def _map_to_palette_rgb(colors_rgb: np.ndarray, palette_rgb: np.ndarray, chunk: int = 20000) -> Tuple[np.ndarray, np.ndarray]:
366
+ if palette_rgb is None or len(palette_rgb) == 0:
367
+ uniq, inv = np.unique(colors_rgb, axis=0, return_inverse=True)
368
+ return inv.astype(np.int32), uniq.astype(np.uint8)
369
+
370
+ c = colors_rgb.astype(np.float32)
371
+ p = palette_rgb.astype(np.float32)
372
+
373
+ out = np.empty((c.shape[0],), dtype=np.int32)
374
+ for i in range(0, c.shape[0], chunk):
375
+ cc = c[i:i + chunk]
376
+ d2 = ((cc[:, None, :] - p[None, :, :]) ** 2).sum(axis=2)
377
+ out[i:i + chunk] = np.argmin(d2, axis=1).astype(np.int32)
378
+
379
+ return out, palette_rgb
380
+
381
+
382
+ def _face_labels_and_confidence_from_texture_rgb(
383
+ mesh: trimesh.Trimesh,
384
+ tex_rgba: np.ndarray,
385
+ palette_rgb: np.ndarray,
386
+ ) -> Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
387
+ uv = getattr(mesh.visual, "uv", None)
388
+ if uv is None:
389
+ return None
390
+
391
+ uv = np.asarray(uv, dtype=np.float32)
392
+ if uv.ndim != 2 or uv.shape[1] != 2 or uv.shape[0] != len(mesh.vertices):
393
+ return None
394
+
395
+ faces = mesh.faces
396
+ uv3 = uv[faces]
397
+
398
+ uvs = _barycentric_samples(uv3, SAMPLES_PER_FACE)
399
+ F, S = uvs.shape[0], uvs.shape[1]
400
+ flat_uv = uvs.reshape(-1, 2)
401
+
402
+ sampled_rgb = _sample_texture_nearest_rgb(tex_rgba, flat_uv)
403
+ sampled_rgb = _quantize_rgb(sampled_rgb, COLOR_QUANT_STEP)
404
+
405
+ sample_label, used_palette = _map_to_palette_rgb(sampled_rgb, palette_rgb)
406
+ sample_label = sample_label.reshape(F, S)
407
+
408
+ if S == 1:
409
+ face_label = sample_label[:, 0].astype(np.int32)
410
+ face_conf = np.ones((F,), dtype=np.float32)
411
+ return face_label, face_conf, used_palette
412
+
413
+ face_label = np.empty((F,), dtype=np.int32)
414
+ face_conf = np.empty((F,), dtype=np.float32)
415
+ for i in range(F):
416
+ row = sample_label[i].tolist()
417
+ c = Counter(row)
418
+ lab, cnt = c.most_common(1)[0]
419
+ face_label[i] = int(lab)
420
+ face_conf[i] = float(cnt) / float(S)
421
+
422
+ return face_label, face_conf, used_palette
423
+
424
+
425
+ def _build_face_adjacency_list(mesh: trimesh.Trimesh) -> Optional[list]:
426
+ adj_pairs = mesh.face_adjacency
427
+ if adj_pairs is None or len(adj_pairs) == 0:
428
+ return None
429
+ F = len(mesh.faces)
430
+ adj = [[] for _ in range(F)]
431
+ for a, b in adj_pairs:
432
+ adj[a].append(b)
433
+ adj[b].append(a)
434
+ return adj
435
+
436
+
437
+ def _reassign_transition_faces(
438
+ face_label: np.ndarray,
439
+ face_conf: np.ndarray,
440
+ adj: list,
441
+ conf_thresh: float,
442
+ iters: int,
443
+ neighbor_min: int,
444
+ ) -> np.ndarray:
445
+ labels = face_label.copy()
446
+ F = labels.shape[0]
447
+
448
+ transition = face_conf < float(conf_thresh)
449
+ if DEBUG_PRINT:
450
+ print(f"[Transition] faces={F}, transition_faces={int(transition.sum())}, conf_thresh={conf_thresh}")
451
+
452
+ if not np.any(transition):
453
+ return labels
454
+
455
+ for it in range(max(1, iters)):
456
+ changed = 0
457
+ for f in range(F):
458
+ if not transition[f]:
459
+ continue
460
+ neigh = adj[f]
461
+ if not neigh:
462
+ continue
463
+
464
+ votes = defaultdict(int)
465
+ for nb in neigh:
466
+ if transition[nb]:
467
+ continue
468
+ votes[int(labels[nb])] += 1
469
+
470
+ if not votes:
471
+ for nb in neigh:
472
+ votes[int(labels[nb])] += 1
473
+
474
+ if not votes:
475
+ continue
476
+
477
+ best_lab, best_cnt = max(votes.items(), key=lambda x: x[1])
478
+ if best_cnt < neighbor_min:
479
+ continue
480
+
481
+ if int(labels[f]) != int(best_lab):
482
+ labels[f] = int(best_lab)
483
+ changed += 1
484
+
485
+ if DEBUG_PRINT:
486
+ print(f"[Transition] iter={it+1}/{max(1,iters)} changed={changed}")
487
+ if changed == 0:
488
+ break
489
+
490
+ return labels
491
+
492
+
493
+ def _postprocess_small_components(
494
+ mesh: trimesh.Trimesh,
495
+ face_label: np.ndarray,
496
+ min_component_faces: int,
497
+ action: str,
498
+ iters: int,
499
+ ) -> np.ndarray:
500
+ if min_component_faces is None or min_component_faces <= 0:
501
+ return face_label
502
+ if action not in ("drop", "reassign"):
503
+ raise ValueError('SMALL_COMPONENT_ACTION must be "drop" or "reassign"')
504
+
505
+ adj = _build_face_adjacency_list(mesh)
506
+ if adj is None:
507
+ return face_label
508
+
509
+ F = len(mesh.faces)
510
+ labels = face_label.copy()
511
+
512
+ for it in range(max(1, iters)):
513
+ visited = np.zeros(F, dtype=bool)
514
+ changed = False
515
+
516
+ for seed in range(F):
517
+ if visited[seed]:
518
+ continue
519
+ lab = int(labels[seed])
520
+ if lab < 0:
521
+ visited[seed] = True
522
+ continue
523
+
524
+ q = [seed]
525
+ visited[seed] = True
526
+ comp = [seed]
527
+
528
+ while q:
529
+ f = q.pop()
530
+ for nb in adj[f]:
531
+ if not visited[nb] and int(labels[nb]) == lab:
532
+ visited[nb] = True
533
+ q.append(nb)
534
+ comp.append(nb)
535
+
536
+ if len(comp) >= min_component_faces:
537
+ continue
538
+
539
+ if action == "drop":
540
+ labels[np.array(comp, dtype=np.int64)] = -1
541
+ changed = True
542
+ continue
543
+
544
+ neigh_counts = defaultdict(int)
545
+ for f in comp:
546
+ for nb in adj[f]:
547
+ nl = int(labels[nb])
548
+ if nl >= 0 and nl != lab:
549
+ neigh_counts[nl] += 1
550
+
551
+ if not neigh_counts:
552
+ continue
553
+
554
+ new_lab = max(neigh_counts.items(), key=lambda x: x[1])[0]
555
+ labels[np.array(comp, dtype=np.int64)] = int(new_lab)
556
+ changed = True
557
+
558
+ if DEBUG_PRINT:
559
+ print(
560
+ f"[SmallComp] iter={it+1}/{max(1,iters)} action={action} "
561
+ f"min_comp_faces={min_component_faces} changed={changed}"
562
+ )
563
+
564
+ if not changed:
565
+ break
566
+
567
+ return labels
568
+
569
+
570
+ def split_glb_by_texture_palette_rgb(
571
+ in_glb_path: str,
572
+ out_glb_path: Optional[str] = None,
573
+ bake_transforms: bool = True,
574
+ ) -> str:
575
+ if out_glb_path is None:
576
+ out_glb_path = _default_out_path(in_glb_path)
577
+
578
+ tex_rgba = _extract_basecolor_texture_image(in_glb_path)
579
+ palette_rgb = _build_palette_rgb(tex_rgba)
580
+
581
+ scene = trimesh.load(in_glb_path, force="scene", process=False)
582
+ out_scene = trimesh.Scene()
583
+
584
+ part_count = 0
585
+ base = os.path.splitext(os.path.basename(in_glb_path))[0]
586
+
587
+ for node_name in scene.graph.nodes_geometry:
588
+ geom_name = scene.graph[node_name][1]
589
+ if geom_name is None:
590
+ continue
591
+
592
+ geom = scene.geometry.get(geom_name, None)
593
+ if geom is None or not isinstance(geom, trimesh.Trimesh):
594
+ continue
595
+
596
+ mesh = geom.copy()
597
+
598
+ if bake_transforms:
599
+ T, _ = scene.graph.get(node_name)
600
+ if T is not None:
601
+ mesh.apply_transform(T)
602
+
603
+ res = _face_labels_and_confidence_from_texture_rgb(mesh, tex_rgba, palette_rgb)
604
+ if res is None:
605
+ if DEBUG_PRINT:
606
+ print(f"[{node_name}] no uv / cannot sample -> keep orig")
607
+ out_scene.add_geometry(mesh, geom_name=f"{base}__{node_name}__orig")
608
+ continue
609
+
610
+ face_label, face_conf, label_rgb = res
611
+
612
+ if DEBUG_PRINT:
613
+ u, _ = np.unique(face_label, return_counts=True)
614
+ print(f"[{node_name}] raw labels_used={len(u)} palette_size={len(label_rgb)}")
615
+
616
+ adj = _build_face_adjacency_list(mesh)
617
+ if adj is not None:
618
+ face_label = _reassign_transition_faces(
619
+ face_label=face_label,
620
+ face_conf=face_conf,
621
+ adj=adj,
622
+ conf_thresh=TRANSITION_CONF_THRESH,
623
+ iters=TRANSITION_PROP_ITERS,
624
+ neighbor_min=TRANSITION_NEIGHBOR_MIN,
625
+ )
626
+
627
+ face_label = _postprocess_small_components(
628
+ mesh=mesh,
629
+ face_label=face_label,
630
+ min_component_faces=SMALL_COMPONENT_MIN_FACES,
631
+ action=SMALL_COMPONENT_ACTION,
632
+ iters=POSTPROCESS_ITERS,
633
+ )
634
+
635
+ if DEBUG_PRINT:
636
+ u2, _ = np.unique(face_label[face_label >= 0], return_counts=True)
637
+ print(f"[{node_name}] after post labels_used={len(u2)}")
638
+
639
+ groups = defaultdict(list)
640
+ for fi, lab in enumerate(face_label):
641
+ lab = int(lab)
642
+ if lab < 0:
643
+ continue
644
+ groups[lab].append(fi)
645
+
646
+ for lab, face_ids in groups.items():
647
+ if len(face_ids) < MIN_FACES_PER_PART:
648
+ continue
649
+
650
+ sub = mesh.submesh([np.array(face_ids, dtype=np.int64)], append=True, repair=False)
651
+ if sub is None:
652
+ continue
653
+ if isinstance(sub, (list, tuple)):
654
+ if not sub:
655
+ continue
656
+ sub = sub[0]
657
+
658
+ if 0 <= lab < len(label_rgb):
659
+ r, g, b = [int(x) for x in label_rgb[lab]]
660
+ part_name = f"{base}__{node_name}__label_{lab}__rgb_{r}_{g}_{b}"
661
+ else:
662
+ part_name = f"{base}__{node_name}__label_{lab}"
663
+
664
+ out_scene.add_geometry(sub, geom_name=part_name)
665
+ part_count += 1
666
+
667
+ if part_count == 0:
668
+ if DEBUG_PRINT:
669
+ print("[INFO] part_count==0, fallback to original scene export.")
670
+ out_scene = scene
671
+
672
+ out_scene.export(out_glb_path)
673
+ return out_glb_path
674
+
675
+
676
+ def main():
677
+ out_path = split_glb_by_texture_palette_rgb(
678
+ INPUT_GLB,
679
+ out_glb_path=None,
680
+ bake_transforms=BAKE_TRANSFORMS,
681
+ )
682
+ print("Done. Exported:", out_path)
683
+
684
+
685
+ if __name__ == "__main__":
686
+ main()
train_full.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
3
+
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import pytorch_lightning as pl
9
+ import trellis2.modules.sparse as sp
10
+
11
+ from trellis2 import models
12
+ from torch.nn import functional as F
13
+ from pytorch_lightning import Trainer
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+
17
+ class Gen3DSeg(nn.Module):
18
+ def __init__(self, flow_model):
19
+ super().__init__()
20
+ self.flow_model = flow_model
21
+
22
+ def forward(self, x_t, tex_slats, shape_slats, t, cond, coords_len_list):
23
+ input_tex_feats_list = []
24
+ input_tex_coords_list = []
25
+ shape_feats_list = []
26
+ shape_coords_list = []
27
+ begin = 0
28
+ for coords_len in coords_len_list:
29
+ end = begin + coords_len
30
+ input_tex_feats_list.append(x_t.feats[begin:end])
31
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
32
+ input_tex_coords_list.append(x_t.coords[begin:end])
33
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
34
+ shape_feats_list.append(shape_slats.feats[begin:end])
35
+ shape_feats_list.append(shape_slats.feats[begin:end])
36
+ shape_coords_list.append(shape_slats.coords[begin:end])
37
+ shape_coords_list.append(shape_slats.coords[begin:end])
38
+ begin = end
39
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
40
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
41
+
42
+ output_tex_slats = self.flow_model(x_t, t, cond, shape_slats)
43
+
44
+ output_tex_feats_list = []
45
+ output_tex_coords_list = []
46
+ begin = 0
47
+ for coords_len in coords_len_list:
48
+ end = begin + coords_len
49
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
50
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
51
+ begin = begin + 2 * coords_len
52
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
53
+ return output_tex_slat
54
+
55
+ class Gen3DSegDataset(Dataset):
56
+ def __init__(self, dataset_path, indices, split="train", repeat=1):
57
+ super().__init__()
58
+ self.repeat = repeat
59
+ self.split = split
60
+ self.indices = indices
61
+ with open(dataset_path, "r") as f:
62
+ all_samples = json.load(f)
63
+ if self.indices == -1:
64
+ self.indices = [0, len(all_samples)]
65
+ self.all_samples = self.split_data(all_samples, split)
66
+
67
+ def split_data(self, all_samples, split):
68
+ repeat = self.repeat if split == "train" else 1
69
+ all_samples = all_samples[self.indices[0] : self.indices[1]]
70
+ all_samples = all_samples * repeat
71
+ return all_samples
72
+
73
+ def __len__(self):
74
+ return len(self.all_samples)
75
+
76
+ def load_instance(self, index):
77
+ shape_slat = torch.load(self.all_samples[index]["shape_slat"])
78
+ shape_slat = sp.SparseTensor(shape_slat["feats"], shape_slat["coords"])
79
+ input_tex_slat = torch.load(self.all_samples[index]["input_tex_slat"])
80
+ input_tex_slat = sp.SparseTensor(input_tex_slat["feats"], input_tex_slat["coords"])
81
+ output_tex_slat_gt = torch.load(self.all_samples[index]["output_tex_slat_gt"])
82
+ output_tex_slat_gt = sp.SparseTensor(output_tex_slat_gt["feats"], output_tex_slat_gt["coords"])
83
+ cond_dict = torch.load(self.all_samples[index]["cond"])
84
+ return {"shape_slat": shape_slat, "input_tex_slat": input_tex_slat, "output_tex_slat_gt": output_tex_slat_gt, "cond_dict": cond_dict}
85
+
86
+ def __getitem__(self, index):
87
+ try:
88
+ return self.load_instance(index)
89
+ except Exception as e:
90
+ print(f"Error in {self.all_samples[index]}: {e}")
91
+ return self.__getitem__((index + 1) % self.__len__())
92
+
93
+ class DataModule(pl.LightningDataModule):
94
+ def __init__(self, batch_size, num_workers, dataset_path, indices, repeat, shuffle, seed):
95
+ super().__init__()
96
+ self.batch_size = batch_size
97
+ self.num_workers = num_workers
98
+ self.dataset_path = dataset_path
99
+ self.indices = indices
100
+ self.repeat = repeat
101
+ self.shuffle = shuffle
102
+ self.seed = seed
103
+
104
+ def setup(self, stage=None):
105
+ if stage in (None, "fit"):
106
+ self.train_dataset = Gen3DSegDataset(self.dataset_path, self.indices, "train", self.repeat)
107
+
108
+ def collate_fn(self, batch):
109
+ shape_slats = sp.sparse_cat([sample["shape_slat"] for sample in batch])
110
+ input_tex_slats = sp.sparse_cat([sample["input_tex_slat"] for sample in batch])
111
+ output_tex_slat_gts = sp.sparse_cat([sample["output_tex_slat_gt"] for sample in batch])
112
+ cond_dicts = [sample["cond_dict"] for sample in batch]
113
+ coords_len_list = [sample["shape_slat"].coords.shape[0] for sample in batch]
114
+ return {"shape_slats": shape_slats, "input_tex_slats": input_tex_slats, "output_tex_slat_gts": output_tex_slat_gts, "cond_dicts": cond_dicts, "coords_len_list": coords_len_list}
115
+
116
+ def train_dataloader(self):
117
+ distributed_sampler = None
118
+ if hasattr(self.trainer, "world_size") and self.trainer.world_size > 1:
119
+ from torch.utils.data.distributed import DistributedSampler
120
+ distributed_sampler = DistributedSampler(
121
+ self.train_dataset,
122
+ num_replicas=self.trainer.world_size,
123
+ rank=self.trainer.global_rank,
124
+ shuffle=self.shuffle,
125
+ seed=self.seed
126
+ )
127
+ return DataLoader(
128
+ self.train_dataset,
129
+ batch_size=self.batch_size,
130
+ num_workers=self.num_workers,
131
+ collate_fn=self.collate_fn,
132
+ sampler=distributed_sampler,
133
+ shuffle=False,
134
+ )
135
+
136
+ class System(pl.LightningModule):
137
+ def __init__(self, gen3dseg, pipeline_args, sigma_min, p_uncond, print_every):
138
+ super().__init__()
139
+ self.gen3dseg = gen3dseg
140
+ self.sigma_min = sigma_min
141
+ self.p_uncond = p_uncond
142
+ self.print_every = print_every
143
+ self.shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None]
144
+ self.shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None]
145
+ self.tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None]
146
+ self.tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None]
147
+ for param in self.gen3dseg.parameters():
148
+ param.requires_grad = True
149
+ self.gen3dseg.train()
150
+
151
+ def forward(self, shape_slats, input_tex_slats, output_tex_slat_gts, cond_dicts, coords_len_list):
152
+ batch_size = len(coords_len_list)
153
+ device = shape_slats.feats.device
154
+ shape_slats = ((shape_slats - self.shape_mean.to(device)) / self.shape_std.to(device))
155
+ input_tex_slats = ((input_tex_slats - self.tex_mean.to(device)) / self.tex_std.to(device))
156
+
157
+ x_0 = (output_tex_slat_gts - self.tex_mean.to(device)) / self.tex_std.to(device)
158
+ t = torch.sigmoid(torch.randn(batch_size) * 1.0 + 1.0).to(device)
159
+ t_x = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
160
+ noise = sp.SparseTensor(torch.randn_like(x_0.feats), x_0.coords).to(device)
161
+ x_t = (1 - t_x) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t_x) * noise
162
+
163
+ mask = list(np.random.rand(batch_size) < self.p_uncond)
164
+ cond_list = []
165
+ for i in range(batch_size):
166
+ if mask[i]:
167
+ cond_list.append(cond_dicts[i]['neg_cond'])
168
+ else:
169
+ cond_list.append(cond_dicts[i]['cond'])
170
+ cond = torch.cat(cond_list, dim=0)
171
+
172
+ pred = self.gen3dseg(x_t, input_tex_slats, shape_slats, t*1000, cond, coords_len_list)
173
+
174
+ target = (1 - self.sigma_min) * noise - x_0
175
+ loss = F.mse_loss(pred.feats, target.feats)
176
+ return loss
177
+
178
+ def configure_optimizers(self):
179
+ optimizer = torch.optim.AdamW(self.gen3dseg.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
180
+ scheduler = {"scheduler": torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=9999999), "interval": "step"}
181
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
182
+
183
+ def training_step(self, batch, batch_idx):
184
+ loss = self(batch["shape_slats"], batch["input_tex_slats"], batch["output_tex_slat_gts"], batch["cond_dicts"], batch["coords_len_list"])
185
+ self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
186
+ torch.cuda.empty_cache()
187
+
188
+ if (self.global_step + 1) % self.print_every == 0:
189
+ self.print(f"[step {self.global_step+1}] train_loss = {loss.item():.6f}")
190
+ return loss
191
+
192
+ def train(dataset_path, ckpts_path):
193
+ pl.seed_everything(42, workers=True)
194
+ data_module = DataModule(1, 16, dataset_path, -1, 1, True, 42)
195
+
196
+ with open("microsoft/TRELLIS.2-4B/pipeline.json", "r") as f:
197
+ pipeline_config = json.load(f)
198
+ pipeline_args = pipeline_config['args']
199
+ tex_slat_flow_model = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
200
+
201
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
202
+ sigma_min = pipeline_args['tex_slat_sampler']['args']['sigma_min']
203
+ system = System(gen3dseg, pipeline_args, sigma_min, p_uncond=0.1, print_every=10)
204
+ ckpt_callback = ModelCheckpoint(
205
+ dirpath=ckpts_path,
206
+ filename="step_{step}",
207
+ every_n_train_steps=500,
208
+ save_top_k=-1
209
+ )
210
+ trainer = Trainer(
211
+ callbacks=[ckpt_callback],
212
+ accelerator="gpu",
213
+ devices=-1,
214
+ max_epochs=1,
215
+ gradient_clip_val=1.0
216
+ )
217
+ trainer.fit(system, datamodule=data_module)
218
+
219
+ if __name__ == "__main__":
220
+ _2d_map = True
221
+ if _2d_map:
222
+ dataset_path = "./data_toolkit/assets/full_seg_w_2d_map/dataset.json"
223
+ ckpts_path = "path/to/ckpts_full_seg_w_2d_map"
224
+ else:
225
+ dataset_path = "./data_toolkit/assets/full_seg/dataset.json"
226
+ ckpts_path = "path/to/ckpts_full_seg"
227
+ train(dataset_path, ckpts_path)
train_interactive.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
3
+
4
+ import json
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import pytorch_lightning as pl
10
+ import trellis2.modules.sparse as sp
11
+
12
+ from trellis2 import models
13
+ from types import MethodType
14
+ from torch.nn import functional as F
15
+ from pytorch_lightning import Trainer
16
+ from trellis2.modules.utils import manual_cast
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from pytorch_lightning.callbacks import ModelCheckpoint
19
+
20
+ def flow_forward(self, x, t, cond, concat_cond, point_embeds, coords_len_list):
21
+ # x.feats: [N, 32]
22
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
23
+ if isinstance(cond, list):
24
+ cond = sp.VarLenTensor.from_tensor_list(cond)
25
+ # x.feats: [N, 64]
26
+ h = self.input_layer(x)
27
+ # h.feats: [N, 1536]
28
+ h = manual_cast(h, self.dtype)
29
+ t_emb = self.t_embedder(t)
30
+ t_emb = self.adaLN_modulation(t_emb)
31
+ t_emb = manual_cast(t_emb, self.dtype)
32
+ cond = manual_cast(cond, self.dtype)
33
+ point_embeds = manual_cast(point_embeds, self.dtype)
34
+
35
+ h_feats_list = []
36
+ h_coords_list = []
37
+ begin = 0
38
+ for i, coords_len in enumerate(coords_len_list):
39
+ end = begin + 2 * coords_len
40
+ h_feats_list.append(h.feats[begin:end])
41
+ h_coords_list.append(h.coords[begin:end])
42
+ h_feats_list.append(point_embeds.feats[i*10:(i+1)*10])
43
+ h_coords_list.append(point_embeds.coords[i*10:(i+1)*10])
44
+ begin = end + 10
45
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
46
+
47
+ for block in self.blocks:
48
+ h = block(h, t_emb, cond)
49
+
50
+ h_feats_list = []
51
+ h_coords_list = []
52
+ begin = 0
53
+ for i, coords_len in enumerate(coords_len_list):
54
+ end = begin + 2 * coords_len
55
+ h_feats_list.append(h.feats[begin:end])
56
+ h_coords_list.append(h.coords[begin:end])
57
+ begin = end
58
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
59
+
60
+ h = manual_cast(h, x.dtype)
61
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
62
+ # h.feats: [N, 1536]
63
+ h = self.out_layer(h)
64
+ # h.feats: [N, 32]
65
+ return h
66
+
67
+ class Gen3DSeg(nn.Module):
68
+ def __init__(self, flow_model):
69
+ super().__init__()
70
+ self.flow_model = flow_model
71
+ self.seg_embeddings = nn.Embedding(1, 1536)
72
+
73
+ def get_positional_encoding(self, input_points):
74
+ point_feats_embed = torch.zeros((10, 1536), dtype=torch.float32).to(input_points['point_slats'].feats.device)
75
+ labels = input_points['point_labels'].squeeze(-1)
76
+ point_feats_embed[labels == 1] = self.seg_embeddings.weight
77
+ return sp.SparseTensor(point_feats_embed, input_points['point_slats'].coords)
78
+
79
+ def forward(self, x_t, tex_slats, shape_slats, t, cond, input_points, coords_len_list):
80
+ input_tex_feats_list = []
81
+ input_tex_coords_list = []
82
+ shape_feats_list = []
83
+ shape_coords_list = []
84
+ begin = 0
85
+ for coords_len in coords_len_list:
86
+ end = begin + coords_len
87
+ input_tex_feats_list.append(x_t.feats[begin:end])
88
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
89
+ input_tex_coords_list.append(x_t.coords[begin:end])
90
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
91
+ shape_feats_list.append(shape_slats.feats[begin:end])
92
+ shape_feats_list.append(shape_slats.feats[begin:end])
93
+ shape_coords_list.append(shape_slats.coords[begin:end])
94
+ shape_coords_list.append(shape_slats.coords[begin:end])
95
+ begin = end
96
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
97
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
98
+
99
+ point_embeds = self.get_positional_encoding(input_points)
100
+ output_tex_slats = self.flow_model(x_t, t, cond, shape_slats, point_embeds, coords_len_list)
101
+
102
+ output_tex_feats_list = []
103
+ output_tex_coords_list = []
104
+ begin = 0
105
+ for coords_len in coords_len_list:
106
+ end = begin + coords_len
107
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
108
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
109
+ begin = begin + 2 * coords_len
110
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
111
+ return output_tex_slat
112
+
113
+ class Gen3DSegDataset(Dataset):
114
+ def __init__(self, dataset_path, indices, split="train", repeat=1):
115
+ super().__init__()
116
+ self.repeat = repeat
117
+ self.split = split
118
+ self.indices = indices
119
+ with open(dataset_path, "r") as f:
120
+ all_samples = json.load(f)
121
+ if self.indices == -1:
122
+ self.indices = [0, len(all_samples)]
123
+ self.all_samples = self.split_data(all_samples, split)
124
+
125
+ def split_data(self, all_samples, split):
126
+ repeat = self.repeat if split == "train" else 1
127
+ all_samples = all_samples[self.indices[0] : self.indices[1]]
128
+ all_samples = all_samples * repeat
129
+ return all_samples
130
+
131
+ def __len__(self):
132
+ return len(self.all_samples)
133
+
134
+ def load_instance(self, index):
135
+ shape_slat = torch.load(self.all_samples[index]["shape_slat"])
136
+ shape_slat = sp.SparseTensor(shape_slat["feats"], shape_slat["coords"])
137
+ input_tex_slat = torch.load(self.all_samples[index]["input_tex_slat"])
138
+ input_tex_slat = sp.SparseTensor(input_tex_slat["feats"], input_tex_slat["coords"])
139
+ output_tex_slat_gt = torch.load(self.all_samples[index]["output_tex_slat_gt"])
140
+ output_tex_slat_gt = sp.SparseTensor(output_tex_slat_gt["feats"], output_tex_slat_gt["coords"])
141
+ cond_dict = torch.load(self.all_samples[index]["cond"])
142
+ max_point_num = self.all_samples[index]["max_point_num"]
143
+ point_num = random.randint(1, max_point_num)
144
+ input_points = torch.load(self.all_samples[index]["input_points"].format(point_num=point_num))
145
+ return {"shape_slat": shape_slat, "input_tex_slat": input_tex_slat, "output_tex_slat_gt": output_tex_slat_gt, "cond_dict": cond_dict, "input_points": input_points}
146
+
147
+ def __getitem__(self, index):
148
+ try:
149
+ return self.load_instance(index)
150
+ except Exception as e:
151
+ print(f"Error in {self.all_samples[index]}: {e}")
152
+ return self.__getitem__((index + 1) % self.__len__())
153
+
154
+ class DataModule(pl.LightningDataModule):
155
+ def __init__(self, batch_size, num_workers, dataset_path, indices, repeat, shuffle, seed):
156
+ super().__init__()
157
+ self.batch_size = batch_size
158
+ self.num_workers = num_workers
159
+ self.dataset_path = dataset_path
160
+ self.indices = indices
161
+ self.repeat = repeat
162
+ self.shuffle = shuffle
163
+ self.seed = seed
164
+
165
+ def setup(self, stage=None):
166
+ if stage in (None, "fit"):
167
+ self.train_dataset = Gen3DSegDataset(self.dataset_path, self.indices, "train", self.repeat)
168
+
169
+ def collate_fn(self, batch):
170
+ shape_slats = sp.sparse_cat([sample["shape_slat"] for sample in batch])
171
+ input_tex_slats = sp.sparse_cat([sample["input_tex_slat"] for sample in batch])
172
+ output_tex_slat_gts = sp.sparse_cat([sample["output_tex_slat_gt"] for sample in batch])
173
+ cond_dicts = [sample["cond_dict"] for sample in batch]
174
+ point_slats = sp.sparse_cat([sp.SparseTensor(sample["input_points"]["point_feats"], sample["input_points"]["point_feats"]) for sample in batch])
175
+ point_labels = torch.cat([sample["input_points"]["point_labels"] for sample in batch])
176
+ input_points = {'point_slats': point_slats, 'point_labels': point_labels}
177
+ coords_len_list = [sample["shape_slat"].coords.shape[0] for sample in batch]
178
+ return {"shape_slats": shape_slats, "input_tex_slats": input_tex_slats, "output_tex_slat_gts": output_tex_slat_gts, "cond_dicts": cond_dicts, "input_points": input_points, "coords_len_list": coords_len_list}
179
+
180
+ def train_dataloader(self):
181
+ distributed_sampler = None
182
+ if hasattr(self.trainer, "world_size") and self.trainer.world_size > 1:
183
+ from torch.utils.data.distributed import DistributedSampler
184
+ distributed_sampler = DistributedSampler(
185
+ self.train_dataset,
186
+ num_replicas=self.trainer.world_size,
187
+ rank=self.trainer.global_rank,
188
+ shuffle=self.shuffle,
189
+ seed=self.seed
190
+ )
191
+ return DataLoader(
192
+ self.train_dataset,
193
+ batch_size=self.batch_size,
194
+ num_workers=self.num_workers,
195
+ collate_fn=self.collate_fn,
196
+ sampler=distributed_sampler,
197
+ shuffle=False,
198
+ )
199
+
200
+ class System(pl.LightningModule):
201
+ def __init__(self, gen3dseg, pipeline_args, sigma_min, p_uncond, print_every):
202
+ super().__init__()
203
+ self.gen3dseg = gen3dseg
204
+ self.sigma_min = sigma_min
205
+ self.p_uncond = p_uncond
206
+ self.print_every = print_every
207
+ self.shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None]
208
+ self.shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None]
209
+ self.tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None]
210
+ self.tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None]
211
+ for param in self.gen3dseg.parameters():
212
+ param.requires_grad = True
213
+ self.gen3dseg.train()
214
+
215
+ def forward(self, shape_slats, input_tex_slats, output_tex_slat_gts, cond_dicts, input_points, coords_len_list):
216
+ batch_size = len(coords_len_list)
217
+ device = shape_slats.feats.device
218
+ shape_slats = ((shape_slats - self.shape_mean.to(device)) / self.shape_std.to(device))
219
+ input_tex_slats = ((input_tex_slats - self.tex_mean.to(device)) / self.tex_std.to(device))
220
+
221
+ x_0 = (output_tex_slat_gts - self.tex_mean.to(device)) / self.tex_std.to(device)
222
+ t = torch.sigmoid(torch.randn(batch_size) * 1.0 + 1.0).to(device)
223
+ t_x = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
224
+ noise = sp.SparseTensor(torch.randn_like(x_0.feats), x_0.coords).to(device)
225
+ x_t = (1 - t_x) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t_x) * noise
226
+
227
+ mask = list(np.random.rand(batch_size) < self.p_uncond)
228
+ cond_list = []
229
+ for i in range(batch_size):
230
+ if mask[i]:
231
+ cond_list.append(cond_dicts[i]['neg_cond'])
232
+ else:
233
+ cond_list.append(cond_dicts[i]['cond'])
234
+ cond = torch.cat(cond_list, dim=0)
235
+
236
+ pred = self.gen3dseg(x_t, input_tex_slats, shape_slats, t*1000, cond, input_points, coords_len_list)
237
+
238
+ target = (1 - self.sigma_min) * noise - x_0
239
+ loss = F.mse_loss(pred.feats, target.feats)
240
+ return loss
241
+
242
+ def configure_optimizers(self):
243
+ optimizer = torch.optim.AdamW(self.gen3dseg.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
244
+ scheduler = {"scheduler": torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=9999999), "interval": "step"}
245
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
246
+
247
+ def training_step(self, batch, batch_idx):
248
+ loss = self(batch["shape_slats"], batch["input_tex_slats"], batch["output_tex_slat_gts"], batch["cond_dicts"], batch["input_points"], batch["coords_len_list"])
249
+ self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
250
+ torch.cuda.empty_cache()
251
+
252
+ if (self.global_step + 1) % self.print_every == 0:
253
+ self.print(f"[step {self.global_step+1}] train_loss = {loss.item():.6f}")
254
+ return loss
255
+
256
+ def train(dataset_path, ckpts_path):
257
+ pl.seed_everything(42, workers=True)
258
+ data_module = DataModule(1, 16, dataset_path, -1, 1, True, 42)
259
+
260
+ with open("microsoft/TRELLIS.2-4B/pipeline.json", "r") as f:
261
+ pipeline_config = json.load(f)
262
+ pipeline_args = pipeline_config['args']
263
+ tex_slat_flow_model = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
264
+ tex_slat_flow_model.forward = MethodType(flow_forward, tex_slat_flow_model)
265
+
266
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
267
+ sigma_min = pipeline_args['tex_slat_sampler']['args']['sigma_min']
268
+ system = System(gen3dseg, pipeline_args, sigma_min, p_uncond=0.1, print_every=10)
269
+ ckpt_callback = ModelCheckpoint(
270
+ dirpath=ckpts_path,
271
+ filename="step_{step}",
272
+ every_n_train_steps=500,
273
+ save_top_k=-1
274
+ )
275
+ trainer = Trainer(
276
+ callbacks=[ckpt_callback],
277
+ accelerator="gpu",
278
+ devices=-1,
279
+ max_epochs=1,
280
+ gradient_clip_val=1.0
281
+ )
282
+ trainer.fit(system, datamodule=data_module)
283
+
284
+ if __name__ == "__main__":
285
+ dataset_path = "./data_toolkit/assets/interactive_seg/dataset.json"
286
+ ckpts_path = "path/to/ckpts_interactive_seg"
287
+ train(dataset_path, ckpts_path)
train_unified.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
3
+
4
+ import json
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import pytorch_lightning as pl
10
+ import trellis2.modules.sparse as sp
11
+
12
+ from trellis2 import models
13
+ from types import MethodType
14
+ from torch.nn import functional as F
15
+ from pytorch_lightning import Trainer
16
+ from trellis2.modules.utils import manual_cast
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from pytorch_lightning.callbacks import ModelCheckpoint
19
+
20
+ def flow_forward(self, x, t, tag_embeds, cond, concat_cond, point_embeds, coords_len_list):
21
+ # x.feats: [N, 32]
22
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
23
+ if isinstance(cond, list):
24
+ cond = sp.VarLenTensor.from_tensor_list(cond)
25
+ # x.feats: [N, 64]
26
+ h = self.input_layer(x)
27
+ # h.feats: [N, 1536]
28
+ h = manual_cast(h, self.dtype)
29
+ t_emb = self.t_embedder(t)
30
+ t_emb = self.adaLN_modulation(t_emb)
31
+ tag_embeds = self.adaLN_modulation(tag_embeds)
32
+ t_emb = t_emb + tag_embeds
33
+ t_emb = manual_cast(t_emb, self.dtype)
34
+ cond = manual_cast(cond, self.dtype)
35
+ point_embeds = manual_cast(point_embeds, self.dtype)
36
+
37
+ h_feats_list = []
38
+ h_coords_list = []
39
+ begin = 0
40
+ for i, coords_len in enumerate(coords_len_list):
41
+ end = begin + 2 * coords_len
42
+ h_feats_list.append(h.feats[begin:end])
43
+ h_coords_list.append(h.coords[begin:end])
44
+ h_feats_list.append(point_embeds.feats[i*10:(i+1)*10])
45
+ h_coords_list.append(point_embeds.coords[i*10:(i+1)*10])
46
+ begin = end + 10
47
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
48
+
49
+ for block in self.blocks:
50
+ h = block(h, t_emb, cond)
51
+
52
+ h_feats_list = []
53
+ h_coords_list = []
54
+ begin = 0
55
+ for i, coords_len in enumerate(coords_len_list):
56
+ end = begin + 2 * coords_len
57
+ h_feats_list.append(h.feats[begin:end])
58
+ h_coords_list.append(h.coords[begin:end])
59
+ begin = end
60
+ h = sp.SparseTensor(torch.cat(h_feats_list), torch.cat(h_coords_list))
61
+
62
+ h = manual_cast(h, x.dtype)
63
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
64
+ # h.feats: [N, 1536]
65
+ h = self.out_layer(h)
66
+ # h.feats: [N, 32]
67
+ return h
68
+
69
+ class Gen3DSeg(nn.Module):
70
+ def __init__(self, flow_model):
71
+ super().__init__()
72
+ self.flow_model = flow_model
73
+ self.seg_embeddings = nn.Embedding(1, 1536)
74
+ self.tag_mlp = nn.Sequential(nn.Linear(256, 1536, bias=True), nn.SiLU(), nn.Linear(1536, 1536, bias=True))
75
+
76
+ def tag_embedding(self, tag):
77
+ freqs = torch.exp(-np.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=tag.device)
78
+ args = tag[:, None].float() * freqs[None]
79
+ tag_freq = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
80
+ tag_embeds = self.tag_mlp(tag_freq)
81
+ return tag_embeds
82
+
83
+ def get_positional_encoding(self, input_points):
84
+ point_feats_embed = torch.zeros((10, 1536), dtype=torch.float32).to(input_points['point_slats'].feats.device)
85
+ labels = input_points['point_labels'].squeeze(-1)
86
+ point_feats_embed[labels == 1] = self.seg_embeddings.weight
87
+ return sp.SparseTensor(point_feats_embed, input_points['point_slats'].coords)
88
+
89
+ def forward(self, x_t, tex_slats, shape_slats, t, tags, cond, input_points, coords_len_list):
90
+ input_tex_feats_list = []
91
+ input_tex_coords_list = []
92
+ shape_feats_list = []
93
+ shape_coords_list = []
94
+ begin = 0
95
+ for coords_len in coords_len_list:
96
+ end = begin + coords_len
97
+ input_tex_feats_list.append(x_t.feats[begin:end])
98
+ input_tex_feats_list.append(tex_slats.feats[begin:end])
99
+ input_tex_coords_list.append(x_t.coords[begin:end])
100
+ input_tex_coords_list.append(tex_slats.coords[begin:end])
101
+ shape_feats_list.append(shape_slats.feats[begin:end])
102
+ shape_feats_list.append(shape_slats.feats[begin:end])
103
+ shape_coords_list.append(shape_slats.coords[begin:end])
104
+ shape_coords_list.append(shape_slats.coords[begin:end])
105
+ begin = end
106
+ x_t = sp.SparseTensor(torch.cat(input_tex_feats_list), torch.cat(input_tex_coords_list))
107
+ shape_slats = sp.SparseTensor(torch.cat(shape_feats_list), torch.cat(shape_coords_list))
108
+
109
+ tag_embeds = self.tag_embedding(tags)
110
+ point_embeds = self.get_positional_encoding(input_points)
111
+ output_tex_slats = self.flow_model(x_t, t, tag_embeds, cond, shape_slats, point_embeds, coords_len_list)
112
+
113
+ output_tex_feats_list = []
114
+ output_tex_coords_list = []
115
+ begin = 0
116
+ for coords_len in coords_len_list:
117
+ end = begin + coords_len
118
+ output_tex_feats_list.append(output_tex_slats.feats[begin:end])
119
+ output_tex_coords_list.append(output_tex_slats.coords[begin:end])
120
+ begin = begin + 2 * coords_len
121
+ output_tex_slat = sp.SparseTensor(torch.cat(output_tex_feats_list), torch.cat(output_tex_coords_list))
122
+ return output_tex_slat
123
+
124
+ class Gen3DSegDataset(Dataset):
125
+ def __init__(self, dataset_path, indices, split="train", repeat=1):
126
+ super().__init__()
127
+ self.repeat = repeat
128
+ self.split = split
129
+ self.indices = indices
130
+ with open(dataset_path, "r") as f:
131
+ all_samples = json.load(f)
132
+ if self.indices == -1:
133
+ self.indices = [0, len(all_samples)]
134
+ self.all_samples = self.split_data(all_samples, split)
135
+
136
+ def split_data(self, all_samples, split):
137
+ repeat = self.repeat if split == "train" else 1
138
+ all_samples = all_samples[self.indices[0] : self.indices[1]]
139
+ all_samples = all_samples * repeat
140
+ return all_samples
141
+
142
+ def __len__(self):
143
+ return len(self.all_samples)
144
+
145
+ def load_instance(self, index):
146
+ shape_slat = torch.load(self.all_samples[index]["shape_slat"])
147
+ shape_slat = sp.SparseTensor(shape_slat["feats"], shape_slat["coords"])
148
+ input_tex_slat = torch.load(self.all_samples[index]["input_tex_slat"])
149
+ input_tex_slat = sp.SparseTensor(input_tex_slat["feats"], input_tex_slat["coords"])
150
+ output_tex_slat_gt = torch.load(self.all_samples[index]["output_tex_slat_gt"])
151
+ output_tex_slat_gt = sp.SparseTensor(output_tex_slat_gt["feats"], output_tex_slat_gt["coords"])
152
+ cond_dict = torch.load(self.all_samples[index]["cond"])
153
+ tag = torch.tensor([self.all_samples[index]["tag"]])
154
+ max_point_num = self.all_samples[index]["max_point_num"]
155
+ if max_point_num == 0:
156
+ input_points = {"point_feats": torch.zeros(10, 4, dtype=torch.int32), "point_labels": torch.zeros(10, 1, dtype=torch.int32)}
157
+ else:
158
+ point_num = random.randint(1, max_point_num)
159
+ input_points = torch.load(self.all_samples[index]["input_points"].format(point_num=point_num))
160
+ return {"shape_slat": shape_slat, "input_tex_slat": input_tex_slat, "output_tex_slat_gt": output_tex_slat_gt, "cond_dict": cond_dict, "input_points": input_points, "tag": tag}
161
+
162
+ def __getitem__(self, index):
163
+ try:
164
+ return self.load_instance(index)
165
+ except Exception as e:
166
+ print(f"Error in {self.all_samples[index]}: {e}")
167
+ return self.__getitem__((index + 1) % self.__len__())
168
+
169
+ class DataModule(pl.LightningDataModule):
170
+ def __init__(self, batch_size, num_workers, dataset_path, indices, repeat, shuffle, seed):
171
+ super().__init__()
172
+ self.batch_size = batch_size
173
+ self.num_workers = num_workers
174
+ self.dataset_path = dataset_path
175
+ self.indices = indices
176
+ self.repeat = repeat
177
+ self.shuffle = shuffle
178
+ self.seed = seed
179
+
180
+ def setup(self, stage=None):
181
+ if stage in (None, "fit"):
182
+ self.train_dataset = Gen3DSegDataset(self.dataset_path, self.indices, "train", self.repeat)
183
+
184
+ def collate_fn(self, batch):
185
+ shape_slats = sp.sparse_cat([sample["shape_slat"] for sample in batch])
186
+ input_tex_slats = sp.sparse_cat([sample["input_tex_slat"] for sample in batch])
187
+ output_tex_slat_gts = sp.sparse_cat([sample["output_tex_slat_gt"] for sample in batch])
188
+ cond_dicts = [sample["cond_dict"] for sample in batch]
189
+ point_slats = sp.sparse_cat([sp.SparseTensor(sample["input_points"]["point_feats"], sample["input_points"]["point_feats"]) for sample in batch])
190
+ point_labels = torch.cat([sample["input_points"]["point_labels"] for sample in batch])
191
+ input_points = {'point_slats': point_slats, 'point_labels': point_labels}
192
+ coords_len_list = [sample["shape_slat"].coords.shape[0] for sample in batch]
193
+ tags = [sample["tag"] for sample in batch]
194
+ return {"shape_slats": shape_slats, "input_tex_slats": input_tex_slats, "output_tex_slat_gts": output_tex_slat_gts, "cond_dicts": cond_dicts, "input_points": input_points, "coords_len_list": coords_len_list, "tags": tags}
195
+
196
+ def train_dataloader(self):
197
+ distributed_sampler = None
198
+ if hasattr(self.trainer, "world_size") and self.trainer.world_size > 1:
199
+ from torch.utils.data.distributed import DistributedSampler
200
+ distributed_sampler = DistributedSampler(
201
+ self.train_dataset,
202
+ num_replicas=self.trainer.world_size,
203
+ rank=self.trainer.global_rank,
204
+ shuffle=self.shuffle,
205
+ seed=self.seed
206
+ )
207
+ return DataLoader(
208
+ self.train_dataset,
209
+ batch_size=self.batch_size,
210
+ num_workers=self.num_workers,
211
+ collate_fn=self.collate_fn,
212
+ sampler=distributed_sampler,
213
+ shuffle=False,
214
+ )
215
+
216
+ class System(pl.LightningModule):
217
+ def __init__(self, gen3dseg, pipeline_args, sigma_min, p_uncond, print_every):
218
+ super().__init__()
219
+ self.gen3dseg = gen3dseg
220
+ self.sigma_min = sigma_min
221
+ self.p_uncond = p_uncond
222
+ self.print_every = print_every
223
+ self.shape_std = torch.tensor(pipeline_args['shape_slat_normalization']['std'])[None]
224
+ self.shape_mean = torch.tensor(pipeline_args['shape_slat_normalization']['mean'])[None]
225
+ self.tex_std = torch.tensor(pipeline_args['tex_slat_normalization']['std'])[None]
226
+ self.tex_mean = torch.tensor(pipeline_args['tex_slat_normalization']['mean'])[None]
227
+ for param in self.gen3dseg.parameters():
228
+ param.requires_grad = True
229
+ self.gen3dseg.train()
230
+
231
+ def forward(self, shape_slats, input_tex_slats, output_tex_slat_gts, cond_dicts, input_points, coords_len_list, tags):
232
+ batch_size = len(coords_len_list)
233
+ device = shape_slats.feats.device
234
+ shape_slats = ((shape_slats - self.shape_mean.to(device)) / self.shape_std.to(device))
235
+ input_tex_slats = ((input_tex_slats - self.tex_mean.to(device)) / self.tex_std.to(device))
236
+
237
+ x_0 = (output_tex_slat_gts - self.tex_mean.to(device)) / self.tex_std.to(device)
238
+ t = torch.sigmoid(torch.randn(batch_size) * 1.0 + 1.0).to(device)
239
+ t_x = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
240
+ noise = sp.SparseTensor(torch.randn_like(x_0.feats), x_0.coords).to(device)
241
+ x_t = (1 - t_x) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t_x) * noise
242
+
243
+ mask = list(np.random.rand(batch_size) < self.p_uncond)
244
+ cond_list = []
245
+ for i in range(batch_size):
246
+ if mask[i]:
247
+ cond_list.append(cond_dicts[i]['neg_cond'])
248
+ else:
249
+ cond_list.append(cond_dicts[i]['cond'])
250
+ cond = torch.cat(cond_list, dim=0)
251
+
252
+ pred = self.gen3dseg(x_t, input_tex_slats, shape_slats, t*1000, tags[0], cond, input_points, coords_len_list)
253
+
254
+ target = (1 - self.sigma_min) * noise - x_0
255
+ loss = F.mse_loss(pred.feats, target.feats)
256
+ return loss
257
+
258
+ def configure_optimizers(self):
259
+ optimizer = torch.optim.AdamW(self.gen3dseg.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
260
+ scheduler = {"scheduler": torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=9999999), "interval": "step"}
261
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
262
+
263
+ def training_step(self, batch, batch_idx):
264
+ loss = self(batch["shape_slats"], batch["input_tex_slats"], batch["output_tex_slat_gts"], batch["cond_dicts"], batch["input_points"], batch["coords_len_list"], batch["tags"])
265
+ self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
266
+ torch.cuda.empty_cache()
267
+
268
+ if (self.global_step + 1) % self.print_every == 0:
269
+ self.print(f"[step {self.global_step+1}] train_loss = {loss.item():.6f}")
270
+ return loss
271
+
272
+ def train(dataset_path, ckpts_path):
273
+ pl.seed_everything(42, workers=True)
274
+ data_module = DataModule(1, 16, dataset_path, -1, 1, True, 42)
275
+
276
+ with open("microsoft/TRELLIS.2-4B/pipeline.json", "r") as f:
277
+ pipeline_config = json.load(f)
278
+ pipeline_args = pipeline_config['args']
279
+ tex_slat_flow_model = models.from_pretrained("microsoft/TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16")
280
+ tex_slat_flow_model.forward = MethodType(flow_forward, tex_slat_flow_model)
281
+
282
+ gen3dseg = Gen3DSeg(tex_slat_flow_model)
283
+ sigma_min = pipeline_args['tex_slat_sampler']['args']['sigma_min']
284
+ system = System(gen3dseg, pipeline_args, sigma_min, p_uncond=0.1, print_every=10)
285
+ ckpt_callback = ModelCheckpoint(
286
+ dirpath=ckpts_path,
287
+ filename="step_{step}",
288
+ every_n_train_steps=500,
289
+ save_top_k=-1
290
+ )
291
+ trainer = Trainer(
292
+ callbacks=[ckpt_callback],
293
+ accelerator="gpu",
294
+ devices=-1,
295
+ max_epochs=1,
296
+ gradient_clip_val=1.0
297
+ )
298
+ trainer.fit(system, datamodule=data_module)
299
+
300
+ if __name__ == "__main__":
301
+ dataset_path = "./data_toolkit/assets/unified/dataset.json"
302
+ ckpts_path = "path/to/ckpts_unified"
303
+ train(dataset_path, ckpts_path)
trellis2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils