yjwnb6 commited on
Commit
9375c3b
·
1 Parent(s): cd90646
Files changed (28) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +99 -77
  3. demo/bird.webp +0 -0
  4. sam2/sam2/__pycache__/__init__.cpython-310.pyc +0 -0
  5. sam2/sam2/__pycache__/build_sam.cpython-310.pyc +0 -0
  6. sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc +0 -0
  7. sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc +0 -0
  8. sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  9. sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc +0 -0
  10. sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc +0 -0
  11. sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc +0 -0
  12. sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc +0 -0
  13. sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc +0 -0
  14. sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
  15. sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc +0 -0
  16. sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc +0 -0
  17. sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc +0 -0
  18. sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc +0 -0
  19. sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc +0 -0
  20. sam2/sam2/modeling/sam/__pycache__/mask_decoder.cpython-310.pyc +0 -0
  21. sam2/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-310.pyc +0 -0
  22. sam2/sam2/modeling/sam/__pycache__/transformer.cpython-310.pyc +0 -0
  23. sam2/sam2/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  24. sam2/sam2/utils/__pycache__/misc.cpython-310.pyc +0 -0
  25. sam2/sam2/utils/__pycache__/transforms.cpython-310.pyc +0 -0
  26. sam2/training/__pycache__/__init__.cpython-310.pyc +0 -0
  27. sam2/training/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  28. sam2/training/utils/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -8,7 +8,7 @@ import os
8
  import sys
9
  import threading
10
  from pathlib import Path
11
- from typing import List, Optional, Sequence, Tuple
12
 
13
  import cv2
14
  import gradio as gr
@@ -28,9 +28,11 @@ if SAM2_REPO.exists():
28
  from sam2.build_sam import build_sam2 # noqa: E402
29
  from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
30
 
31
- logging.basicConfig(level=os.getenv("UNSAMV2_LOGLEVEL", "INFO"))
32
  LOGGER = logging.getLogger("unsamv2-gradio")
33
 
 
 
34
  CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml")
