import os import json import gradio as gr import torch from PIL import Image, ImageDraw from transformers import GroundingDinoProcessor from hf_model import CountEX from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries import google.generativeai as genai # Global variables for model and processor model = None processor = None device = None # Configure Gemini genai.configure(api_key='AIzaSyApqa65vVYTmw4FC4wP-6-_xpBLxXdctxE') gemini_model = genai.GenerativeModel("gemini-2.0-flash") PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two lists—A (include) and B (exclude)—splitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" → "red beans", "black beans"). Rules: - Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms) - Keep B items that are more specific than A (for fine-grained exclusion) - If B is more general than A but shares the head noun, remove B (contradictory) Case 1 — Different head nouns → Keep B Example 1: Count green apples and red beans, not yellow screws and white rice → A: ["green apples", "red beans"], B: ["yellow screws", "white rice"] Example 2: Count black beans, not poker chips or nails → A: ["black beans"], B: ["poker chips", "nails"] Case 2 — Equivalent items → Remove from B Example 1: Count fries and TV, not chips and television → A: ["fries", "TV"], B: [] Example 2: Count garbanzo beans and couch, not chickpeas and sofa → A: ["garbanzo beans", "couch"], B: [] Case 3 — B more specific than A → Keep B (for fine-grained exclusion) Example 1: Count apples and beans, not green apples and black beans → A: ["apples", "beans"], B: ["green apples", "black beans"] Example 2: Count beans, not white beans or yellow beans → A: ["beans"], B: ["white beans", "yellow beans"] Example 3: Count people, not women → A: ["people"], B: ["women"] Case 4 — B more general than A → Remove B (contradictory) Example 1: Count green apples, not apples → A: ["green apples"], B: [] Example 2: Count red beans and green apples, not beans and apples → A: ["red beans", "green apples"], B: [] User instruction: {instruction} Respond ONLY with a JSON object in this exact format, no other text: {{"A": ["item1", "item2"], "B": ["item3"]}} """ def parse_counting_instruction(instruction: str) -> tuple[str, str]: """ Parse natural language counting instruction using Gemini 2.0 Flash. Args: instruction: Natural language instruction like "count apples, not green apples" Returns: tuple: (positive_caption, negative_caption) """ try: prompt = PARSING_PROMPT.format(instruction=instruction) response = gemini_model.generate_content(prompt) response_text = response.text.strip() # Clean up response - remove markdown code blocks if present if response_text.startswith("```"): response_text = response_text.split("```")[1] if response_text.startswith("json"): response_text = response_text[4:] response_text = response_text.strip() result = json.loads(response_text) # Convert lists to caption strings pos_items = result.get("A", []) neg_items = result.get("B", []) # Join items with " and " and add period pos_caption = " and ".join(pos_items) + "." if pos_items else "" neg_caption = " and ".join(neg_items) + "." if neg_items else "None." return pos_caption, neg_caption except Exception as e: print(f"Error parsing instruction: {e}") # Fallback: treat entire instruction as positive caption return instruction.strip() + ".", "None." def load_model(): """Load model and processor once at startup""" global model, processor, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model - change path for HF Spaces model_id = "yifehuang97/CountEx_KC_aug_v3_12131215" # Change to your HF model repo model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) model = model.to(torch.bfloat16) model = model.to(device) model.eval() # Load processor processor_id = "fushh7/llmdet_swin_tiny_hf" processor = GroundingDinoProcessor.from_pretrained(processor_id) return model, processor, device import numpy as np def discriminative_point_suppression( points, neg_points, pos_queries, # (N, D) numpy array neg_queries, # (M, D) numpy array image_size, pixel_threshold=5, similarity_threshold=0.3, ): """ Discriminative Point Suppression (DPS): Step 1: Find spatially closest negative point for each positive point Step 2: If distance < pixel_threshold, check query similarity Step 3: Suppress only if query similarity > similarity_threshold This two-stage design ensures suppression only when predictions are both spatially overlapping AND semantically conflicting. Args: points: List of [x, y] positive points (normalized, 0-1) neg_points: List of [x, y] negative points (normalized, 0-1) pos_queries: (N, D) query embeddings for positive predictions neg_queries: (M, D) query embeddings for negative predictions image_size: (width, height) in pixels pixel_threshold: spatial distance threshold in pixels similarity_threshold: cosine similarity threshold for semantic conflict Returns: filtered_points: points after suppression filtered_indices: indices of kept points suppression_info: dict with detailed suppression decisions """ if not neg_points or not points: return points, list(range(len(points))), {} width, height = image_size N, M = len(points), len(neg_points) # === Step 1: Spatial Matching === points_arr = np.array(points) * np.array([width, height]) # (N, 2) neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2) # Compute pairwise distances spatial_dist = np.linalg.norm( points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1 ) # (N, M) # Find nearest negative for each positive nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,) nearest_neg_dist = spatial_dist.min(axis=1) # (N,) # Check spatial condition spatially_close = nearest_neg_dist < pixel_threshold # (N,) # === Step 2: Query Similarity Check (only for spatially close pairs) === # Normalize queries pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8) neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8) # Compute similarity only for matched pairs matched_neg_q = neg_q[nearest_neg_idx] # (N, D) query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity # Check semantic condition semantically_similar = query_sim > similarity_threshold # (N,) # === Step 3: Joint Decision === # Suppress only if BOTH conditions are met should_suppress = spatially_close & semantically_similar # (N,) # === Filter === keep_mask = ~should_suppress filtered_points = np.array(points)[keep_mask].tolist() filtered_indices = np.where(keep_mask)[0].tolist() # === Suppression Info === suppression_info = { "nearest_neg_idx": nearest_neg_idx.tolist(), "nearest_neg_dist": nearest_neg_dist.tolist(), "query_similarity": query_sim.tolist(), "spatially_close": spatially_close.tolist(), "semantically_similar": semantically_similar.tolist(), "suppressed_indices": np.where(should_suppress)[0].tolist(), } return filtered_points, filtered_indices, suppression_info def count_objects(image, instruction, box_threshold, point_radius, point_color): """ Main inference function for counting objects Args: image: Input PIL Image instruction: Natural language instruction (e.g., "count apples, not green apples") box_threshold: Detection confidence threshold point_radius: Radius of visualization points point_color: Color of visualization points Returns: Annotated image, count, and parsed captions """ global model, processor, device if model is None: load_model() # Parse instruction using Gemini pos_caption, neg_caption = parse_counting_instruction(instruction) parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Process positive caption pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) # Process negative caption use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') if not use_neg: neg_caption = "None." neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) # Add negative inputs to positive inputs dict pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True # Run inference with torch.no_grad(): outputs = model(**pos_inputs) # Post-process outputs outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] threshold = box_threshold if box_threshold > 0 else model.box_threshold pos_queries = outputs["pos_queries"].squeeze(0).float() neg_queries = outputs["neg_queries"].squeeze(0).float() pos_queries = pos_queries[-1].squeeze(0) neg_queries = neg_queries[-1].squeeze(0) pos_queries = pos_queries.unsqueeze(0) neg_queries = neg_queries.unsqueeze(0) results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] # Negative prediction neg_points = [] neg_results = None if "neg_pred_boxes" in outputs and "neg_logits" in outputs: neg_outputs = outputs.copy() neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] neg_outputs["logits"] = outputs["neg_logits"] neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] neg_outputs["pred_logits"] = outputs["neg_logits"] neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0] neg_boxes = neg_results["boxes"] neg_boxes = [box.tolist() for box in neg_boxes] neg_points = [[box[0], box[1]] for box in neg_boxes] pos_queries_np = results["queries"].cpu().numpy() neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) img_size = image.size if len(neg_points) > 0 and len(neg_queries_np) > 0: filtered_points, kept_indices, suppression_info = discriminative_point_suppression( points, neg_points, pos_queries_np, neg_queries_np, image_size=img_size, pixel_threshold=5, similarity_threshold=0.3, ) filtered_boxes = [boxes[i] for i in kept_indices] else: filtered_points = points filtered_boxes = boxes points = filtered_points boxes = filtered_boxes # Visualize results img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) count = len(points) return img_draw, f"Count: {count}", parsed_info def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): """ Manual mode: directly use provided positive and negative captions. """ global model, processor, device if model is None: load_model() # Ensure captions end with period if pos_caption and not pos_caption.endswith('.'): pos_caption = pos_caption + '.' if neg_caption and not neg_caption.endswith('.'): neg_caption = neg_caption + '.' if not neg_caption or neg_caption.strip() == '': neg_caption = "None." parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Process positive caption pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) # Process negative caption use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') if not use_neg: neg_caption = "None." neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) # Add negative inputs to positive inputs dict pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True # Run inference with torch.no_grad(): outputs = model(**pos_inputs) # Post-process outputs outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] threshold = box_threshold if box_threshold > 0 else model.box_threshold pos_queries = outputs["pos_queries"].squeeze(0).float() neg_queries = outputs["neg_queries"].squeeze(0).float() pos_queries = pos_queries[-1].squeeze(0) neg_queries = neg_queries[-1].squeeze(0) pos_queries = pos_queries.unsqueeze(0) neg_queries = neg_queries.unsqueeze(0) results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] # Negative prediction neg_points = [] neg_results = None if "neg_pred_boxes" in outputs and "neg_logits" in outputs: neg_outputs = outputs.copy() neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] neg_outputs["logits"] = outputs["neg_logits"] neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] neg_outputs["pred_logits"] = outputs["neg_logits"] neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0] neg_boxes = neg_results["boxes"] neg_boxes = [box.tolist() for box in neg_boxes] neg_points = [[box[0], box[1]] for box in neg_boxes] pos_queries_np = results["queries"].cpu().numpy() neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) img_size = image.size if len(neg_points) > 0 and len(neg_queries_np) > 0: filtered_points, kept_indices, suppression_info = discriminative_point_suppression( points, neg_points, pos_queries_np, neg_queries_np, image_size=img_size, pixel_threshold=5, similarity_threshold=0.3, ) filtered_boxes = [boxes[i] for i in kept_indices] else: filtered_points = points filtered_boxes = boxes points = filtered_points boxes = filtered_boxes # Visualize results img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) count = len(points) return img_draw, f"Count: {count}", parsed_info # Create Gradio interface def create_demo(): with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: gr.Markdown(""" # CountEx: Fine-Grained Counting via Exemplars and Exclusion Count specific objects in images using text prompts with exclusion capability. """) # State to track current input mode current_mode = gr.State(value="natural_language") with gr.Row(): # Left column - Input with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") with gr.Tabs() as input_tabs: # Tab 1: Natural Language Input with gr.TabItem("Natural Language", id=0) as tab_nl: instruction = gr.Textbox( label="Counting Instruction", placeholder="e.g., Count apples, not green apples", value="Count apples, not green apples", lines=2 ) gr.Markdown(""" **Examples:** - "Count apples, not green apples" - "Count red and black beans, exclude white beans" - "Count people, not women" """) # Tab 2: Manual Input with gr.TabItem("Manual Input", id=1) as tab_manual: pos_caption = gr.Textbox( label="Positive Prompt (objects to count)", placeholder="e.g., apple", value="apple." ) neg_caption = gr.Textbox( label="Negative Prompt (objects to exclude)", placeholder="e.g., green apple", value="None." ) # Single submit button outside tabs submit_btn = gr.Button("Count Objects", variant="primary", size="lg") # Shared settings with gr.Accordion("Advanced Settings", open=False): box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.42, step=0.01, label="Detection Threshold" ) point_radius = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Point Radius" ) point_color = gr.Dropdown( choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"], value="blue", label="Point Color" ) # Right column - Output with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Result") count_output = gr.Textbox(label="Count Result") parsed_output = gr.Textbox(label="Parsed Captions", lines=2) # Examples for Natural Language mode gr.Markdown("### Examples (Natural Language)") gr.Examples( examples=[ ["examples/apples.png", "Count apples, not green apples"], ["examples/apples.png", "Count apples, exclude red apples"], ["examples/apple.jpg", "Count green apples"], ["examples/apple.jpg", "Count apples, exclude red apples"], ["examples/apple.jpg", "Count apples, exclude green apples"], ["examples/black_beans.jpg", "Count black beans and soy beans"], ["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"], ["examples/strawberry.jpg", "Count blueberries and strawberry"], ["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"], ["examples/women.jpg", "Count people, not women"], ["examples/women.jpg", "Count people, not man"], ["examples/boat-1.jpg", "Count boats, exclude blue boats"], ["examples/boat-1.jpg", "Count boats, exclude red boats"], ], inputs=[input_image, instruction], outputs=[output_image, count_output, parsed_output], fn=count_objects, cache_examples=False, ) # Update mode when tab changes def set_mode_nl(): return "natural_language" def set_mode_manual(): return "manual" tab_nl.select(fn=set_mode_nl, outputs=[current_mode]) tab_manual.select(fn=set_mode_manual, outputs=[current_mode]) # Unified handler that routes based on mode def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color): if mode == "natural_language": return count_objects(image, instr, threshold, radius, color) else: return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color) # Single button click handler submit_btn.click( fn=handle_submit, inputs=[current_mode, input_image, instruction, pos_caption, neg_caption, box_threshold, point_radius, point_color], outputs=[output_image, count_output, parsed_output] ) return demo if __name__ == "__main__": # Load model at startup print("Loading model...") load_model() print("Model loaded!") # Create and launch demo demo = create_demo() demo.launch()