fenghora commited on
Commit
128b5a1
·
1 Parent(s): a9b839c
Files changed (3) hide show
  1. app.py +98 -62
  2. app_ori.py +423 -0
  3. inference_full.py +45 -18
app.py CHANGED
@@ -1,9 +1,17 @@
1
  import os
2
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
3
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
5
 
6
  import urllib.request
 
 
 
 
 
 
 
 
7
 
8
  os.makedirs("pretrained_model", exist_ok=True)
9
 
@@ -22,17 +30,6 @@ if not os.path.exists(CKPT_W_2D_MAP):
22
  CKPT_W_2D_MAP,
23
  )
24
 
25
- CKPT_FULL_SEG = CKPT_W_2D_MAP
26
-
27
- import shutil
28
- import traceback
29
- from datetime import datetime
30
- from pathlib import Path
31
- from typing import List
32
- import inference_full as inf
33
- import split as splitter
34
-
35
-
36
  TRANSFORMS_JSON = "./data_toolkit/transforms.json"
37
 
38
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -83,25 +80,36 @@ def _collect_examples(example_dir: str) -> List[List[str]]:
83
 
84
  examples: List[List[str]] = []
85
 
86
- # Search recursively in case you add subfolders later
87
  glb_files = sorted(d.rglob("*.glb"))
88
  for glb_path in glb_files:
89
  png_path = glb_path.with_suffix(".png")
90
  if png_path.is_file():
91
  examples.append([str(glb_path), str(png_path)])
92
- # If png is missing, skip to keep examples consistent (2 inputs required)
93
 
94
  return examples
95
 
96
 
97
- # Build examples once at startup
98
  FULL_SEG_EXAMPLES = _collect_examples(EXAMPLES_DIR)
99
 
100
 
101
- def run_seg(glb_in, img_in):
 
 
102
  """
103
- Segment button: generates whole segmented GLB and displays in the second box.
104
- Returns: segmented_glb_path, segmented_glb_path(state)
 
 
 
 
 
 
 
 
 
 
 
 
105
  """
106
  try:
107
  glb_path = _normalize_path(glb_in)
@@ -121,39 +129,49 @@ def run_seg(glb_in, img_in):
121
  out_glb = os.path.join(workdir, "segmented.glb")
122
  in_vxz = os.path.join(workdir, "input.vxz")
123
 
124
- # If image is provided -> 2d_map=True; otherwise full segmentation (render_from_transforms)
125
- if img_path is not None and os.path.isfile(img_path):
126
- ckpt = CKPT_W_2D_MAP
127
- in_img = os.path.join(workdir, "2d_map.png")
128
- shutil.copy(img_path, in_img)
129
- item = {
130
- "2d_map": True,
131
- "glb": in_glb,
132
- "input_vxz": in_vxz,
133
- "img": in_img,
134
- "export_glb": out_glb,
135
- }
136
- else:
137
- ckpt = CKPT_FULL_SEG
138
  render_img = os.path.join(workdir, "render.png")