35
  CKPT_PATH = Path(
36
  os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt")
@@ -53,6 +55,22 @@ POINT_COLORS_BGR = {
53
  MASK_COLOR_BGR = (0, 196, 255)
54
  OUTLINE_COLOR_BGR = (0, 165, 255)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  class ModelManager:
58
  """Keeps SAM2 models on each device and spawns lightweight predictors."""
@@ -82,7 +100,7 @@ class ModelManager:
82
  return self._models[key]
83
 
84
  def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
85
- return SAM2ImagePredictor(self.get_model(device))
86
 
87
 
88
  MODEL_MANAGER = ModelManager()
@@ -119,6 +137,41 @@ def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor
119
  return tensor
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def draw_overlay(
123
  image: np.ndarray,
124
  mask: Optional[np.ndarray],
@@ -150,37 +203,21 @@ def draw_overlay(
150
  return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
151
 
152
 
153
- def points_table(points: Sequence[Sequence[float]], labels: Sequence[int]) -> List[List[str]]:
154
- table = []
155
- for idx, ((x, y), lbl) in enumerate(zip(points, labels), start=1):
156
- table.append([
157
- idx,
158
- round(float(x), 1),
159
- round(float(y), 1),
160
- "fg" if lbl == 1 else "bg",
161
- ])
162
- return table
163
-
164
-
165
  def handle_image_upload(image: Optional[np.ndarray]):
166
  img = ensure_uint8(image)
167
  if img is None:
168
  return (
169
- None,
170
  None,
171
  None,
172
  [],
173
  [],
174
- [],
175
  "Upload an image to start adding clicks.",
176
  )
177
  return (
178
  img,
179
- None,
180
  img,
181
  [],
182
  [],
183
- [],
184
  "Image loaded. Choose click type, then tap on the image.",
185
  )
186
 
@@ -195,20 +232,16 @@ def handle_click(
195
  if image is None:
196
  return (
197
  gr.update(),
198
- None,
199
  pts,
200
  lbls,
201
- points_table(pts, lbls),
202
  "Upload an image first.",
203
  )
204
  coord = evt.index # (x, y)
205
  if coord is None:
206
  return (
207
  gr.update(),
208
- None,
209
  pts,
210
  lbls,
211
- points_table(pts, lbls),
212
  "Couldn't read click position.",
213
  )
214
  x, y = coord
@@ -217,29 +250,27 @@ def handle_click(
217
  lbls = lbls + [label]
218
  overlay = draw_overlay(image, None, pts, lbls)
219
  status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})."
220
- return overlay, None, pts, lbls, points_table(pts, lbls), status
221
 
222
 
223
  def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]):
224
  if not pts:
225
  return (
226
  gr.update(),
227
- None,
228
  pts,
229
  lbls,
230
- points_table(pts, lbls),
231
  "No clicks to undo.",
232
  )
233
  pts = pts[:-1]
234
  lbls = lbls[:-1]
235
  overlay = draw_overlay(image, None, pts, lbls) if image is not None else None
236
  status = "Removed the last click."
237
- return overlay, None, pts, lbls, points_table(pts, lbls), status
238
 
239
 
240
  def clear_clicks(image: Optional[np.ndarray]):
241
  overlay = image if image is not None else None
242
- return overlay, None, [], [], [], "Cleared all clicks."
243
 
244
 
245
  def _run_segmentation(
@@ -250,9 +281,9 @@ def _run_segmentation(
250
  ):
251
  img = ensure_uint8(image)
252
  if img is None:
253
- return None, None, "Upload an image to segment."
254
  if not pts:
255
- return draw_overlay(img, None, [], []), None, "Add at least one click before running segmentation."
256
 
257
  device = choose_device()
258
  predictor = MODEL_MANAGER.make_predictor(device)
@@ -262,7 +293,7 @@ def _run_segmentation(
262
  labels = np.asarray(lbls, dtype=np.int32)
263
  gran_tensor = build_granularity_tensor(granularity, predictor.device)
264
 
265
- masks, scores, _ = predictor.predict(
266
  point_coords=coords,
267
  point_labels=labels,
268
  multimask_output=True,
@@ -271,10 +302,27 @@ def _run_segmentation(
271
  )
272
  best_idx = int(np.argmax(scores))
273
  best_mask = masks[best_idx].astype(bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  overlay = draw_overlay(img, best_mask, pts, lbls)
275
- mask_vis = (best_mask.astype(np.uint8) * 255)
276
- status = f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | granularity={granularity:.2f}"
277
- return overlay, mask_vis, status
278
 
279
 
280
  if spaces is not None and ZERO_GPU_ENABLED:
@@ -287,30 +335,20 @@ def build_demo() -> gr.Blocks:
287
  with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo:
288
  gr.Markdown(
289
  """## UnSAMv2 · Interactive Granularity Control
290
- Upload an image, add positive/negative clicks, tune granularity, and run segmentation.
291
- ZeroGPU automatically pulls a GPU when available; otherwise the app falls back to CPU."""
292
  )
293
 
294
- image_state = gr.State()
295
  points_state = gr.State([])
296
  labels_state = gr.State([])
297
 
298
- with gr.Row():
299
- image_input = gr.Image(
300
- label="1 · Upload image & click to add prompts",
301
- type="numpy",
302
- height=480,
303
- )
304
- overlay_output = gr.Image(
305
- label="Segmentation preview",
306
- interactive=False,
307
- height=480,
308
- )
309
- mask_output = gr.Image(
310
- label="Binary mask",
311
- interactive=False,
312
- height=480,
313
- )
314
 
315
  with gr.Row():
316
  point_mode = gr.Radio(
@@ -322,34 +360,26 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
322
  minimum=GRANULARITY_MIN,
323
  maximum=GRANULARITY_MAX,
324
  value=0.2,
325
- step=0.05,
326
  label="Granularity",
327
  info="Lower = finer details, Higher = coarser regions",
328
  )
329
- segment_button = gr.Button("3 · Segment", variant="primary")
330
 
331
  with gr.Row():
332
  undo_button = gr.Button("Undo last click")
333
  clear_button = gr.Button("Clear clicks")
334
 
335
- points_table_output = gr.Dataframe(
336
- headers=["#", "x", "y", "type"],
337
- datatype=["number", "number", "number", "str"],
338
- interactive=False,
339
- label="2 · Click history",
340
- )
341
  status_markdown = gr.Markdown(" Ready.")
342
 
343
  image_input.upload(
344
  handle_image_upload,
345
  inputs=[image_input],
346
  outputs=[
347
- overlay_output,
348
- mask_output,
349
  image_state,
350
  points_state,
351
  labels_state,
352
- points_table_output,
353
  status_markdown,
354
  ],
355
  )
@@ -358,12 +388,10 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
358
  handle_image_upload,
359
  inputs=[image_input],
360
  outputs=[
361
- overlay_output,
362
- mask_output,
363
  image_state,
364
  points_state,
365
  labels_state,
366
- points_table_output,
367
  status_markdown,
368
  ],
369
  )
@@ -377,11 +405,9 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
377
  image_state,
378
  ],
379
  outputs=[
380
- overlay_output,
381
- mask_output,
382
  points_state,
383
  labels_state,
384
- points_table_output,
385
  status_markdown,
386
  ],
387
  )
@@ -390,11 +416,9 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
390
  undo_last_click,
391
  inputs=[image_state, points_state, labels_state],
392
  outputs=[
393
- overlay_output,
394
- mask_output,
395
  points_state,
396
  labels_state,
397
- points_table_output,
398
  status_markdown,
399
  ],
400
  )
@@ -403,11 +427,9 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
403
  clear_clicks,
404
  inputs=[image_state],
405
  outputs=[
406
- overlay_output,
407
- mask_output,
408
  points_state,
409
  labels_state,
410
- points_table_output,
411
  status_markdown,
412
  ],
413
  )
@@ -415,7 +437,7 @@ ZeroGPU automatically pulls a GPU when available; otherwise the app falls back t
415
  segment_button.click(
416
  segment_fn,
417
  inputs=[image_state, points_state, labels_state, granularity_slider],
418
- outputs=[overlay_output, mask_output, status_markdown],
419
  )
420
 
421
  demo.queue(max_size=8)
 
8
  import sys
9
  import threading
10
  from pathlib import Path
11
+ from typing import List, Optional, Sequence
12
 
13
  import cv2
14
  import gradio as gr
 
28
  from sam2.build_sam import build_sam2 # noqa: E402
29
  from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
30
 
31
+ logging.basicConfig(level=logging.INFO)
32
  LOGGER = logging.getLogger("unsamv2-gradio")
33
 
34
+ USE_M2M_REFINEMENT = True
35
+
36
  CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml")
37
  CKPT_PATH = Path(
38
  os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt")
 
55
  MASK_COLOR_BGR = (0, 196, 255)
56
  OUTLINE_COLOR_BGR = (0, 165, 255)
57
 
58
+ DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp"
59
+
60
+
61
+ def _load_default_image() -> Optional[np.ndarray]:
62
+ if not DEFAULT_IMAGE_PATH.exists():
63
+ LOGGER.warning("Default image missing at %s", DEFAULT_IMAGE_PATH)
64
+ return None
65
+ img_bgr = cv2.imread(str(DEFAULT_IMAGE_PATH), cv2.IMREAD_COLOR)
66
+ if img_bgr is None:
67
+ LOGGER.warning("Could not read default image at %s", DEFAULT_IMAGE_PATH)
68
+ return None
69
+ return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
70
+
71
+
72
+ DEFAULT_IMAGE = _load_default_image()
73
+
74
 
75
  class ModelManager:
76
  """Keeps SAM2 models on each device and spawns lightweight predictors."""
 
100
  return self._models[key]
101
 
102
  def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
103
+ return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0)
104
 
105
 
106
  MODEL_MANAGER = ModelManager()
 
137
  return tensor
138
 
139
 
140
+ def apply_m2m_refinement(
141
+ predictor,
142
+ point_coords,
143
+ point_labels,
144
+ granularity,
145
+ logits,
146
+ best_mask_idx,
147
+ use_m2m: bool = True,
148
+ ):
149
+ """Optionally run a second M2M pass using the best mask's logits."""
150
+ if not use_m2m:
151
+ return None
152
+
153
+ logging.info("Applying M2M refinement...")
154
+ try:
155
+ if logits is None:
156
+ raise ValueError("logits must be provided for M2M refinement.")
157
+
158
+ low_res_logits = logits[best_mask_idx : best_mask_idx + 1]
159
+ refined_masks, refined_scores, _ = predictor.predict(
160
+ point_coords=point_coords,
161
+ point_labels=point_labels,
162
+ multimask_output=False,
163
+ gra=granularity,
164
+ mask_input=low_res_logits,
165
+ )
166
+ refined_mask = refined_masks[0]
167
+ refined_score = float(refined_scores[0])
168
+ logging.info("M2M refinement completed with score: %.3f", refined_score)
169
+ return refined_mask, refined_score
170
+ except Exception as exc: # pragma: no cover - logging only
171
+ logging.error("M2M refinement failed: %s, using original mask", exc)
172
+ return None
173
+
174
+
175
  def draw_overlay(
176
  image: np.ndarray,
177
  mask: Optional[np.ndarray],
 
203
  return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def handle_image_upload(image: Optional[np.ndarray]):
207
  img = ensure_uint8(image)
208
  if img is None:
209
  return (
 
210
  None,
211
  None,
212
  [],
213
  [],
 
214
  "Upload an image to start adding clicks.",
215
  )
216
  return (
217
  img,
 
218
  img,
219
  [],
220
  [],
 
221
  "Image loaded. Choose click type, then tap on the image.",
222
  )
223
 
 
232
  if image is None:
233
  return (
234
  gr.update(),
 
235
  pts,
236
  lbls,
 
237
  "Upload an image first.",
238
  )
239
  coord = evt.index # (x, y)
240
  if coord is None:
241
  return (
242
  gr.update(),
 
243
  pts,
244
  lbls,
 
245
  "Couldn't read click position.",
246
  )
247
  x, y = coord
 
250
  lbls = lbls + [label]
251
  overlay = draw_overlay(image, None, pts, lbls)
252
  status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})."
