yjwnb6 commited on
Commit
6c5665b
·
1 Parent(s): c9e1219
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -11,7 +11,7 @@ import tempfile
11
  import threading
12
  import uuid
13
  from pathlib import Path
14
- from typing import Any, Dict, List, Optional, Sequence, Tuple
15
 
16
  import cv2
17
  import gradio as gr
@@ -50,6 +50,12 @@ GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1))
50
  GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
51
  ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
52
  ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
 
 
 
 
 
 
53
  MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360"))
54
  WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64"))
55
  WHOLE_IMAGE_MAX_MASKS = 1000
@@ -360,6 +366,20 @@ def choose_device() -> torch.device:
360
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor:
364
  tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device)
365
  return tensor
@@ -835,10 +855,19 @@ def run_video_frame_segmentation(
835
  return overlay, status
836
 
837
 
838
- if spaces is not None and ZERO_GPU_ENABLED:
839
- segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation)
840
- else:
841
- segment_fn = _run_segmentation
 
 
 
 
 
 
 
 
 
842
 
843
 
844
  def build_demo() -> gr.Blocks:
@@ -1018,7 +1047,7 @@ def build_demo() -> gr.Blocks:
1018
  whole_status = gr.Markdown(" Ready for whole-image masks.")
1019
 
1020
  whole_generate_btn.click(
1021
- run_whole_image_segmentation,
1022
  inputs=[
1023
  whole_image_input,
1024
  whole_granularity,
@@ -1162,7 +1191,7 @@ def build_demo() -> gr.Blocks:
1162
  )
1163
 
1164
  video_frame_btn.click(
1165
- run_video_frame_segmentation,
1166
  inputs=[
1167
  video_state,
1168
  video_points_state,
@@ -1201,7 +1230,7 @@ def build_demo() -> gr.Blocks:
1201
  )
1202
 
1203
  video_segment_btn.click(
1204
- run_video_segmentation,
1205
  inputs=[
1206
  video_state,
1207
  video_points_state,
 
11
  import threading
12
  import uuid
13
  from pathlib import Path
14
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
15
 
16
  import cv2
17
  import gradio as gr
 
50
  GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
51
  ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
52
  ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
53
+ ZERO_GPU_WHOLE_DURATION = int(
54
+ os.getenv("UNSAMV2_ZEROGPU_WHOLE_DURATION", str(ZERO_GPU_DURATION))
55
+ )
56
+ ZERO_GPU_VIDEO_DURATION = int(
57
+ os.getenv("UNSAMV2_ZEROGPU_VIDEO_DURATION", str(max(120, ZERO_GPU_DURATION)))
58
+ )
59
  MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360"))
60
  WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64"))
61
  WHOLE_IMAGE_MAX_MASKS = 1000
 
366
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
367
 
368
 
369
+ def wrap_with_zero_gpu(
370
+ fn: Callable[..., Any],
371
+ duration: int,
372
+ ) -> Callable[..., Any]:
373
+ if spaces is None or not ZERO_GPU_ENABLED:
374
+ return fn
375
+ try:
376
+ LOGGER.info("Enabling ZeroGPU (duration=%ss) for %s", duration, fn.__name__)
377
+ return spaces.GPU(duration=duration)(fn) # type: ignore[misc]
378
+ except Exception: # pragma: no cover - defensive logging
379
+ LOGGER.exception("Failed to wrap %s with ZeroGPU; running on CPU", fn.__name__)
380
+ return fn
381
+
382
+
383
  def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor:
384
  tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device)
385
  return tensor
 
855
  return overlay, status
856
 
857
 
858
+ segment_fn = wrap_with_zero_gpu(_run_segmentation, ZERO_GPU_DURATION)
859
+ whole_image_fn = wrap_with_zero_gpu(
860
+ run_whole_image_segmentation,
861
+ ZERO_GPU_WHOLE_DURATION,
862
+ )
863
+ video_frame_fn = wrap_with_zero_gpu(
864
+ run_video_frame_segmentation,
865
+ ZERO_GPU_VIDEO_DURATION,
866
+ )
867
+ video_segmentation_fn = wrap_with_zero_gpu(
868
+ run_video_segmentation,
869
+ ZERO_GPU_VIDEO_DURATION,
870
+ )
871
 
872
 
873
  def build_demo() -> gr.Blocks:
 
1047
  whole_status = gr.Markdown(" Ready for whole-image masks.")
1048
 
1049
  whole_generate_btn.click(
1050
+ whole_image_fn,
1051
  inputs=[
1052
  whole_image_input,
1053
  whole_granularity,
 
1191
  )
1192
 
1193
  video_frame_btn.click(
1194
+ video_frame_fn,
1195
  inputs=[
1196
  video_state,
1197
  video_points_state,
 
1230
  )
1231
 
1232
  video_segment_btn.click(
1233
+ video_segmentation_fn,
1234
  inputs=[
1235
  video_state,
1236
  video_points_state,