CountEx / app.py
yifehuang97's picture
(feat) spatial semantic sup
9115945
raw
history blame
15.3 kB
import os
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import GroundingDinoProcessor
from hf_model import CountEX
from utils import post_process_grounded_object_detection
# Global variables for model and processor
model = None
processor = None
device = None
def load_model():
"""Load model and processor once at startup"""
global model, processor, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model - change path for HF Spaces
model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo
model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = model.to(torch.bfloat16)
model = model.to(device)
model.eval()
# Load processor
processor_id = "fushh7/llmdet_swin_tiny_hf"
processor = GroundingDinoProcessor.from_pretrained(processor_id)
return model, processor, device
import numpy as np
def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
"""
Filter out positive points that are too close to any negative point.
Args:
points: List of [x, y] positive points (normalized coordinates, 0-1)
neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
image_size: Tuple of (width, height) in pixels
pixel_threshold: Minimum distance threshold in pixels
Returns:
filtered_points: List of points that are far enough from all negative points
filtered_indices: Indices of the kept points in the original list
"""
if not neg_points or not points:
return points, list(range(len(points)))
width, height = image_size
points_arr = np.array(points) # (N, 2) normalized
neg_points_arr = np.array(neg_points) # (M, 2) normalized
# Convert to pixel coordinates
points_pixel = points_arr * np.array([width, height]) # (N, 2)
neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
# Compute pairwise distances in pixels: (N, M)
diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
distances = np.linalg.norm(diff, axis=-1) # (N, M)
# Find minimum distance to any negative point for each positive point
min_distances = distances.min(axis=1) # (N,)
# Keep points where min distance > threshold
keep_mask = min_distances > pixel_threshold
filtered_points = points_arr[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
return filtered_points, filtered_indices
def discriminative_point_suppression(
points,
neg_points,
pos_queries,
neg_queries,
image_size,
pixel_threshold=5,
similarity_threshold=0.5,
mode="and"
):
"""
Discriminative Point Suppression (DPS):
Suppress positive predictions that are both spatially close to
AND semantically similar with negative predictions.
Motivation: Spatial proximity alone may cause false suppression when
positive and negative queries represent different semantic concepts.
By jointly verifying spatial AND semantic alignment, we ensure
suppression only occurs for true conflicts.
Args:
points: List of [x, y] positive points (normalized, 0-1)
neg_points: List of [x, y] negative points (normalized, 0-1)
pos_queries: (N, D) query embeddings for positive predictions
neg_queries: (M, D) query embeddings for negative predictions
image_size: (width, height) in pixels
pixel_threshold: spatial distance threshold in pixels
similarity_threshold: cosine similarity threshold for semantic match
mode: "and" for hard joint condition, "weighted" for soft combination
Returns:
filtered_points: points after suppression
filtered_indices: indices of kept points
suppression_info: dict with detailed suppression decisions (for analysis)
"""
if not neg_points or not points:
return points, list(range(len(points))), {}
width, height = image_size
N, M = len(points), len(neg_points)
# === Spatial Distance ===
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
spatial_dist = np.linalg.norm(
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
) # (N, M)
# === Query Similarity (Cosine) ===
# Normalize queries
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
query_sim = np.dot(pos_q, neg_q.T) # (N, M), range [-1, 1]
# === Joint Suppression Decision ===
if mode == "and":
# Hard condition: suppress only if BOTH spatially close AND semantically similar
spatial_close = spatial_dist < pixel_threshold # (N, M)
semantic_similar = query_sim > similarity_threshold # (N, M)
# A positive is suppressed if ANY negative satisfies both conditions
should_suppress = (spatial_close & semantic_similar).any(axis=1) # (N,)
elif mode == "weighted":
# Soft combination: weighted score
# Convert distance to proximity score (0-1, higher = closer)
spatial_proximity = np.exp(-spatial_dist / pixel_threshold) # (N, M)
# Normalize similarity to [0, 1]
semantic_score = (query_sim + 1) / 2 # (N, M)
# Combined suppression score
suppression_score = spatial_proximity * semantic_score # (N, M)
max_suppression = suppression_score.max(axis=1) # (N,)
should_suppress = max_suppression > similarity_threshold
else:
raise ValueError(f"Unknown mode: {mode}")
# === Filter ===
keep_mask = ~should_suppress
filtered_points = np.array(points)[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
# === Suppression Info (for analysis/visualization) ===
suppression_info = {
"spatial_dist": spatial_dist,
"query_similarity": query_sim,
"suppressed_indices": np.where(should_suppress)[0].tolist(),
"suppressed_reasons": []
}
# Record why each point was suppressed
for i in np.where(should_suppress)[0]:
if mode == "and":
matching_negs = np.where(spatial_close[i] & semantic_similar[i])[0]
else:
matching_negs = [suppression_score[i].argmax()]
suppression_info["suppressed_reasons"].append({
"pos_idx": int(i),
"matched_neg_idx": matching_negs.tolist() if isinstance(matching_negs, np.ndarray) else matching_negs,
"min_spatial_dist": float(spatial_dist[i].min()),
"max_query_sim": float(query_sim[i].max())
})
return filtered_points, filtered_indices, suppression_info
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
"""
Main inference function for counting objects
Args:
image: Input PIL Image
pos_caption: Positive prompt (objects to count)
neg_caption: Negative prompt (objects to exclude)
box_threshold: Detection confidence threshold
point_radius: Radius of visualization points
point_color: Color of visualization points
Returns:
Annotated image and count
"""
global model, processor, device
if model is None:
load_model()
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Ensure captions end with period
if not pos_caption.endswith('.'):
pos_caption = pos_caption + '.'
if neg_caption and not neg_caption.endswith('.'):
neg_caption = neg_caption + '.'
# Process positive caption
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
# Process negative caption if provided
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
if use_neg:
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
# Add negative inputs to positive inputs dict
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
else:
pos_inputs['use_neg'] = False
# Run inference
with torch.no_grad():
outputs = model(**pos_inputs)
# Post-process outputs
# positive prediction
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
# negative prediction
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection(neg_outputs, box_threshold=threshold)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries = outputs["pos_queries"].squeeze(0)
neg_queries = outputs["neg_queries"].squeeze(0)
pos_queries = pos_queries.cpu().numpy()
neg_queries = neg_queries.cpu().numpy()
img_size = image.size
# filtered_points, kept_indices = filter_points_by_negative(
# points,
# neg_points,
# image_size=img_size,
# pixel_threshold=5
# )
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries,
neg_queries,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.5,
mode="and"
)
filtered_boxes = [boxes[i] for i in kept_indices]
if "scores" in results:
filtered_scores = [results["scores"][i].item() for i in kept_indices]
points = filtered_points
boxes = filtered_boxes
# Visualize results
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
# for point in neg_points:
# x = point[0] * img_w
# y = point[1] * img_h
# draw.ellipse(
# [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
# fill="red"
# )
count = len(points)
return img_draw, f"Count: {count}"
# Create Gradio interface
def create_demo():
with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
gr.Markdown("""
# CountEx: Fine-Grained Counting via Exemplars and Exclusion
Count specific objects in images using positive and negative text prompts.
**Important Note: Both the Positive and Negative prompts must end with a period (.) for the model to correctly interpret the instruction.**
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
pos_caption = gr.Textbox(
label="Positive Prompt",
placeholder="e.g., Green Apple",
value="Pos Caption Here."
)
neg_caption = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="e.g., Red Apple",
value="None."
)
box_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.42,
step=0.01,
label="Detection Threshold (0.42 = use model default)"
)
point_radius = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Radius"
)
point_color = gr.Dropdown(
choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
value="blue",
label="Point Color"
)
submit_btn = gr.Button("Count Objects", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Result")
count_output = gr.Textbox(label="Count Result")
# Example images
# ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."],
gr.Examples(
examples=[
["examples/apples.png", "Green apple.", "Red apple."],
["examples/black_beans.jpg", "Black bean.", "Soy bean."],
["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
],
inputs=[input_image, pos_caption, neg_caption],
outputs=[output_image, count_output],
fn=count_objects,
cache_examples=False,
)
submit_btn.click(
fn=count_objects,
inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
outputs=[output_image, count_output]
)
return demo
if __name__ == "__main__":
# Load model at startup
print("Loading model...")
load_model()
print("Model loaded!")
# Create and launch demo
demo = create_demo()
demo.launch()