253
+ return overlay, pts, lbls, status
254
 
255
 
256
  def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]):
257
  if not pts:
258
  return (
259
  gr.update(),
 
260
  pts,
261
  lbls,
 
262
  "No clicks to undo.",
263
  )
264
  pts = pts[:-1]
265
  lbls = lbls[:-1]
266
  overlay = draw_overlay(image, None, pts, lbls) if image is not None else None
267
  status = "Removed the last click."
268
+ return overlay, pts, lbls, status
269
 
270
 
271
  def clear_clicks(image: Optional[np.ndarray]):
272
  overlay = image if image is not None else None
273
+ return overlay, [], [], "Cleared all clicks."
274
 
275
 
276
  def _run_segmentation(
 
281
  ):
282
  img = ensure_uint8(image)
283
  if img is None:
284
+ return None, "Upload an image to segment."
285
  if not pts:
286
+ return draw_overlay(img, None, [], []), "Add at least one click before running segmentation."
287
 
288
  device = choose_device()
289
  predictor = MODEL_MANAGER.make_predictor(device)
 
293
  labels = np.asarray(lbls, dtype=np.int32)
294
  gran_tensor = build_granularity_tensor(granularity, predictor.device)
295
 
296
+ masks, scores, logits = predictor.predict(
297
  point_coords=coords,
298
  point_labels=labels,
299
  multimask_output=True,
 
302
  )
