Spaces:
Sleeping
Sleeping
File size: 6,401 Bytes
ad13250 82c0d2b ad13250 82c0d2b ad13250 82c0d2b ad13250 82c0d2b ad13250 82c0d2b ad13250 82c0d2b ad13250 b1322c5 82c0d2b b1322c5 ad13250 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import os
import torchvision
import shutil
# --- Setup ---
# Clean caches each restart (helps avoid 50GB limit)
for cache_dir in [
os.path.expanduser("~/.cache/huggingface"),
os.path.expanduser("~/.cache/torch"),
]:
shutil.rmtree(cache_dir, ignore_errors=True)
# Force Hugging Face cache to /tmp (ephemeral)
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.makedirs(os.environ["HF_HUB_CACHE"], exist_ok=True)
# Gradio temp folder
os.environ["GRADIO_TEMP_DIR"] = "tmp"
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
# Handle ZeroGPU safely for local debugging
try:
import spaces
except ImportError:
class spaces:
def GPU(*args, **kwargs):
def decorator(fn): return fn
return decorator
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Lazy Model Loader ---
MODELS = {}
def get_model(selected_model):
"""Load model + processor on demand and cache in memory."""
if selected_model in MODELS:
return MODELS[selected_model]
print(f"Loading {selected_model}...")
if selected_model == "NoctOWLv2-Base":
model = Owlv2ForObjectDetection.from_pretrained(
"lorebianchi98/NoctOWLv2-base-patch16"
).to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
elif selected_model == "NoctOWLv2-Large":
model = Owlv2ForObjectDetection.from_pretrained(
"lorebianchi98/NoctOWLv2-large-patch14"
).to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
else:
raise gr.Error(f"Unknown model: {selected_model}")
# Cache in memory so re-selections don't re-load from disk
MODELS[selected_model] = (model, processor)
return model, processor
# --- Inference Function ---
@spaces.GPU(duration=120)
def query_image(img, text_queries, score_threshold, selected_model):
if img is None:
raise gr.Error("Please upload or select an example image first.")
if not text_queries.strip():
raise gr.Error("Please enter at least one text query.")
if selected_model is None or selected_model == "":
raise gr.Error("Please select a model before running inference.")
model, processor = get_model(selected_model)
model = model.to(device)
# Prepare text
text_queries = [f"a {t.strip()}" for t in text_queries.split(",") if t.strip()]
if not text_queries:
raise gr.Error("No valid queries found. Please check your input text.")
# Preprocess
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Postprocess
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=score_threshold
)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
# Non-Maximum Suppression
keep = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
# Format output
result_labels = []
for box, score, label in zip(boxes, scores, labels):
if score < score_threshold:
continue
box = [int(i) for i in box.tolist()]
result_labels.append((box, f"{text_queries[label.item()]} ({score:.2f})"))
return img, result_labels
# --- Interface Description ---
description = """
# π¦ **NoctOWLv2: Fine-Grained Open-Vocabulary Object Detection**
**NoctOWL** (***N***ot **o**nly **c**oarse-**t**ext **OWL**) extends **OWL-ViT** and **OWLv2** for **Fine-Grained Open-Vocabulary Detection (FG-OVD)**.
It can recognize subtle object differences such as **color, texture, and material**, while retaining strong coarse-grained detection abilities.
**Available Models:**
- π§© **NoctOWLv2-Base** β Smaller and faster.
- π§ **NoctOWLv2-Large** β More accurate, higher capacity.
π [Training & evaluation code](https://github.com/lorebianchi98/FG-OVD/NoctOWL)
"""
# --- Create Interface Layout ---
with gr.Blocks(title="NoctOWLv2 β Fine-Grained Zero-Shot Object Detection") as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
text_queries = gr.Textbox(
label="Text Queries (comma-separated)",
placeholder="e.g., red shoes, striped shirt, yellow ball"
)
score_threshold = gr.Slider(
0, 1, value=0.1, step=0.01, label="Score Threshold"
)
model_dropdown = gr.Dropdown(
choices=["NoctOWLv2-Base", "NoctOWLv2-Large"],
label="Select Model",
value=None,
info="Select which model to use for detection",
)
run_button = gr.Button("π Run Detection", interactive=False)
with gr.Column():
output_image = gr.AnnotatedImage(label="Detected Objects")
# --- Enable / Disable Run Button ---
def toggle_button(model, text):
return gr.update(interactive=bool(model and text.strip()))
model_dropdown.change(
fn=toggle_button,
inputs=[model_dropdown, text_queries],
outputs=run_button,
)
text_queries.change(
fn=toggle_button,
inputs=[model_dropdown, text_queries],
outputs=run_button,
)
# --- Connect Button to Inference ---
run_button.click(
fn=query_image,
inputs=[input_image, text_queries, score_threshold, model_dropdown],
outputs=output_image,
)
# --- Example Images ---
gr.Examples(
examples=[
["assets/desciglio.jpg", "striped football shirt, plain red football shirt, yellow shoes, red shoes", 0.07],
["assets/pool.jpg", "white ball, blue ball, black ball, yellow ball", 0.1],
["assets/patio.jpg", "ceramic mug, glass mug, pink flowers, blue flowers", 0.09],
],
inputs=[input_image, text_queries, score_threshold],
)
demo.launch()
|