Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from transformers import GroundingDinoProcessor | |
| from hf_model import CountEX | |
| from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| device = None | |
| def load_model(): | |
| """Load model and processor once at startup""" | |
| global model, processor, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model - change path for HF Spaces | |
| model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo | |
| model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
| model = model.to(torch.bfloat16) | |
| model = model.to(device) | |
| model.eval() | |
| # Load processor | |
| processor_id = "fushh7/llmdet_swin_tiny_hf" | |
| processor = GroundingDinoProcessor.from_pretrained(processor_id) | |
| return model, processor, device | |
| import numpy as np | |
| def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5): | |
| """ | |
| Filter out positive points that are too close to any negative point. | |
| Args: | |
| points: List of [x, y] positive points (normalized coordinates, 0-1) | |
| neg_points: List of [x, y] negative points (normalized coordinates, 0-1) | |
| image_size: Tuple of (width, height) in pixels | |
| pixel_threshold: Minimum distance threshold in pixels | |
| Returns: | |
| filtered_points: List of points that are far enough from all negative points | |
| filtered_indices: Indices of the kept points in the original list | |
| """ | |
| if not neg_points or not points: | |
| return points, list(range(len(points))) | |
| width, height = image_size | |
| points_arr = np.array(points) # (N, 2) normalized | |
| neg_points_arr = np.array(neg_points) # (M, 2) normalized | |
| # Convert to pixel coordinates | |
| points_pixel = points_arr * np.array([width, height]) # (N, 2) | |
| neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2) | |
| # Compute pairwise distances in pixels: (N, M) | |
| diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :] | |
| distances = np.linalg.norm(diff, axis=-1) # (N, M) | |
| # Find minimum distance to any negative point for each positive point | |
| min_distances = distances.min(axis=1) # (N,) | |
| # Keep points where min distance > threshold | |
| keep_mask = min_distances > pixel_threshold | |
| filtered_points = points_arr[keep_mask].tolist() | |
| filtered_indices = np.where(keep_mask)[0].tolist() | |
| return filtered_points, filtered_indices | |
| import numpy as np | |
| def discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries, # (N, D) numpy array | |
| neg_queries, # (M, D) numpy array | |
| image_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ): | |
| """ | |
| Discriminative Point Suppression (DPS): | |
| Step 1: Find spatially closest negative point for each positive point | |
| Step 2: If distance < pixel_threshold, check query similarity | |
| Step 3: Suppress only if query similarity > similarity_threshold | |
| This two-stage design ensures suppression only when predictions are | |
| both spatially overlapping AND semantically conflicting. | |
| Args: | |
| points: List of [x, y] positive points (normalized, 0-1) | |
| neg_points: List of [x, y] negative points (normalized, 0-1) | |
| pos_queries: (N, D) query embeddings for positive predictions | |
| neg_queries: (M, D) query embeddings for negative predictions | |
| image_size: (width, height) in pixels | |
| pixel_threshold: spatial distance threshold in pixels | |
| similarity_threshold: cosine similarity threshold for semantic conflict | |
| Returns: | |
| filtered_points: points after suppression | |
| filtered_indices: indices of kept points | |
| suppression_info: dict with detailed suppression decisions | |
| """ | |
| if not neg_points or not points: | |
| return points, list(range(len(points))), {} | |
| width, height = image_size | |
| N, M = len(points), len(neg_points) | |
| # === Step 1: Spatial Matching === | |
| points_arr = np.array(points) * np.array([width, height]) # (N, 2) | |
| neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2) | |
| # Compute pairwise distances | |
| spatial_dist = np.linalg.norm( | |
| points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1 | |
| ) # (N, M) | |
| # Find nearest negative for each positive | |
| nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,) | |
| nearest_neg_dist = spatial_dist.min(axis=1) # (N,) | |
| # Check spatial condition | |
| spatially_close = nearest_neg_dist < pixel_threshold # (N,) | |
| # === Step 2: Query Similarity Check (only for spatially close pairs) === | |
| # Normalize queries | |
| pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8) | |
| neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8) | |
| # Compute similarity only for matched pairs | |
| matched_neg_q = neg_q[nearest_neg_idx] # (N, D) | |
| query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity | |
| # Check semantic condition | |
| semantically_similar = query_sim > similarity_threshold # (N,) | |
| # === Step 3: Joint Decision === | |
| # Suppress only if BOTH conditions are met | |
| should_suppress = spatially_close & semantically_similar # (N,) | |
| # === Filter === | |
| keep_mask = ~should_suppress | |
| filtered_points = np.array(points)[keep_mask].tolist() | |
| filtered_indices = np.where(keep_mask)[0].tolist() | |
| # === Suppression Info === | |
| suppression_info = { | |
| "nearest_neg_idx": nearest_neg_idx.tolist(), | |
| "nearest_neg_dist": nearest_neg_dist.tolist(), | |
| "query_similarity": query_sim.tolist(), | |
| "spatially_close": spatially_close.tolist(), | |
| "semantically_similar": semantically_similar.tolist(), | |
| "suppressed_indices": np.where(should_suppress)[0].tolist(), | |
| } | |
| return filtered_points, filtered_indices, suppression_info | |
| def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): | |
| """ | |
| Main inference function for counting objects | |
| Args: | |
| image: Input PIL Image | |
| pos_caption: Positive prompt (objects to count) | |
| neg_caption: Negative prompt (objects to exclude) | |
| box_threshold: Detection confidence threshold | |
| point_radius: Radius of visualization points | |
| point_color: Color of visualization points | |
| Returns: | |
| Annotated image and count | |
| """ | |
| global model, processor, device | |
| if model is None: | |
| load_model() | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Ensure captions end with period | |
| if not pos_caption.endswith('.'): | |
| pos_caption = pos_caption + '.' | |
| if neg_caption and not neg_caption.endswith('.'): | |
| neg_caption = neg_caption + '.' | |
| # Process positive caption | |
| pos_inputs = processor( | |
| images=image, | |
| text=pos_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| pos_inputs = pos_inputs.to(device) | |
| pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) | |
| # Process negative caption if provided | |
| use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.') | |
| if use_neg: | |
| neg_inputs = processor( | |
| images=image, | |
| text=neg_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} | |
| neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) | |
| # Add negative inputs to positive inputs dict | |
| pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] | |
| pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] | |
| pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] | |
| pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] | |
| pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] | |
| pos_inputs['use_neg'] = True | |
| else: | |
| pos_inputs['use_neg'] = False | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**pos_inputs) | |
| # Post-process outputs | |
| # positive prediction | |
| outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] | |
| outputs["pred_logits"] = outputs["logits"] | |
| threshold = box_threshold if box_threshold > 0 else model.box_threshold | |
| pos_queries = outputs["pos_queries"].squeeze(0).float() | |
| neg_queries = outputs["neg_queries"].squeeze(0).float() | |
| pos_queries = pos_queries[-1].squeeze(0) | |
| neg_queries = neg_queries[-1].squeeze(0) | |
| pos_queries = pos_queries.unsqueeze(0) | |
| neg_queries = neg_queries.unsqueeze(0) | |
| results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] | |
| boxes = results["boxes"] | |
| boxes = [box.tolist() for box in boxes] | |
| points = [[box[0], box[1]] for box in boxes] | |
| # negative prediction | |
| if "neg_pred_boxes" in outputs and "neg_logits" in outputs: | |
| neg_outputs = outputs.copy() | |
| neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] | |
| neg_outputs["logits"] = outputs["neg_logits"] | |
| neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] | |
| neg_outputs["pred_logits"] = outputs["neg_logits"] | |
| neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0] | |
| neg_boxes = neg_results["boxes"] | |
| neg_boxes = [box.tolist() for box in neg_boxes] | |
| neg_points = [[box[0], box[1]] for box in neg_boxes] | |
| pos_queries = results["queries"] | |
| neg_queries = neg_results["queries"] | |
| pos_queries = pos_queries.cpu().numpy() | |
| neg_queries = neg_queries.cpu().numpy() | |
| img_size = image.size | |
| # filtered_points, kept_indices = filter_points_by_negative( | |
| # points, | |
| # neg_points, | |
| # image_size=img_size, | |
| # pixel_threshold=5 | |
| # ) | |
| filtered_points, kept_indices, suppression_info = discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries, | |
| neg_queries, | |
| image_size=img_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ) | |
| filtered_boxes = [boxes[i] for i in kept_indices] | |
| if "scores" in results: | |
| filtered_scores = [results["scores"][i].item() for i in kept_indices] | |
| points = filtered_points | |
| boxes = filtered_boxes | |
| # Visualize results | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for point in points: | |
| x = point[0] * img_w | |
| y = point[1] * img_h | |
| draw.ellipse( | |
| [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| fill=point_color | |
| ) | |
| # for point in neg_points: | |
| # x = point[0] * img_w | |
| # y = point[1] * img_h | |
| # draw.ellipse( | |
| # [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| # fill="red" | |
| # ) | |
| count = len(points) | |
| return img_draw, f"Count: {count}" | |
| # Create Gradio interface | |
| def create_demo(): | |
| with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: | |
| gr.Markdown(""" | |
| # CountEx: Fine-Grained Counting via Exemplars and Exclusion | |
| Count specific objects in images using positive and negative text prompts. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| pos_caption = gr.Textbox( | |
| label="Positive Prompt", | |
| placeholder="e.g., Green Apple", | |
| value="Pos Caption Here." | |
| ) | |
| neg_caption = gr.Textbox( | |
| label="Negative Prompt (optional)", | |
| placeholder="e.g., Red Apple", | |
| value="None." | |
| ) | |
| box_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.42, | |
| step=0.01, | |
| label="Detection Threshold (0.42 = use model default)" | |
| ) | |
| point_radius = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Point Radius" | |
| ) | |
| point_color = gr.Dropdown( | |
| choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"], | |
| value="blue", | |
| label="Point Color" | |
| ) | |
| submit_btn = gr.Button("Count Objects", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="Result") | |
| count_output = gr.Textbox(label="Count Result") | |
| # Example images | |
| # ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."], | |
| gr.Examples( | |
| examples=[ | |
| ["examples/apples.png", "apple.", "Green apple."], | |
| ["examples/apples.png", "apple.", "red apple."], | |
| ["examples/black_beans.jpg", "Black bean.", "Soy bean."], | |
| ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."], | |
| ["examples/strawberry.jpg", "strawberry and blueberry.", "strawberry."], | |
| ["examples/strawberry2.jpg", "strawberry and blueberry.", "strawberry."], | |
| ["examples/women.jpg", "person.", "woman."], | |
| ], | |
| inputs=[input_image, pos_caption, neg_caption], | |
| outputs=[output_image, count_output], | |
| fn=count_objects, | |
| cache_examples=False, | |
| ) | |
| submit_btn.click( | |
| fn=count_objects, | |
| inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color], | |
| outputs=[output_image, count_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Load model at startup | |
| print("Loading model...") | |
| load_model() | |
| print("Model loaded!") | |
| # Create and launch demo | |
| demo = create_demo() | |
| demo.launch() |