303
  best_idx = int(np.argmax(scores))
304
  best_mask = masks[best_idx].astype(bool)
305
+ status = (
306
+ f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | "
307
+ f"granularity={granularity:.2f}"
308
+ )
309
+
310
+ refinement = apply_m2m_refinement(
311
+ predictor=predictor,
312
+ point_coords=coords,
313
+ point_labels=labels,
314
+ granularity=float(granularity),
315
+ logits=logits,
316
+ best_mask_idx=best_idx,
317
+ use_m2m=USE_M2M_REFINEMENT,
318
+ )
319
+ if refinement is not None:
320
+ refined_mask, refined_score = refinement
321
+ best_mask = refined_mask.astype(bool)
322
+ status += f" | M2M IoU: {refined_score:.3f}"
323
+
324
  overlay = draw_overlay(img, best_mask, pts, lbls)
325
+ return overlay, status
 
 
326
 
327
 
328
  if spaces is not None and ZERO_GPU_ENABLED:
 
335
  with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo:
336
  gr.Markdown(
337
  """## UnSAMv2 · Interactive Granularity Control
338
+ Upload an image, add positive/negative clicks, tune granularity, and run segmentation."""
 
339
  )
340
 
341
+ image_state = gr.State(DEFAULT_IMAGE)
342
  points_state = gr.State([])
343
  labels_state = gr.State([])
344
 
345
+ image_input = gr.Image(
346
+ label="Image · clicks & mask",
347
+ type="numpy",
348
+ height=480,
349
+ value=DEFAULT_IMAGE,
350
+ sources=["upload"],
351
+ )
 
 
 
 
 
 
 
 
 
352
 
353
  with gr.Row():
