import os import gradio as gr import torch from PIL import Image, ImageDraw from transformers import GroundingDinoProcessor from hf_model import CountEX from utils import post_process_grounded_object_detection # Global variables for model and processor model = None processor = None device = None def load_model(): """Load model and processor once at startup""" global model, processor, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model - change path for HF Spaces model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) model = model.to(torch.bfloat16) model = model.to(device) model.eval() # Load processor processor_id = "fushh7/llmdet_swin_tiny_hf" processor = GroundingDinoProcessor.from_pretrained(processor_id) return model, processor, device import numpy as np def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5): """ Filter out positive points that are too close to any negative point. Args: points: List of [x, y] positive points (normalized coordinates, 0-1) neg_points: List of [x, y] negative points (normalized coordinates, 0-1) image_size: Tuple of (width, height) in pixels pixel_threshold: Minimum distance threshold in pixels Returns: filtered_points: List of points that are far enough from all negative points filtered_indices: Indices of the kept points in the original list """ if not neg_points or not points: return points, list(range(len(points))) width, height = image_size points_arr = np.array(points) # (N, 2) normalized neg_points_arr = np.array(neg_points) # (M, 2) normalized # Convert to pixel coordinates points_pixel = points_arr * np.array([width, height]) # (N, 2) neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2) # Compute pairwise distances in pixels: (N, M) diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :] distances = np.linalg.norm(diff, axis=-1) # (N, M) # Find minimum distance to any negative point for each positive point min_distances = distances.min(axis=1) # (N,) # Keep points where min distance > threshold keep_mask = min_distances > pixel_threshold filtered_points = points_arr[keep_mask].tolist() filtered_indices = np.where(keep_mask)[0].tolist() return filtered_points, filtered_indices def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): """ Main inference function for counting objects Args: image: Input PIL Image pos_caption: Positive prompt (objects to count) neg_caption: Negative prompt (objects to exclude) box_threshold: Detection confidence threshold point_radius: Radius of visualization points point_color: Color of visualization points Returns: Annotated image and count """ global model, processor, device if model is None: load_model() # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Ensure captions end with period if not pos_caption.endswith('.'): pos_caption = pos_caption + '.' if neg_caption and not neg_caption.endswith('.'): neg_caption = neg_caption + '.' # Process positive caption pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) # Process negative caption if provided use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.') if use_neg: neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) # Add negative inputs to positive inputs dict pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True else: pos_inputs['use_neg'] = False # Run inference with torch.no_grad(): outputs = model(**pos_inputs) # Post-process outputs # positive prediction outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] threshold = box_threshold if box_threshold > 0 else model.box_threshold results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0] boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] # negative prediction if "neg_pred_boxes" in outputs and "neg_logits" in outputs: neg_outputs = outputs.copy() neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] neg_outputs["logits"] = outputs["neg_logits"] neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] neg_outputs["pred_logits"] = outputs["neg_logits"] neg_results = post_process_grounded_object_detection(neg_outputs, box_threshold=threshold)[0] neg_boxes = neg_results["boxes"] neg_boxes = [box.tolist() for box in neg_boxes] neg_points = [[box[0], box[1]] for box in neg_boxes] img_size = image.size filtered_points, kept_indices = filter_points_by_negative( points, neg_points, image_size=img_size, pixel_threshold=5 ) filtered_boxes = [boxes[i] for i in kept_indices] if "scores" in results: filtered_scores = [results["scores"][i].item() for i in kept_indices] points = filtered_points boxes = filtered_boxes # Visualize results img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) # for point in neg_points: # x = point[0] * img_w # y = point[1] * img_h # draw.ellipse( # [x - point_radius, y - point_radius, x + point_radius, y + point_radius], # fill="red" # ) count = len(points) return img_draw, f"Count: {count}" # Create Gradio interface def create_demo(): with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: gr.Markdown(""" # CountEx: Fine-Grained Counting via Exemplars and Exclusion Count specific objects in images using positive and negative text prompts. **Important Note: Both the Positive and Negative prompts must end with a period (.) for the model to correctly interpret the instruction.** """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") pos_caption = gr.Textbox( label="Positive Prompt", placeholder="e.g., Green Apple", value="Pos Caption Here." ) neg_caption = gr.Textbox( label="Negative Prompt (optional)", placeholder="e.g., Red Apple", value="None." ) box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.42, step=0.01, label="Detection Threshold (0.42 = use model default)" ) point_radius = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Point Radius" ) point_color = gr.Dropdown( choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"], value="blue", label="Point Color" ) submit_btn = gr.Button("Count Objects", variant="primary") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Result") count_output = gr.Textbox(label="Count Result") # Example images # ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."], gr.Examples( examples=[ ["examples/apples.png", "Green apple.", "Red apple."], ["examples/black_beans.jpg", "Black bean.", "Soy bean."], ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."], ], inputs=[input_image, pos_caption, neg_caption], outputs=[output_image, count_output], fn=count_objects, cache_examples=False, ) submit_btn.click( fn=count_objects, inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color], outputs=[output_image, count_output] ) return demo if __name__ == "__main__": # Load model at startup print("Loading model...") load_model() print("Model loaded!") # Create and launch demo demo = create_demo() demo.launch()