Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from transformers import GroundingDinoProcessor | |
| from hf_model import CountEX | |
| from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries | |
| import google.generativeai as genai | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| device = None | |
| # Configure Gemini | |
| genai.configure(api_key='AIzaSyApqa65vVYTmw4FC4wP-6-_xpBLxXdctxE') | |
| gemini_model = genai.GenerativeModel("gemini-2.0-flash") | |
| PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two listsβA (include) and B (exclude)βsplitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" β "red beans", "black beans"). | |
| Rules: | |
| - Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms) | |
| - Keep B items that are more specific than A (for fine-grained exclusion) | |
| - If B is more general than A but shares the head noun, remove B (contradictory) | |
| Case 1 β Different head nouns β Keep B | |
| Example 1: Count green apples and red beans, not yellow screws and white rice β A: ["green apples", "red beans"], B: ["yellow screws", "white rice"] | |
| Example 2: Count black beans, not poker chips or nails β A: ["black beans"], B: ["poker chips", "nails"] | |
| Case 2 β Equivalent items β Remove from B | |
| Example 1: Count fries and TV, not chips and television β A: ["fries", "TV"], B: [] | |
| Example 2: Count garbanzo beans and couch, not chickpeas and sofa β A: ["garbanzo beans", "couch"], B: [] | |
| Case 3 β B more specific than A β Keep B (for fine-grained exclusion) | |
| Example 1: Count apples and beans, not green apples and black beans β A: ["apples", "beans"], B: ["green apples", "black beans"] | |
| Example 2: Count beans, not white beans or yellow beans β A: ["beans"], B: ["white beans", "yellow beans"] | |
| Example 3: Count people, not women β A: ["people"], B: ["women"] | |
| Case 4 β B more general than A β Remove B (contradictory) | |
| Example 1: Count green apples, not apples β A: ["green apples"], B: [] | |
| Example 2: Count red beans and green apples, not beans and apples β A: ["red beans", "green apples"], B: [] | |
| User instruction: {instruction} | |
| Respond ONLY with a JSON object in this exact format, no other text: | |
| {{"A": ["item1", "item2"], "B": ["item3"]}} | |
| """ | |
| def parse_counting_instruction(instruction: str) -> tuple[str, str]: | |
| """ | |
| Parse natural language counting instruction using Gemini 2.0 Flash. | |
| Args: | |
| instruction: Natural language instruction like "count apples, not green apples" | |
| Returns: | |
| tuple: (positive_caption, negative_caption) | |
| """ | |
| try: | |
| prompt = PARSING_PROMPT.format(instruction=instruction) | |
| response = gemini_model.generate_content(prompt) | |
| response_text = response.text.strip() | |
| # Clean up response - remove markdown code blocks if present | |
| if response_text.startswith("```"): | |
| response_text = response_text.split("```")[1] | |
| if response_text.startswith("json"): | |
| response_text = response_text[4:] | |
| response_text = response_text.strip() | |
| result = json.loads(response_text) | |
| # Convert lists to caption strings | |
| pos_items = result.get("A", []) | |
| neg_items = result.get("B", []) | |
| # Join items with " and " and add period | |
| pos_caption = " and ".join(pos_items) + "." if pos_items else "" | |
| neg_caption = " and ".join(neg_items) + "." if neg_items else "None." | |
| return pos_caption, neg_caption | |
| except Exception as e: | |
| print(f"Error parsing instruction: {e}") | |
| # Fallback: treat entire instruction as positive caption | |
| return instruction.strip() + ".", "None." | |
| def load_model(): | |
| """Load model and processor once at startup""" | |
| global model, processor, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model - change path for HF Spaces | |
| model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo | |
| model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
| model = model.to(torch.bfloat16) | |
| model = model.to(device) | |
| model.eval() | |
| # Load processor | |
| processor_id = "fushh7/llmdet_swin_tiny_hf" | |
| processor = GroundingDinoProcessor.from_pretrained(processor_id) | |
| return model, processor, device | |
| import numpy as np | |
| def discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries, # (N, D) numpy array | |
| neg_queries, # (M, D) numpy array | |
| image_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ): | |
| """ | |
| Discriminative Point Suppression (DPS): | |
| Step 1: Find spatially closest negative point for each positive point | |
| Step 2: If distance < pixel_threshold, check query similarity | |
| Step 3: Suppress only if query similarity > similarity_threshold | |
| This two-stage design ensures suppression only when predictions are | |
| both spatially overlapping AND semantically conflicting. | |
| Args: | |
| points: List of [x, y] positive points (normalized, 0-1) | |
| neg_points: List of [x, y] negative points (normalized, 0-1) | |
| pos_queries: (N, D) query embeddings for positive predictions | |
| neg_queries: (M, D) query embeddings for negative predictions | |
| image_size: (width, height) in pixels | |
| pixel_threshold: spatial distance threshold in pixels | |
| similarity_threshold: cosine similarity threshold for semantic conflict | |
| Returns: | |
| filtered_points: points after suppression | |
| filtered_indices: indices of kept points | |
| suppression_info: dict with detailed suppression decisions | |
| """ | |
| if not neg_points or not points: | |
| return points, list(range(len(points))), {} | |
| width, height = image_size | |
| N, M = len(points), len(neg_points) | |
| # === Step 1: Spatial Matching === | |
| points_arr = np.array(points) * np.array([width, height]) # (N, 2) | |
| neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2) | |
| # Compute pairwise distances | |
| spatial_dist = np.linalg.norm( | |
| points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1 | |
| ) # (N, M) | |
| # Find nearest negative for each positive | |
| nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,) | |
| nearest_neg_dist = spatial_dist.min(axis=1) # (N,) | |
| # Check spatial condition | |
| spatially_close = nearest_neg_dist < pixel_threshold # (N,) | |
| # === Step 2: Query Similarity Check (only for spatially close pairs) === | |
| # Normalize queries | |
| pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8) | |
| neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8) | |
| # Compute similarity only for matched pairs | |
| matched_neg_q = neg_q[nearest_neg_idx] # (N, D) | |
| query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity | |
| # Check semantic condition | |
| semantically_similar = query_sim > similarity_threshold # (N,) | |
| # === Step 3: Joint Decision === | |
| # Suppress only if BOTH conditions are met | |
| should_suppress = spatially_close & semantically_similar # (N,) | |
| # === Filter === | |
| keep_mask = ~should_suppress | |
| filtered_points = np.array(points)[keep_mask].tolist() | |
| filtered_indices = np.where(keep_mask)[0].tolist() | |
| # === Suppression Info === | |
| suppression_info = { | |
| "nearest_neg_idx": nearest_neg_idx.tolist(), | |
| "nearest_neg_dist": nearest_neg_dist.tolist(), | |
| "query_similarity": query_sim.tolist(), | |
| "spatially_close": spatially_close.tolist(), | |
| "semantically_similar": semantically_similar.tolist(), | |
| "suppressed_indices": np.where(should_suppress)[0].tolist(), | |
| } | |
| return filtered_points, filtered_indices, suppression_info | |
| def count_objects(image, instruction, box_threshold, point_radius, point_color): | |
| """ | |
| Main inference function for counting objects | |
| Args: | |
| image: Input PIL Image | |
| instruction: Natural language instruction (e.g., "count apples, not green apples") | |
| box_threshold: Detection confidence threshold | |
| point_radius: Radius of visualization points | |
| point_color: Color of visualization points | |
| Returns: | |
| Annotated image, count, and parsed captions | |
| """ | |
| global model, processor, device | |
| if model is None: | |
| load_model() | |
| # Parse instruction using Gemini | |
| pos_caption, neg_caption = parse_counting_instruction(instruction) | |
| parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Process positive caption | |
| pos_inputs = processor( | |
| images=image, | |
| text=pos_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| pos_inputs = pos_inputs.to(device) | |
| pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) | |
| # Process negative caption | |
| use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') | |
| if not use_neg: | |
| neg_caption = "None." | |
| neg_inputs = processor( | |
| images=image, | |
| text=neg_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} | |
| neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) | |
| # Add negative inputs to positive inputs dict | |
| pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] | |
| pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] | |
| pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] | |
| pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] | |
| pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] | |
| pos_inputs['use_neg'] = True | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**pos_inputs) | |
| # Post-process outputs | |
| outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] | |
| outputs["pred_logits"] = outputs["logits"] | |
| threshold = box_threshold if box_threshold > 0 else model.box_threshold | |
| pos_queries = outputs["pos_queries"].squeeze(0).float() | |
| neg_queries = outputs["neg_queries"].squeeze(0).float() | |
| pos_queries = pos_queries[-1].squeeze(0) | |
| neg_queries = neg_queries[-1].squeeze(0) | |
| pos_queries = pos_queries.unsqueeze(0) | |
| neg_queries = neg_queries.unsqueeze(0) | |
| results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] | |
| boxes = results["boxes"] | |
| boxes = [box.tolist() for box in boxes] | |
| points = [[box[0], box[1]] for box in boxes] | |
| # Negative prediction | |
| neg_points = [] | |
| neg_results = None | |
| if "neg_pred_boxes" in outputs and "neg_logits" in outputs: | |
| neg_outputs = outputs.copy() | |
| neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] | |
| neg_outputs["logits"] = outputs["neg_logits"] | |
| neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] | |
| neg_outputs["pred_logits"] = outputs["neg_logits"] | |
| neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0] | |
| neg_boxes = neg_results["boxes"] | |
| neg_boxes = [box.tolist() for box in neg_boxes] | |
| neg_points = [[box[0], box[1]] for box in neg_boxes] | |
| pos_queries_np = results["queries"].cpu().numpy() | |
| neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) | |
| img_size = image.size | |
| if len(neg_points) > 0 and len(neg_queries_np) > 0: | |
| filtered_points, kept_indices, suppression_info = discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries_np, | |
| neg_queries_np, | |
| image_size=img_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ) | |
| filtered_boxes = [boxes[i] for i in kept_indices] | |
| else: | |
| filtered_points = points | |
| filtered_boxes = boxes | |
| points = filtered_points | |
| boxes = filtered_boxes | |
| # Visualize results | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for point in points: | |
| x = point[0] * img_w | |
| y = point[1] * img_h | |
| draw.ellipse( | |
| [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| fill=point_color | |
| ) | |
| count = len(points) | |
| return img_draw, f"Count: {count}", parsed_info | |
| def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): | |
| """ | |
| Manual mode: directly use provided positive and negative captions. | |
| """ | |
| global model, processor, device | |
| if model is None: | |
| load_model() | |
| # Ensure captions end with period | |
| if pos_caption and not pos_caption.endswith('.'): | |
| pos_caption = pos_caption + '.' | |
| if neg_caption and not neg_caption.endswith('.'): | |
| neg_caption = neg_caption + '.' | |
| if not neg_caption or neg_caption.strip() == '': | |
| neg_caption = "None." | |
| parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Process positive caption | |
| pos_inputs = processor( | |
| images=image, | |
| text=pos_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| pos_inputs = pos_inputs.to(device) | |
| pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) | |
| # Process negative caption | |
| use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') | |
| if not use_neg: | |
| neg_caption = "None." | |
| neg_inputs = processor( | |
| images=image, | |
| text=neg_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} | |
| neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) | |
| # Add negative inputs to positive inputs dict | |
| pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] | |
| pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] | |
| pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] | |
| pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] | |
| pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] | |
| pos_inputs['use_neg'] = True | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**pos_inputs) | |
| # Post-process outputs | |
| outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] | |
| outputs["pred_logits"] = outputs["logits"] | |
| threshold = box_threshold if box_threshold > 0 else model.box_threshold | |
| pos_queries = outputs["pos_queries"].squeeze(0).float() | |
| neg_queries = outputs["neg_queries"].squeeze(0).float() | |
| pos_queries = pos_queries[-1].squeeze(0) | |
| neg_queries = neg_queries[-1].squeeze(0) | |
| pos_queries = pos_queries.unsqueeze(0) | |
| neg_queries = neg_queries.unsqueeze(0) | |
| results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] | |
| boxes = results["boxes"] | |
| boxes = [box.tolist() for box in boxes] | |
| points = [[box[0], box[1]] for box in boxes] | |
| # Negative prediction | |
| neg_points = [] | |
| neg_results = None | |
| if "neg_pred_boxes" in outputs and "neg_logits" in outputs: | |
| neg_outputs = outputs.copy() | |
| neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] | |
| neg_outputs["logits"] = outputs["neg_logits"] | |
| neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] | |
| neg_outputs["pred_logits"] = outputs["neg_logits"] | |
| neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0] | |
| neg_boxes = neg_results["boxes"] | |
| neg_boxes = [box.tolist() for box in neg_boxes] | |
| neg_points = [[box[0], box[1]] for box in neg_boxes] | |
| pos_queries_np = results["queries"].cpu().numpy() | |
| neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) | |
| img_size = image.size | |
| if len(neg_points) > 0 and len(neg_queries_np) > 0: | |
| filtered_points, kept_indices, suppression_info = discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries_np, | |
| neg_queries_np, | |
| image_size=img_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ) | |
| filtered_boxes = [boxes[i] for i in kept_indices] | |
| else: | |
| filtered_points = points | |
| filtered_boxes = boxes | |
| points = filtered_points | |
| boxes = filtered_boxes | |
| # Visualize results | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for point in points: | |
| x = point[0] * img_w | |
| y = point[1] * img_h | |
| draw.ellipse( | |
| [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| fill=point_color | |
| ) | |
| count = len(points) | |
| return img_draw, f"Count: {count}", parsed_info | |
| # Create Gradio interface | |
| def create_demo(): | |
| with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: | |
| gr.Markdown(""" | |
| # CountEx: Fine-Grained Counting via Exemplars and Exclusion | |
| Count specific objects in images using text prompts with exclusion capability. | |
| """) | |
| # State to track current input mode | |
| current_mode = gr.State(value="natural_language") | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| with gr.Tabs() as input_tabs: | |
| # Tab 1: Natural Language Input | |
| with gr.TabItem("Natural Language", id=0) as tab_nl: | |
| instruction = gr.Textbox( | |
| label="Counting Instruction", | |
| placeholder="e.g., Count apples, not green apples", | |
| value="Count apples, not green apples", | |
| lines=2 | |
| ) | |
| gr.Markdown(""" | |
| **Examples:** | |
| - "Count apples, not green apples" | |
| - "Count red and black beans, exclude white beans" | |
| - "Count people, not women" | |
| """) | |
| # Tab 2: Manual Input | |
| with gr.TabItem("Manual Input", id=1) as tab_manual: | |
| pos_caption = gr.Textbox( | |
| label="Positive Prompt (objects to count)", | |
| placeholder="e.g., apple", | |
| value="apple." | |
| ) | |
| neg_caption = gr.Textbox( | |
| label="Negative Prompt (objects to exclude)", | |
| placeholder="e.g., green apple", | |
| value="None." | |
| ) | |
| # Single submit button outside tabs | |
| submit_btn = gr.Button("Count Objects", variant="primary", size="lg") | |
| # Shared settings | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.42, | |
| step=0.01, | |
| label="Detection Threshold" | |
| ) | |
| point_radius = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Point Radius" | |
| ) | |
| point_color = gr.Dropdown( | |
| choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"], | |
| value="blue", | |
| label="Point Color" | |
| ) | |
| # Right column - Output | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="Result") | |
| count_output = gr.Textbox(label="Count Result") | |
| parsed_output = gr.Textbox(label="Parsed Captions", lines=2) | |
| # Examples for Natural Language mode | |
| gr.Markdown("### Examples (Natural Language)") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/apples.png", "Count apples, not green apples"], | |
| ["examples/apples.png", "Count apples, exclude red apples"], | |
| ["examples/apple.jpg", "Count green apples"], | |
| ["examples/apple.jpg", "Count apples, exclude red apples"], | |
| ["examples/apple.jpg", "Count apples, exclude green apples"], | |
| ["examples/black_beans.jpg", "Count black beans and soy beans"], | |
| ["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"], | |
| ["examples/strawberry.jpg", "Count blueberries and strawberry"], | |
| ["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"], | |
| ["examples/women.jpg", "Count people, not women"], | |
| ["examples/women.jpg", "Count people, not man"], | |
| ["examples/boat-1.jpg", "Count boats, exclude blue boats"], | |
| ["examples/boat-1.jpg", "Count boats, exclude red boats"], | |
| ], | |
| inputs=[input_image, instruction], | |
| outputs=[output_image, count_output, parsed_output], | |
| fn=count_objects, | |
| cache_examples=False, | |
| ) | |
| # Update mode when tab changes | |
| def set_mode_nl(): | |
| return "natural_language" | |
| def set_mode_manual(): | |
| return "manual" | |
| tab_nl.select(fn=set_mode_nl, outputs=[current_mode]) | |
| tab_manual.select(fn=set_mode_manual, outputs=[current_mode]) | |
| # Unified handler that routes based on mode | |
| def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color): | |
| if mode == "natural_language": | |
| return count_objects(image, instr, threshold, radius, color) | |
| else: | |
| return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color) | |
| # Single button click handler | |
| submit_btn.click( | |
| fn=handle_submit, | |
| inputs=[current_mode, input_image, instruction, pos_caption, neg_caption, | |
| box_threshold, point_radius, point_color], | |
| outputs=[output_image, count_output, parsed_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Load model at startup | |
| print("Loading model...") | |
| load_model() | |
| print("Model loaded!") | |
| # Create and launch demo | |
| demo = create_demo() | |
| demo.launch() |