CountEx / app.py
yifehuang97's picture
(feat) update examples
487ef32
raw
history blame
14.8 kB
import os
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import GroundingDinoProcessor
from hf_model import CountEX
from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
# Global variables for model and processor
model = None
processor = None
device = None
def load_model():
"""Load model and processor once at startup"""
global model, processor, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model - change path for HF Spaces
model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo
model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = model.to(torch.bfloat16)
model = model.to(device)
model.eval()
# Load processor
processor_id = "fushh7/llmdet_swin_tiny_hf"
processor = GroundingDinoProcessor.from_pretrained(processor_id)
return model, processor, device
import numpy as np
def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
"""
Filter out positive points that are too close to any negative point.
Args:
points: List of [x, y] positive points (normalized coordinates, 0-1)
neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
image_size: Tuple of (width, height) in pixels
pixel_threshold: Minimum distance threshold in pixels
Returns:
filtered_points: List of points that are far enough from all negative points
filtered_indices: Indices of the kept points in the original list
"""
if not neg_points or not points:
return points, list(range(len(points)))
width, height = image_size
points_arr = np.array(points) # (N, 2) normalized
neg_points_arr = np.array(neg_points) # (M, 2) normalized
# Convert to pixel coordinates
points_pixel = points_arr * np.array([width, height]) # (N, 2)
neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
# Compute pairwise distances in pixels: (N, M)
diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
distances = np.linalg.norm(diff, axis=-1) # (N, M)
# Find minimum distance to any negative point for each positive point
min_distances = distances.min(axis=1) # (N,)
# Keep points where min distance > threshold
keep_mask = min_distances > pixel_threshold
filtered_points = points_arr[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
return filtered_points, filtered_indices
import numpy as np
def discriminative_point_suppression(
points,
neg_points,
pos_queries, # (N, D) numpy array
neg_queries, # (M, D) numpy array
image_size,
pixel_threshold=5,
similarity_threshold=0.3,
):
"""
Discriminative Point Suppression (DPS):
Step 1: Find spatially closest negative point for each positive point
Step 2: If distance < pixel_threshold, check query similarity
Step 3: Suppress only if query similarity > similarity_threshold
This two-stage design ensures suppression only when predictions are
both spatially overlapping AND semantically conflicting.
Args:
points: List of [x, y] positive points (normalized, 0-1)
neg_points: List of [x, y] negative points (normalized, 0-1)
pos_queries: (N, D) query embeddings for positive predictions
neg_queries: (M, D) query embeddings for negative predictions
image_size: (width, height) in pixels
pixel_threshold: spatial distance threshold in pixels
similarity_threshold: cosine similarity threshold for semantic conflict
Returns:
filtered_points: points after suppression
filtered_indices: indices of kept points
suppression_info: dict with detailed suppression decisions
"""
if not neg_points or not points:
return points, list(range(len(points))), {}
width, height = image_size
N, M = len(points), len(neg_points)
# === Step 1: Spatial Matching ===
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
# Compute pairwise distances
spatial_dist = np.linalg.norm(
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
) # (N, M)
# Find nearest negative for each positive
nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
# Check spatial condition
spatially_close = nearest_neg_dist < pixel_threshold # (N,)
# === Step 2: Query Similarity Check (only for spatially close pairs) ===
# Normalize queries
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
# Compute similarity only for matched pairs
matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
# Check semantic condition
semantically_similar = query_sim > similarity_threshold # (N,)
# === Step 3: Joint Decision ===
# Suppress only if BOTH conditions are met
should_suppress = spatially_close & semantically_similar # (N,)
# === Filter ===
keep_mask = ~should_suppress
filtered_points = np.array(points)[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
# === Suppression Info ===
suppression_info = {
"nearest_neg_idx": nearest_neg_idx.tolist(),
"nearest_neg_dist": nearest_neg_dist.tolist(),
"query_similarity": query_sim.tolist(),
"spatially_close": spatially_close.tolist(),
"semantically_similar": semantically_similar.tolist(),
"suppressed_indices": np.where(should_suppress)[0].tolist(),
}
return filtered_points, filtered_indices, suppression_info
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
"""
Main inference function for counting objects
Args:
image: Input PIL Image
pos_caption: Positive prompt (objects to count)
neg_caption: Negative prompt (objects to exclude)
box_threshold: Detection confidence threshold
point_radius: Radius of visualization points
point_color: Color of visualization points
Returns:
Annotated image and count
"""
global model, processor, device
if model is None:
load_model()
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Ensure captions end with period
if not pos_caption.endswith('.'):
pos_caption = pos_caption + '.'
if neg_caption and not neg_caption.endswith('.'):
neg_caption = neg_caption + '.'
# Process positive caption
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
# Process negative caption if provided
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
if use_neg:
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
# Add negative inputs to positive inputs dict
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
else:
pos_inputs['use_neg'] = False
# Run inference
with torch.no_grad():
outputs = model(**pos_inputs)
# Post-process outputs
# positive prediction
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
pos_queries = outputs["pos_queries"].squeeze(0).float()
neg_queries = outputs["neg_queries"].squeeze(0).float()
pos_queries = pos_queries[-1].squeeze(0)
neg_queries = neg_queries[-1].squeeze(0)
pos_queries = pos_queries.unsqueeze(0)
neg_queries = neg_queries.unsqueeze(0)
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
# negative prediction
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries = results["queries"]
neg_queries = neg_results["queries"]
pos_queries = pos_queries.cpu().numpy()
neg_queries = neg_queries.cpu().numpy()
img_size = image.size
# filtered_points, kept_indices = filter_points_by_negative(
# points,
# neg_points,
# image_size=img_size,
# pixel_threshold=5
# )
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries,
neg_queries,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.3,
)
filtered_boxes = [boxes[i] for i in kept_indices]
if "scores" in results:
filtered_scores = [results["scores"][i].item() for i in kept_indices]
points = filtered_points
boxes = filtered_boxes
# Visualize results
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
# for point in neg_points:
# x = point[0] * img_w
# y = point[1] * img_h
# draw.ellipse(
# [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
# fill="red"
# )
count = len(points)
return img_draw, f"Count: {count}"
# Create Gradio interface
def create_demo():
with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
gr.Markdown("""
# CountEx: Fine-Grained Counting via Exemplars and Exclusion
Count specific objects in images using positive and negative text prompts.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
pos_caption = gr.Textbox(
label="Positive Prompt",
placeholder="e.g., Green Apple",
value="Pos Caption Here."
)
neg_caption = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="e.g., Red Apple",
value="None."
)
box_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.42,
step=0.01,
label="Detection Threshold (0.42 = use model default)"
)
point_radius = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Radius"
)
point_color = gr.Dropdown(
choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
value="blue",
label="Point Color"
)
submit_btn = gr.Button("Count Objects", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Result")
count_output = gr.Textbox(label="Count Result")
# Example images
# ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."],
gr.Examples(
examples=[
["examples/apples.png", "apple.", "Green apple."],
["examples/apples.png", "apple.", "red apple."],
["examples/black_beans.jpg", "Black bean.", "Soy bean."],
["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
["examples/strawberry.jpg", "strawberry and blueberry.", "strawberry."],
["examples/strawberry2.jpg", "strawberry and blueberry.", "strawberry."],
["examples/women.jpg", "person.", "woman."],
],
inputs=[input_image, pos_caption, neg_caption],
outputs=[output_image, count_output],
fn=count_objects,
cache_examples=False,
)
submit_btn.click(
fn=count_objects,
inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
outputs=[output_image, count_output]
)
return demo
if __name__ == "__main__":
# Load model at startup
print("Loading model...")
load_model()
print("Model loaded!")
# Create and launch demo
demo = create_demo()
demo.launch()