354
  point_mode = gr.Radio(
 
360
  minimum=GRANULARITY_MIN,
361
  maximum=GRANULARITY_MAX,
362
  value=0.2,
363
+ step=0.01,
364
  label="Granularity",
365
  info="Lower = finer details, Higher = coarser regions",
366
  )
367
+ segment_button = gr.Button("Segment", variant="primary")
368
 
369
  with gr.Row():
370
  undo_button = gr.Button("Undo last click")
371
  clear_button = gr.Button("Clear clicks")
372
 
 
 
 
 
 
 
373
  status_markdown = gr.Markdown(" Ready.")
374
 
375
  image_input.upload(
376
  handle_image_upload,
377
  inputs=[image_input],
378
  outputs=[
379
+ image_input,
 
380
  image_state,
381
  points_state,
382
  labels_state,
 
383
  status_markdown,
384
  ],
385
  )
 
388
  handle_image_upload,
389
  inputs=[image_input],
390
  outputs=[
391
+ image_input,
 
392
  image_state,
393
  points_state,
394
  labels_state,
 
395
  status_markdown,
396
  ],
397
  )
 
405
  image_state,
406
  ],
407
  outputs=[
408
+ image_input,
 
409
  points_state,
410
  labels_state,
 
411
  status_markdown,
412
  ],
413
  )
 
416
  undo_last_click,
417
  inputs=[image_state, points_state, labels_state],
418
  outputs=[
419
+ image_input,
 
420
  points_state,
421
  labels_state,
 
422
  status_markdown,
423
  ],
424
  )
 
427
  clear_clicks,
428
  inputs=[image_state],
429
  outputs=[
430
+ image_input,
 
431
  points_state,
432
  labels_state,
 
433
  status_markdown,
434
  ],
435
  )
 
437
  segment_button.click(
438
  segment_fn,
439
  inputs=[image_state, points_state, labels_state, granularity_slider],
440
+ outputs=[image_input, status_markdown],
441
  )
442
 
443
  demo.queue(max_size=8)
demo/bird.webp ADDED
sam2/sam2/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/__pycache__/__init__.cpython-310.pyc and b/sam2/sam2/__pycache__/__init__.cpython-310.pyc differ
 
sam2/sam2/__pycache__/build_sam.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/__pycache__/build_sam.cpython-310.pyc and b/sam2/sam2/__pycache__/build_sam.cpython-310.pyc differ
 
sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc and b/sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc differ
 
sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc and b/sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc differ
 
sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc and b/sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc differ
 
sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc and b/sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc differ
 
sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc and b/sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc differ
 
sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc and b/sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc differ
 
sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc and b/sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc differ
 
sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc and b/sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc differ
 
sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc and b/sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc differ
 
sam2/sam2/modeling/sam/__pycache__/mask_decoder.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/sam/__pycache__/mask_decoder.cpython-310.pyc and b/sam2/sam2/modeling/sam/__pycache__/mask_decoder.cpython-310.pyc differ
 
sam2/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-310.pyc and b/sam2/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-310.pyc differ
 
sam2/sam2/modeling/sam/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/modeling/sam/__pycache__/transformer.cpython-310.pyc and b/sam2/sam2/modeling/sam/__pycache__/transformer.cpython-310.pyc differ
 
sam2/sam2/utils/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/utils/__pycache__/__init__.cpython-310.pyc and b/sam2/sam2/utils/__pycache__/__init__.cpython-310.pyc differ
 
sam2/sam2/utils/__pycache__/misc.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/utils/__pycache__/misc.cpython-310.pyc and b/sam2/sam2/utils/__pycache__/misc.cpython-310.pyc differ
 
sam2/sam2/utils/__pycache__/transforms.cpython-310.pyc CHANGED
Binary files a/sam2/sam2/utils/__pycache__/transforms.cpython-310.pyc and b/sam2/sam2/utils/__pycache__/transforms.cpython-310.pyc differ
 
sam2/training/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/training/__pycache__/__init__.cpython-310.pyc and b/sam2/training/__pycache__/__init__.cpython-310.pyc differ
 
sam2/training/utils/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/sam2/training/utils/__pycache__/__init__.cpython-310.pyc and b/sam2/training/utils/__pycache__/__init__.cpython-310.pyc differ
 
sam2/training/utils/__pycache__/checkpoint_utils.cpython-310.pyc CHANGED
Binary files a/sam2/training/utils/__pycache__/checkpoint_utils.cpython-310.pyc and b/sam2/training/utils/__pycache__/checkpoint_utils.cpython-310.pyc differ