139
- item = {
140
- "2d_map": False,
141
- "glb": in_glb,
142
- "input_vxz": in_vxz,
143
- "transforms": TRANSFORMS_JSON,
144
- "img": render_img,
145
- "export_glb": out_glb,
146
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  inf.inference_with_loaded_models(ckpt, item)
149
 
150
  if not os.path.isfile(out_glb):
151
  _raise_user_error("Export failed: output glb not found.")
152
 
153
- # Apply X90 rotation for whole segmented output
154
- # _apply_root_x90_rotation_glb(out_glb)
155
-
156
- return out_glb, out_glb
157
 
158
  except Exception as e:
159
  err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
@@ -216,10 +234,6 @@ def run_refine_segmentation(
216
  if not os.path.isfile(out_parts_glb):
217
  _raise_user_error("Split failed: output parts glb not found.")
218
 
219
- # If bake_transforms=False, split output will not have the wrapper transform baked, so we need to apply X90 rotation fix
220
- # if (not bool(bake_transforms)) and APPLY_OUTPUT_X90_FIX:
221
- # _apply_root_x90_rotation_glb(out_parts_glb)
222
-
223
  return out_parts_glb
224
 
225
  except Exception as e:
@@ -230,10 +244,11 @@ def run_refine_segmentation(
230
 
231
  CSS_TEXT = """
232
  <style>
233
- #in_glb { height: 520px !important; }
234
- #seg_glb { height: 520px !important; }
235
- #part_glb{ height: 520px !important; }
236
- #img { height: 520px !important; }
 
237
  </style>
238
  """
239
 
@@ -245,7 +260,6 @@ with gr.Blocks() as demo:
245
  """
246
  )
247
 
248
- # ---------------- 2x2 Layout ----------------
249
  with gr.Row():
250
  with gr.Column(scale=1, min_width=260):
251
  in_glb = gr.Model3D(label="Input GLB", elem_id="in_glb")
@@ -254,12 +268,28 @@ with gr.Blocks() as demo:
254
 
255
  with gr.Row():
256
  with gr.Column(scale=1, min_width=260):
257
- with gr.Accordion("2D Segmentation Map (Optional)", open=False):
258
- in_img = gr.Image(label="2D Segmentation Map", type="filepath", elem_id="img")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  seg_btn = gr.Button("Process", variant="primary")
261
 
262
- # ✅ Examples directly under the Process button
263
  if FULL_SEG_EXAMPLES:
264
  gr.Examples(
265
  examples=FULL_SEG_EXAMPLES,
@@ -269,7 +299,9 @@ with gr.Blocks() as demo:
269
  cache_examples=False,
270
  )
271
  else:
272
- gr.Markdown(f"**No examples found** in: `{EXAMPLES_DIR}` (expected: `*.glb` + same-name `*.png`).")
 
 
273
 
274
  with gr.Accordion("Advanced segmentation options", open=False):
275
  def _g(name, default):
@@ -385,14 +417,18 @@ with gr.Blocks() as demo:
385
  refine_btn = gr.Button("Segment", variant="secondary")
386
  part_glb = gr.Model3D(label="Segmented GLB", elem_id="part_glb")
387
 
388
- # Hidden states
389
  seg_glb_state = gr.State(None)
390
 
391
- # ---------------- wiring ----------------
 
 
 
 
 
392
  seg_btn.click(
393
  fn=run_seg,
394
- inputs=[in_glb, in_img],
395
- outputs=[seg_glb, seg_glb_state],
396
  )
397
 
398
  refine_btn.click(
@@ -414,7 +450,7 @@ with gr.Blocks() as demo:
414
  small_component_min_faces,
415
  postprocess_iters,
416
  min_faces_per_part,
417
- bake_transforms
418
  ],
419
  outputs=[part_glb],
420
  )
 
1
  import os
2
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
3
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
5
 
6
  import urllib.request
7
+ import shutil
8
+ import traceback
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ import inference_full as inf
14
+ import split as splitter
15
 
16
  os.makedirs("pretrained_model", exist_ok=True)
17
 
 
30
  CKPT_W_2D_MAP,
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  TRANSFORMS_JSON = "./data_toolkit/transforms.json"
34
 
35
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 
80
 
81
  examples: List[List[str]] = []
82
 
 
83
  glb_files = sorted(d.rglob("*.glb"))
84
  for glb_path in glb_files:
85
  png_path = glb_path.with_suffix(".png")
86
  if png_path.is_file():
87
  examples.append([str(glb_path), str(png_path)])
 
88
 
89
  return examples
90
 
91
 
 
92
  FULL_SEG_EXAMPLES = _collect_examples(EXAMPLES_DIR)
93
 
94
 
95
+ def _toggle_map_input(mode: str):
96
+ """
97
+ Show upload image input only when Upload mode is selected.
98
  """
99
+ return gr.update(visible=(mode == "Upload"))
100
+
101
+
102
+ def run_seg(glb_in, map_mode, img_in):
103
+ """
104
+ Process button:
105
+ - Upload mode: use the uploaded 2D map directly
106
+ - Generate mode: generate a 2D map with FLUX2, show it to the user,
107
+ and use it as if it were the uploaded map
108
+
109
+ Returns:
110
+ segmented_glb_path,
111
+ segmented_glb_path(state),
112
+ used_2d_map_path
113
  """
114
  try:
115
  glb_path = _normalize_path(glb_in)
 
129
  out_glb = os.path.join(workdir, "segmented.glb")
130
  in_vxz = os.path.join(workdir, "input.vxz")
131
 
132
+ # Always build an item that uses a 2D map in the end.
133
+ # If the user chooses Generate, we generate the map first.
134
+ ckpt = CKPT_W_2D_MAP
135
+
136
+ if map_mode == "Upload":
137
+ if img_path is None or (not os.path.isfile(img_path)):
138
+ _raise_user_error("Please upload a valid 2D segmentation map, or switch to Generate mode.")
139
+
140
+ used_img = os.path.join(workdir, "2d_map_uploaded.png")
141
+ shutil.copy(img_path, used_img)
142
+
143
+ elif map_mode == "Generate":
 
 
144
  render_img = os.path.join(workdir, "render.png")
145
+ used_img = os.path.join(workdir, "2d_map_generated.png")
146
+
147
+ # Generate the 2D map first, and then use it as the uploaded image.
148
+ inf.generate_2d_map_from_glb(
149
+ glb_path=in_glb,
150
+ transforms_path=TRANSFORMS_JSON,
151
+ out_img_path=used_img,
152
+ render_img_path=render_img,
153
+ )
154
+
155
+ if not os.path.isfile(used_img):
156
+ _raise_user_error("2D map generation failed: generated image not found.")
157
+
158
+ else:
159
+ _raise_user_error(f"Unsupported map mode: {map_mode}")
160
+
161
+ item = {
162
+ "2d_map": True,
163
+ "glb": in_glb,
164
+ "input_vxz": in_vxz,
165
+ "img": used_img,
166
+ "export_glb": out_glb,
167
+ }
168
 
169
  inf.inference_with_loaded_models(ckpt, item)
170
 
171
  if not os.path.isfile(out_glb):
172
  _raise_user_error("Export failed: output glb not found.")
173
 
174
+ return out_glb, out_glb, used_img
 
 
 
175
 
176
  except Exception as e:
177
  err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
 
234
  if not os.path.isfile(out_parts_glb):
235
  _raise_user_error("Split failed: output parts glb not found.")
236
 
 
 
 
 
237
  return out_parts_glb
238
 
239
  except Exception as e:
 
244
 
245
  CSS_TEXT = """
246
  <style>
247
+ #in_glb { height: 520px !important; }
248
+ #seg_glb { height: 520px !important; }
249
+ #part_glb { height: 520px !important; }
250
+ #img { height: 520px !important; }
251
+ #used_img { height: 520px !important; }
252
  </style>
253
  """
254
 
 
260
  """
261
  )
262
 
 
263
  with gr.Row():
264
  with gr.Column(scale=1, min_width=260):
265
  in_glb = gr.Model3D(label="Input GLB", elem_id="in_glb")
 
268
 
269
  with gr.Row():
270
  with gr.Column(scale=1, min_width=260):
271
+ map_mode = gr.Radio(
272
+ choices=["Upload", "Generate"],
273
+ value="Upload",
274
+ label="2D Map Source",
275
+ )
276
+
277
+ with gr.Accordion("2D Segmentation Map", open=True):
278
+ in_img = gr.Image(
279
+ label="Upload 2D Segmentation Map",
280
+ type="filepath",
281
+ elem_id="img",
282
+ visible=True,
283
+ )
284
+
285
+ used_img_preview = gr.Image(
286
+ label="Used 2D Segmentation Map",
287
+ type="filepath",
288
+ elem_id="used_img",
289
+ )
290
 
291
  seg_btn = gr.Button("Process", variant="primary")
292
 
 
293
  if FULL_SEG_EXAMPLES:
294
  gr.Examples(
295
  examples=FULL_SEG_EXAMPLES,
 
299
  cache_examples=False,
300
  )
301
  else:
302
+ gr.Markdown(
303
+ f"**No examples found** in: `{EXAMPLES_DIR}` (expected: `*.glb` + same-name `*.png`)."
304
+ )
305
 
306
  with gr.Accordion("Advanced segmentation options", open=False):
307
  def _g(name, default):
 
417
  refine_btn = gr.Button("Segment", variant="secondary")
418
  part_glb = gr.Model3D(label="Segmented GLB", elem_id="part_glb")
419
 
 
420
  seg_glb_state = gr.State(None)
421
 
422
+ map_mode.change(
423
+ fn=_toggle_map_input,
424
+ inputs=[map_mode],
425
+ outputs=[in_img],
426
+ )
427
+
428
  seg_btn.click(
429
  fn=run_seg,
430
+ inputs=[in_glb, map_mode, in_img],
431
+ outputs=[seg_glb, seg_glb_state, used_img_preview],
432
  )
433
 
434
  refine_btn.click(
 
450
  small_component_min_faces,
451
  postprocess_iters,
452
  min_faces_per_part,
453
+ bake_transforms,
454
  ],
455
  outputs=[part_glb],
456
  )
app_ori.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
5
+
6
+ import urllib.request
7
+
8
+ os.makedirs("pretrained_model", exist_ok=True)
9
+
10
+ CKPT_FULL_SEG = "pretrained_model/full_seg.ckpt"
11
+ CKPT_W_2D_MAP = "pretrained_model/full_seg_w_2d_map.ckpt"
12
+
13
+ if not os.path.exists(CKPT_FULL_SEG):
14
+ urllib.request.urlretrieve(
15
+ "https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg.ckpt",
16
+ CKPT_FULL_SEG,
17
+ )
18
+
19
+ if not os.path.exists(CKPT_W_2D_MAP):
20
+ urllib.request.urlretrieve(
21
+ "https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg_w_2d_map.ckpt",
22
+ CKPT_W_2D_MAP,
23
+ )
24
+
25
+ import shutil
26
+ import traceback
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+ from typing import List
30
+ import inference_full as inf
31
+ import split as splitter
32
+
33
+
34
+ TRANSFORMS_JSON = "./data_toolkit/transforms.json"
35
+
36
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
37
+ TMP_DIR = os.path.join(ROOT_DIR, "_tmp_gradio_seg")
38
+ EXAMPLES_CACHE_DIR = os.path.join(TMP_DIR, "examples_cache")
39
+ os.makedirs(TMP_DIR, exist_ok=True)
40
+ os.makedirs(EXAMPLES_CACHE_DIR, exist_ok=True)
41
+
42
+ os.environ["GRADIO_TEMP_DIR"] = TMP_DIR
43
+ os.environ["GRADIO_EXAMPLES_CACHE"] = EXAMPLES_CACHE_DIR
44
+
45
+ import gradio as gr
46
+
47
+ EXAMPLES_DIR = "examples"
48
+
49
+
50
+ def _ensure_dir(p: str):
51
+ os.makedirs(p, exist_ok=True)
52
+
53
+
54
+ def _normalize_path(x):
55
+ """
56
+ Compatible with different Gradio versions: File/Model3D might be str / dict / object
57
+ """
58
+ if x is None:
59
+ return None
60
+ if isinstance(x, str):
61
+ return x
62
+ if isinstance(x, dict):
63
+ return x.get("name") or x.get("path") or x.get("data")
64
+ return getattr(x, "name", None) or getattr(x, "path", None) or None
65
+
66
+
67
+ def _raise_user_error(msg: str):
68
+ if hasattr(gr, "Error"):
69
+ raise gr.Error(msg)
70
+ raise RuntimeError(msg)
71
+
72
+
73
+ def _collect_examples(example_dir: str) -> List[List[str]]:
74
+ """
75
+ Scan example_dir for pairs: <name>.glb + <name>.png
76
+ Return a list of examples: [[glb_path, png_path], ...]
77
+ """
78
+ d = Path(example_dir)
79
+ if not d.is_dir():
80
+ return []
81
+
82
+ examples: List[List[str]] = []
83
+
84
+ # Search recursively in case you add subfolders later
85
+ glb_files = sorted(d.rglob("*.glb"))
86
+ for glb_path in glb_files:
87
+ png_path = glb_path.with_suffix(".png")
88
+ if png_path.is_file():
89
+ examples.append([str(glb_path), str(png_path)])
90
+ # If png is missing, skip to keep examples consistent (2 inputs required)
91
+
92
+ return examples
93
+
94
+
95
+ # Build examples once at startup
96
+ FULL_SEG_EXAMPLES = _collect_examples(EXAMPLES_DIR)
97
+
98
+
99
+ def run_seg(glb_in, img_in):
100
+ """
101
+ Segment button: generates whole segmented GLB and displays in the second box.
102
+ Returns: segmented_glb_path, segmented_glb_path(state)
103
+ """
104
+ try:
105
+ glb_path = _normalize_path(glb_in)
106
+ img_path = _normalize_path(img_in)
107
+
108
+ if glb_path is None or (not os.path.isfile(glb_path)):
109
+ _raise_user_error("Please upload a valid .glb file.")
110
+
111
+ _ensure_dir(TMP_DIR)
112
+ run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
113
+ workdir = os.path.join(TMP_DIR, run_id)
114
+ _ensure_dir(workdir)
115
+
116
+ in_glb = os.path.join(workdir, "input.glb")
117
+ shutil.copy(glb_path, in_glb)
118
+
119
+ out_glb = os.path.join(workdir, "segmented.glb")
120
+ in_vxz = os.path.join(workdir, "input.vxz")
121
+
122
+ # If image is provided -> 2d_map=True; otherwise full segmentation (render_from_transforms)
123
+ if img_path is not None and os.path.isfile(img_path):
124
+ ckpt = CKPT_W_2D_MAP
125
+ in_img = os.path.join(workdir, "2d_map.png")
126
+ shutil.copy(img_path, in_img)
127
+ item = {
128
+ "2d_map": True,
129
+ "glb": in_glb,
130
+ "input_vxz": in_vxz,
131
+ "img": in_img,
132
+ "export_glb": out_glb,
133
+ }
134
+ else:
135
+ ckpt = CKPT_FULL_SEG
136
+ render_img = os.path.join(workdir, "render.png")
137
+ item = {
138
+ "2d_map": False,
139
+ "glb": in_glb,
140
+ "input_vxz": in_vxz,
141
+ "transforms": TRANSFORMS_JSON,
142
+ "img": render_img,
143
+ "export_glb": out_glb,
144
+ }
145
+
146
+ inf.inference_with_loaded_models(ckpt, item)
147
+
148
+ if not os.path.isfile(out_glb):
149
+ _raise_user_error("Export failed: output glb not found.")
150
+
151
+ # Apply X90 rotation for whole segmented output
152
+ # _apply_root_x90_rotation_glb(out_glb)
153
+
154
+ return out_glb, out_glb
155
+
156
+ except Exception as e:
157
+ err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
158
+ print(err)
159
+ raise
160
+
161
+
162
+ def run_refine_segmentation(
163
+ seg_glb_path_state,
164
+ color_quant_step,
165
+ palette_sample_pixels,
166
+ palette_min_pixels,
167
+ palette_max_colors,
168
+ palette_merge_dist,
169
+ samples_per_face,
170
+ flip_v,
171
+ uv_wrap_repeat,
172
+ transition_conf_thresh,
173
+ transition_prop_iters,
174
+ transition_neighbor_min,
175
+ small_component_action,
176
+ small_component_min_faces,
177
+ postprocess_iters,
178
+ min_faces_per_part,
179
+ bake_transforms,
180
+ ):
181
+ """
182
+ Refine Segmentation button: splits the segmented GLB into smaller parts GLB and displays in the fourth box.
183
+ """
184
+ try:
185
+ seg_glb_path = seg_glb_path_state if isinstance(seg_glb_path_state, str) else None
186
+ if (seg_glb_path is None) or (not os.path.isfile(seg_glb_path)):
187
+ _raise_user_error("Please run Segmentation first (the segmented GLB is missing).")
188
+
189
+ out_dir = os.path.dirname(seg_glb_path)
190
+ out_parts_glb = os.path.join(out_dir, "segmented_parts.glb")
191
+
192
+ splitter.split_glb_by_texture_palette_rgb(
193
+ in_glb_path=seg_glb_path,
194
+ out_glb_path=out_parts_glb,
195
+ min_faces_per_part=min_faces_per_part,
196
+ bake_transforms=bool(bake_transforms),
197
+ color_quant_step=color_quant_step,
198
+ palette_sample_pixels=palette_sample_pixels,
199
+ palette_min_pixels=palette_min_pixels,
200
+ palette_max_colors=palette_max_colors,
201
+ palette_merge_dist=palette_merge_dist,
202
+ samples_per_face=samples_per_face,
203
+ flip_v=flip_v,
204
+ uv_wrap_repeat=uv_wrap_repeat,
205
+ transition_conf_thresh=transition_conf_thresh,
206
+ transition_prop_iters=transition_prop_iters,
207
+ transition_neighbor_min=transition_neighbor_min,
208
+ small_component_action=small_component_action,
209
+ small_component_min_faces=small_component_min_faces,
210
+ postprocess_iters=postprocess_iters,
211
+ debug_print=True,
212
+ )
213
+
214
+ if not os.path.isfile(out_parts_glb):
215
+ _raise_user_error("Split failed: output parts glb not found.")
216
+
217
+ # If bake_transforms=False, split output will not have the wrapper transform baked, so we need to apply X90 rotation fix
218
+ # if (not bool(bake_transforms)) and APPLY_OUTPUT_X90_FIX:
219
+ # _apply_root_x90_rotation_glb(out_parts_glb)
220
+
221
+ return out_parts_glb
222
+
223
+ except Exception as e:
224
+ err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
225
+ print(err)
226
+ raise
227
+
228
+
229
+ CSS_TEXT = """
230
+ <style>
231
+ #in_glb { height: 520px !important; }
232
+ #seg_glb { height: 520px !important; }
233
+ #part_glb{ height: 520px !important; }
234
+ #img { height: 520px !important; }
235
+ </style>
236
+ """
237
+
238
+ with gr.Blocks() as demo:
239
+ gr.HTML(CSS_TEXT)
240
+ gr.Markdown(
241
+ """
242
+ # SegviGen: Repurposing 3D Generative Model for Part Segmentation
243
+ """
244
+ )
245
+
246
+ # ---------------- 2x2 Layout ----------------
247
+ with gr.Row():
248
+ with gr.Column(scale=1, min_width=260):
249
+ in_glb = gr.Model3D(label="Input GLB", elem_id="in_glb")
250
+ with gr.Column(scale=1, min_width=260):
251
+ seg_glb = gr.Model3D(label="Processed GLB", elem_id="seg_glb")
252
+
253
+ with gr.Row():
254
+ with gr.Column(scale=1, min_width=260):
255
+ with gr.Accordion("2D Segmentation Map (Optional)", open=False):
256
+ in_img = gr.Image(label="2D Segmentation Map", type="filepath", elem_id="img")
257
+
258
+ seg_btn = gr.Button("Process", variant="primary")
259
+
260
+ # ✅ Examples directly under the Process button
261
+ if FULL_SEG_EXAMPLES:
262
+ gr.Examples(
263
+ examples=FULL_SEG_EXAMPLES,
264
+ inputs=[in_glb, in_img],
265
+ label="Examples",
266
+ examples_per_page=3,
267
+ cache_examples=False,
268
+ )
269
+ else:
270
+ gr.Markdown(f"**No examples found** in: `{EXAMPLES_DIR}` (expected: `*.glb` + same-name `*.png`).")
271
+
272
+ with gr.Accordion("Advanced segmentation options", open=False):
273
+ def _g(name, default):
274
+ return getattr(splitter, name, default)
275
+
276
+ color_quant_step = gr.Slider(
277
+ 1, 64, value=_g("COLOR_QUANT_STEP", 16), step=1, label="COLOR_QUANT_STEP"
278
+ )
279
+ gr.Markdown(
280
+ "*COLOR_QUANT_STEP controls the RGB quantization step, where a larger value merges similar colors more aggressively and a smaller value preserves finer color differences.*"
281
+ )
282
+
283
+ palette_sample_pixels = gr.Number(
284
+ value=_g("PALETTE_SAMPLE_PIXELS", 2_000_000), precision=0, label="PALETTE_SAMPLE_PIXELS"
285
+ )
286
+ gr.Markdown(
287
+ "*PALETTE_SAMPLE_PIXELS sets the maximum number of sampled pixels used to estimate the palette, where more samples improve stability but increase runtime.*"
288
+ )
289
+
290
+ palette_min_pixels = gr.Number(
291
+ value=_g("PALETTE_MIN_PIXELS", 500), precision=0, label="PALETTE_MIN_PIXELS"
292
+ )
293
+ gr.Markdown(
294
+ "*PALETTE_MIN_PIXELS specifies the minimum pixel count required to keep a color in the palette, where a higher threshold suppresses noise but may discard small parts.*"
295
+ )
296
+
297
+ palette_max_colors = gr.Number(
298
+ value=_g("PALETTE_MAX_COLORS", 256), precision=0, label="PALETTE_MAX_COLORS"
299
+ )
300
+ gr.Markdown(
301
+ "*PALETTE_MAX_COLORS limits the maximum number of colors retained in the palette, where a larger limit yields finer partitions and a smaller limit enforces stronger merging.*"
302
+ )
303
+
304
+ palette_merge_dist = gr.Number(
305
+ value=_g("PALETTE_MERGE_DIST", 32), precision=0, label="PALETTE_MERGE_DIST"
306
+ )
307
+ gr.Markdown(
308
+ "*PALETTE_MERGE_DIST defines the distance threshold for merging nearby palette colors in RGB space, where a larger threshold merges near duplicates more often and a smaller threshold keeps colors distinct.*"
309
+ )
310
+
311
+ samples_per_face = gr.Dropdown(
312
+ choices=[1, 4], value=_g("SAMPLES_PER_FACE", 4), label="SAMPLES_PER_FACE"
313
+ )
314
+ gr.Markdown(
315
+ "*SAMPLES_PER_FACE sets the number of UV samples per triangle used for label voting, where more samples improve robustness near boundaries but increase computation.*"
316
+ )
317
+
318
+ flip_v = gr.Checkbox(value=_g("FLIP_V", True), label="FLIP_V")
319
+ gr.Markdown(
320
+ "*FLIP_V toggles whether the V coordinate is flipped to match common glTF texture conventions, and you should disable it only if the texture appears vertically inverted.*"
321
+ )
322
+
323
+ uv_wrap_repeat = gr.Checkbox(value=_g("UV_WRAP_REPEAT", True), label="UV_WRAP_REPEAT")
324
+ gr.Markdown(
325
+ "*UV_WRAP_REPEAT selects how out of range UVs are handled by either repeating via modulo or clamping to the unit interval, and repeating is typically preferred for tiled textures.*"
326
+ )
327
+
328
+ transition_conf_thresh = gr.Slider(
329
+ 0.25, 1.0, value=float(_g("TRANSITION_CONF_THRESH", 1.0)), step=0.25, label="TRANSITION_CONF_THRESH"
330
+ )
331
+ gr.Markdown(
332
+ "*TRANSITION_CONF_THRESH sets the confidence threshold for transition handling, where a higher value makes refinement more conservative and a lower value enables more aggressive smoothing.*"
333
+ )
334
+
335
+ transition_prop_iters = gr.Number(
336
+ value=_g("TRANSITION_PROP_ITERS", 6), precision=0, label="TRANSITION_PROP_ITERS"
337
+ )
338
+ gr.Markdown(
339
+ "*TRANSITION_PROP_ITERS specifies the number of propagation iterations used in transition refinement, where more iterations strengthen diffusion effects but increase runtime.*"
340
+ )
341
+
342
+ transition_neighbor_min = gr.Number(
343
+ value=_g("TRANSITION_NEIGHBOR_MIN", 1), precision=0, label="TRANSITION_NEIGHBOR_MIN"
344
+ )
345
+ gr.Markdown(
346
+ "*TRANSITION_NEIGHBOR_MIN requires a minimum number of supporting neighbors to propagate a label, where a higher requirement is more conservative and a lower requirement is more permissive.*"
347
+ )
348
+
349
+ small_component_action = gr.Dropdown(
350
+ choices=["reassign", "drop"], value=_g("SMALL_COMPONENT_ACTION", "reassign"), label="SMALL_COMPONENT_ACTION"
351
+ )
352
+ gr.Markdown(
353
+ "*SMALL_COMPONENT_ACTION determines how small connected components are handled by either reassigning them to neighboring labels or dropping them entirely.*"
354
+ )
355
+
356
+ small_component_min_faces = gr.Number(
357
+ value=_g("SMALL_COMPONENT_MIN_FACES", 50), precision=0, label="SMALL_COMPONENT_MIN_FACES"
358
+ )
359
+ gr.Markdown(
360
+ "*SMALL_COMPONENT_MIN_FACES defines the face count threshold used to classify a component as small, where a higher threshold merges or removes more fragments and a lower threshold preserves more small parts.*"
361
+ )
362
+
363
+ postprocess_iters = gr.Number(
364
+ value=_g("POSTPROCESS_ITERS", 3), precision=0, label="POSTPROCESS_ITERS"
365
+ )
366
+ gr.Markdown(
367
+ "*POSTPROCESS_ITERS sets the number of post processing iterations, where more iterations produce stronger cleanup at the cost of additional computation.*"
368
+ )
369
+
370
+ min_faces_per_part = gr.Number(
371
+ value=_g("MIN_FACES_PER_PART", 1), precision=0, label="MIN_FACES_PER_PART"
372
+ )
373
+ gr.Markdown(
374
+ "*MIN_FACES_PER_PART enforces a minimum number of faces per exported part, where a larger value filters tiny outputs and a smaller value retains fine components.*"
375
+ )
376
+
377
+ bake_transforms = gr.Checkbox(value=_g("BAKE_TRANSFORMS", True), label="BAKE_TRANSFORMS")
378
+ gr.Markdown(
379
+ "*BAKE_TRANSFORMS controls whether scene graph transforms are baked into geometry before splitting, where enabling it improves consistency in world space and disabling it preserves node transforms.*"
380
+ )
381
+
382
+ with gr.Column(scale=1, min_width=260):
383
+ refine_btn = gr.Button("Segment", variant="secondary")
384
+ part_glb = gr.Model3D(label="Segmented GLB", elem_id="part_glb")
385
+
386
+ # Hidden states
387
+ seg_glb_state = gr.State(None)
388
+
389
+ # ---------------- wiring ----------------
390
+ seg_btn.click(
391
+ fn=run_seg,
392
+ inputs=[in_glb, in_img],
393
+ outputs=[seg_glb, seg_glb_state],
394
+ )
395
+
396
+ refine_btn.click(
397
+ fn=run_refine_segmentation,
398
+ inputs=[
399
+ seg_glb_state,
400
+ color_quant_step,
401
+ palette_sample_pixels,
402
+ palette_min_pixels,
403
+ palette_max_colors,
404
+ palette_merge_dist,
405
+ samples_per_face,
406
+ flip_v,
407
+ uv_wrap_repeat,
408
+ transition_conf_thresh,
409
+ transition_prop_iters,
410
+ transition_neighbor_min,
411
+ small_component_action,
412
+ small_component_min_faces,
413
+ postprocess_iters,
414
+ min_faces_per_part,
415
+ bake_transforms
416
+ ],
417
+ outputs=[part_glb],
418
+ )
419
+
420
+ if __name__ == "__main__":
421
+ inf.PIPE.load_all_models()
422
+ inf.PIPE.load_ckpt_if_needed(CKPT_W_2D_MAP)
423
+ demo.launch()
inference_full.py CHANGED
@@ -33,6 +33,32 @@ TRELLIS_TEX_DEC = "microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16"
33
  DINO_PATH = "fenghora/dinov3"
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _colorvisuals_to_texturevisuals(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
37
  """
38
  Convert ColorVisuals to TextureVisuals by baking per-face colors into a tiny atlas
@@ -525,25 +551,26 @@ def inference_with_loaded_models(ckpt_path, item):
525
  PIPE.load_all_models()
526
  PIPE.load_ckpt_if_needed(ckpt_path)
527
 
528
- if not item['2d_map']:
529
- render_from_transforms(item['glb'], item['transforms'], item['img'])
530
 
531
- prompt = "Apply distinct colors to different regions of this image"
532
- image = PIPE.flux2(
533
- height=512,
534
- width=512,
535
- prompt=prompt,
536
- image=Image.open(item['img']),
537
- num_inference_steps=28,
538
- guidance_scale=4,
539
- ).images[0]
540
- image.save(item['img'])
541
-
542
- # import gc
543
- # del flux2
544
- # gc.collect()
545
- # torch.cuda.empty_cache()
546
- # torch.cuda.ipc_collect()
 
547
 
548
  if PIPE.rembg_model is None:
549
  raise RuntimeError("PIPE.rembg_model is None. Check BiRefNet loading and .cuda() usage.")
 
33
  DINO_PATH = "fenghora/dinov3"
34
 
35
 
36
+ def generate_2d_map_from_glb(glb_path, transforms_path, out_img_path, render_img_path=None):
37
+ """
38
+ Render the GLB first, then generate a 2D segmentation map with FLUX2.
39
+ """
40
+ PIPE.load_all_models()
41
+
42
+ if render_img_path is None:
43
+ base, _ = os.path.splitext(out_img_path)
44
+ render_img_path = f"{base}_render.png"
45
+
46
+ render_from_transforms(glb_path, transforms_path, render_img_path)
47
+
48
+ prompt = "Apply distinct colors to different regions of this image"
49
+ image = PIPE.flux2(
50
+ height=512,
51
+ width=512,
52
+ prompt=prompt,
53
+ image=Image.open(render_img_path),
54
+ num_inference_steps=28,
55
+ guidance_scale=4,
56
+ ).images[0]
57
+
58
+ image.save(out_img_path)
59
+ return out_img_path
60
+
61
+
62
  def _colorvisuals_to_texturevisuals(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
63
  """
64
  Convert ColorVisuals to TextureVisuals by baking per-face colors into a tiny atlas
 
551
  PIPE.load_all_models()
552
  PIPE.load_ckpt_if_needed(ckpt_path)
553
 
554
+ # if not item['2d_map']:
555
+ # render_from_transforms(item['glb'], item['transforms'], item['img'])
556
 
557
+ # prompt = "Apply distinct colors to different regions of this image"
558
+ # image = PIPE.flux2(
559
+ # height=512,
560
+ # width=512,
561
+ # prompt=prompt,
562
+ # image=Image.open(item['img']),
563
+ # num_inference_steps=28,
564
+ # guidance_scale=4,
565
+ # ).images[0]
566
+ # image.save(item['img'])
567
+
568
+ if not item["2d_map"]:
569
+ generate_2d_map_from_glb(
570
+ glb_path=item["glb"],
571
+ transforms_path=item["transforms"],
572
+ out_img_path=item["img"],
573
+ )
574
 
575
  if PIPE.rembg_model is None:
576
  raise RuntimeError("PIPE.rembg_model is None. Check BiRefNet loading and .cuda() usage.")