CountEx / app.py
yifehuang97's picture
Update app.py
8eeb8f0 verified
import os
import json
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import GroundingDinoProcessor
from hf_model import CountEX
from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
import google.generativeai as genai
# Global variables for model and processor
model = None
processor = None
device = None
# Configure Gemini
genai.configure(api_key='AIzaSyApqa65vVYTmw4FC4wP-6-_xpBLxXdctxE')
gemini_model = genai.GenerativeModel("gemini-2.0-flash")
PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two listsβ€”A (include) and B (exclude)β€”splitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" β†’ "red beans", "black beans").
Rules:
- Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms)
- Keep B items that are more specific than A (for fine-grained exclusion)
- If B is more general than A but shares the head noun, remove B (contradictory)
Case 1 β€” Different head nouns β†’ Keep B
Example 1: Count green apples and red beans, not yellow screws and white rice β†’ A: ["green apples", "red beans"], B: ["yellow screws", "white rice"]
Example 2: Count black beans, not poker chips or nails β†’ A: ["black beans"], B: ["poker chips", "nails"]
Case 2 β€” Equivalent items β†’ Remove from B
Example 1: Count fries and TV, not chips and television β†’ A: ["fries", "TV"], B: []
Example 2: Count garbanzo beans and couch, not chickpeas and sofa β†’ A: ["garbanzo beans", "couch"], B: []
Case 3 β€” B more specific than A β†’ Keep B (for fine-grained exclusion)
Example 1: Count apples and beans, not green apples and black beans β†’ A: ["apples", "beans"], B: ["green apples", "black beans"]
Example 2: Count beans, not white beans or yellow beans β†’ A: ["beans"], B: ["white beans", "yellow beans"]
Example 3: Count people, not women β†’ A: ["people"], B: ["women"]
Case 4 β€” B more general than A β†’ Remove B (contradictory)
Example 1: Count green apples, not apples β†’ A: ["green apples"], B: []
Example 2: Count red beans and green apples, not beans and apples β†’ A: ["red beans", "green apples"], B: []
User instruction: {instruction}
Respond ONLY with a JSON object in this exact format, no other text:
{{"A": ["item1", "item2"], "B": ["item3"]}}
"""
def parse_counting_instruction(instruction: str) -> tuple[str, str]:
"""
Parse natural language counting instruction using Gemini 2.0 Flash.
Args:
instruction: Natural language instruction like "count apples, not green apples"
Returns:
tuple: (positive_caption, negative_caption)
"""
try:
prompt = PARSING_PROMPT.format(instruction=instruction)
response = gemini_model.generate_content(prompt)
response_text = response.text.strip()
# Clean up response - remove markdown code blocks if present
if response_text.startswith("```"):
response_text = response_text.split("```")[1]
if response_text.startswith("json"):
response_text = response_text[4:]
response_text = response_text.strip()
result = json.loads(response_text)
# Convert lists to caption strings
pos_items = result.get("A", [])
neg_items = result.get("B", [])
# Join items with " and " and add period
pos_caption = " and ".join(pos_items) + "." if pos_items else ""
neg_caption = " and ".join(neg_items) + "." if neg_items else "None."
return pos_caption, neg_caption
except Exception as e:
print(f"Error parsing instruction: {e}")
# Fallback: treat entire instruction as positive caption
return instruction.strip() + ".", "None."
def load_model():
"""Load model and processor once at startup"""
global model, processor, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model - change path for HF Spaces
model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo
model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = model.to(torch.bfloat16)
model = model.to(device)
model.eval()
# Load processor
processor_id = "fushh7/llmdet_swin_tiny_hf"
processor = GroundingDinoProcessor.from_pretrained(processor_id)
return model, processor, device
import numpy as np
def discriminative_point_suppression(
points,
neg_points,
pos_queries, # (N, D) numpy array
neg_queries, # (M, D) numpy array
image_size,
pixel_threshold=5,
similarity_threshold=0.3,
):
"""
Discriminative Point Suppression (DPS):
Step 1: Find spatially closest negative point for each positive point
Step 2: If distance < pixel_threshold, check query similarity
Step 3: Suppress only if query similarity > similarity_threshold
This two-stage design ensures suppression only when predictions are
both spatially overlapping AND semantically conflicting.
Args:
points: List of [x, y] positive points (normalized, 0-1)
neg_points: List of [x, y] negative points (normalized, 0-1)
pos_queries: (N, D) query embeddings for positive predictions
neg_queries: (M, D) query embeddings for negative predictions
image_size: (width, height) in pixels
pixel_threshold: spatial distance threshold in pixels
similarity_threshold: cosine similarity threshold for semantic conflict
Returns:
filtered_points: points after suppression
filtered_indices: indices of kept points
suppression_info: dict with detailed suppression decisions
"""
if not neg_points or not points:
return points, list(range(len(points))), {}
width, height = image_size
N, M = len(points), len(neg_points)
# === Step 1: Spatial Matching ===
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
# Compute pairwise distances
spatial_dist = np.linalg.norm(
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
) # (N, M)
# Find nearest negative for each positive
nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
# Check spatial condition
spatially_close = nearest_neg_dist < pixel_threshold # (N,)
# === Step 2: Query Similarity Check (only for spatially close pairs) ===
# Normalize queries
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
# Compute similarity only for matched pairs
matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
# Check semantic condition
semantically_similar = query_sim > similarity_threshold # (N,)
# === Step 3: Joint Decision ===
# Suppress only if BOTH conditions are met
should_suppress = spatially_close & semantically_similar # (N,)
# === Filter ===
keep_mask = ~should_suppress
filtered_points = np.array(points)[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
# === Suppression Info ===
suppression_info = {
"nearest_neg_idx": nearest_neg_idx.tolist(),
"nearest_neg_dist": nearest_neg_dist.tolist(),
"query_similarity": query_sim.tolist(),
"spatially_close": spatially_close.tolist(),
"semantically_similar": semantically_similar.tolist(),
"suppressed_indices": np.where(should_suppress)[0].tolist(),
}
return filtered_points, filtered_indices, suppression_info
def count_objects(image, instruction, box_threshold, point_radius, point_color):
"""
Main inference function for counting objects
Args:
image: Input PIL Image
instruction: Natural language instruction (e.g., "count apples, not green apples")
box_threshold: Detection confidence threshold
point_radius: Radius of visualization points
point_color: Color of visualization points
Returns:
Annotated image, count, and parsed captions
"""
global model, processor, device
if model is None:
load_model()
# Parse instruction using Gemini
pos_caption, neg_caption = parse_counting_instruction(instruction)
parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Process positive caption
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
# Process negative caption
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
if not use_neg:
neg_caption = "None."
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
# Add negative inputs to positive inputs dict
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
# Run inference
with torch.no_grad():
outputs = model(**pos_inputs)
# Post-process outputs
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
pos_queries = outputs["pos_queries"].squeeze(0).float()
neg_queries = outputs["neg_queries"].squeeze(0).float()
pos_queries = pos_queries[-1].squeeze(0)
neg_queries = neg_queries[-1].squeeze(0)
pos_queries = pos_queries.unsqueeze(0)
neg_queries = neg_queries.unsqueeze(0)
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
# Negative prediction
neg_points = []
neg_results = None
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries_np = results["queries"].cpu().numpy()
neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
img_size = image.size
if len(neg_points) > 0 and len(neg_queries_np) > 0:
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries_np,
neg_queries_np,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.3,
)
filtered_boxes = [boxes[i] for i in kept_indices]
else:
filtered_points = points
filtered_boxes = boxes
points = filtered_points
boxes = filtered_boxes
# Visualize results
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
count = len(points)
return img_draw, f"Count: {count}", parsed_info
def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
"""
Manual mode: directly use provided positive and negative captions.
"""
global model, processor, device
if model is None:
load_model()
# Ensure captions end with period
if pos_caption and not pos_caption.endswith('.'):
pos_caption = pos_caption + '.'
if neg_caption and not neg_caption.endswith('.'):
neg_caption = neg_caption + '.'
if not neg_caption or neg_caption.strip() == '':
neg_caption = "None."
parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Process positive caption
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
# Process negative caption
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
if not use_neg:
neg_caption = "None."
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
# Add negative inputs to positive inputs dict
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
# Run inference
with torch.no_grad():
outputs = model(**pos_inputs)
# Post-process outputs
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
pos_queries = outputs["pos_queries"].squeeze(0).float()
neg_queries = outputs["neg_queries"].squeeze(0).float()
pos_queries = pos_queries[-1].squeeze(0)
neg_queries = neg_queries[-1].squeeze(0)
pos_queries = pos_queries.unsqueeze(0)
neg_queries = neg_queries.unsqueeze(0)
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
# Negative prediction
neg_points = []
neg_results = None
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries_np = results["queries"].cpu().numpy()
neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
img_size = image.size
if len(neg_points) > 0 and len(neg_queries_np) > 0:
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries_np,
neg_queries_np,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.3,
)
filtered_boxes = [boxes[i] for i in kept_indices]
else:
filtered_points = points
filtered_boxes = boxes
points = filtered_points
boxes = filtered_boxes
# Visualize results
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
count = len(points)
return img_draw, f"Count: {count}", parsed_info
# Create Gradio interface
def create_demo():
with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
gr.Markdown("""
# CountEx: Fine-Grained Counting via Exemplars and Exclusion
Count specific objects in images using text prompts with exclusion capability.
""")
# State to track current input mode
current_mode = gr.State(value="natural_language")
with gr.Row():
# Left column - Input
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
with gr.Tabs() as input_tabs:
# Tab 1: Natural Language Input
with gr.TabItem("Natural Language", id=0) as tab_nl:
instruction = gr.Textbox(
label="Counting Instruction",
placeholder="e.g., Count apples, not green apples",
value="Count apples, not green apples",
lines=2
)
gr.Markdown("""
**Examples:**
- "Count apples, not green apples"
- "Count red and black beans, exclude white beans"
- "Count people, not women"
""")
# Tab 2: Manual Input
with gr.TabItem("Manual Input", id=1) as tab_manual:
pos_caption = gr.Textbox(
label="Positive Prompt (objects to count)",
placeholder="e.g., apple",
value="apple."
)
neg_caption = gr.Textbox(
label="Negative Prompt (objects to exclude)",
placeholder="e.g., green apple",
value="None."
)
# Single submit button outside tabs
submit_btn = gr.Button("Count Objects", variant="primary", size="lg")
# Shared settings
with gr.Accordion("Advanced Settings", open=False):
box_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.42,
step=0.01,
label="Detection Threshold"
)
point_radius = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Radius"
)
point_color = gr.Dropdown(
choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
value="blue",
label="Point Color"
)
# Right column - Output
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Result")
count_output = gr.Textbox(label="Count Result")
parsed_output = gr.Textbox(label="Parsed Captions", lines=2)
# Examples for Natural Language mode
gr.Markdown("### Examples (Natural Language)")
gr.Examples(
examples=[
["examples/apples.png", "Count apples, not green apples"],
["examples/apples.png", "Count apples, exclude red apples"],
["examples/apple.jpg", "Count green apples"],
["examples/apple.jpg", "Count apples, exclude red apples"],
["examples/apple.jpg", "Count apples, exclude green apples"],
["examples/black_beans.jpg", "Count black beans and soy beans"],
["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"],
["examples/strawberry.jpg", "Count blueberries and strawberry"],
["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"],
["examples/women.jpg", "Count people, not women"],
["examples/women.jpg", "Count people, not man"],
["examples/boat-1.jpg", "Count boats, exclude blue boats"],
["examples/boat-1.jpg", "Count boats, exclude red boats"],
],
inputs=[input_image, instruction],
outputs=[output_image, count_output, parsed_output],
fn=count_objects,
cache_examples=False,
)
# Update mode when tab changes
def set_mode_nl():
return "natural_language"
def set_mode_manual():
return "manual"
tab_nl.select(fn=set_mode_nl, outputs=[current_mode])
tab_manual.select(fn=set_mode_manual, outputs=[current_mode])
# Unified handler that routes based on mode
def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color):
if mode == "natural_language":
return count_objects(image, instr, threshold, radius, color)
else:
return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color)
# Single button click handler
submit_btn.click(
fn=handle_submit,
inputs=[current_mode, input_image, instruction, pos_caption, neg_caption,
box_threshold, point_radius, point_color],
outputs=[output_image, count_output, parsed_output]
)
return demo
if __name__ == "__main__":
# Load model at startup
print("Loading model...")
load_model()
print("Model loaded!")
# Create and launch demo
demo = create_demo()
demo.launch()