yifehuang97 commited on
Commit
3bef090
·
1 Parent(s): 97e2fd4

(feat) llm parse

Browse files
Files changed (1) hide show
  1. app.py +325 -157
app.py CHANGED
@@ -1,16 +1,86 @@
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image, ImageDraw
5
  from transformers import GroundingDinoProcessor
6
  from hf_model import CountEX
7
  from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
 
8
 
9
  # Global variables for model and processor
10
  model = None
11
  processor = None
12
  device = None
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def load_model():
16
  """Load model and processor once at startup"""
@@ -34,49 +104,6 @@ def load_model():
34
 
35
  import numpy as np
36
 
37
- def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
38
- """
39
- Filter out positive points that are too close to any negative point.
40
-
41
- Args:
42
- points: List of [x, y] positive points (normalized coordinates, 0-1)
43
- neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
44
- image_size: Tuple of (width, height) in pixels
45
- pixel_threshold: Minimum distance threshold in pixels
46
-
47
- Returns:
48
- filtered_points: List of points that are far enough from all negative points
49
- filtered_indices: Indices of the kept points in the original list
50
- """
51
- if not neg_points or not points:
52
- return points, list(range(len(points)))
53
-
54
- width, height = image_size
55
-
56
- points_arr = np.array(points) # (N, 2) normalized
57
- neg_points_arr = np.array(neg_points) # (M, 2) normalized
58
-
59
- # Convert to pixel coordinates
60
- points_pixel = points_arr * np.array([width, height]) # (N, 2)
61
- neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
62
-
63
- # Compute pairwise distances in pixels: (N, M)
64
- diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
65
- distances = np.linalg.norm(diff, axis=-1) # (N, M)
66
-
67
- # Find minimum distance to any negative point for each positive point
68
- min_distances = distances.min(axis=1) # (N,)
69
-
70
- # Keep points where min distance > threshold
71
- keep_mask = min_distances > pixel_threshold
72
-
73
- filtered_points = points_arr[keep_mask].tolist()
74
- filtered_indices = np.where(keep_mask)[0].tolist()
75
-
76
- return filtered_points, filtered_indices
77
-
78
-
79
- import numpy as np
80
 
81
  def discriminative_point_suppression(
82
  points,
@@ -166,35 +193,166 @@ def discriminative_point_suppression(
166
 
167
  return filtered_points, filtered_indices, suppression_info
168
 
169
- def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
 
170
  """
171
  Main inference function for counting objects
172
 
173
  Args:
174
  image: Input PIL Image
175
- pos_caption: Positive prompt (objects to count)
176
- neg_caption: Negative prompt (objects to exclude)
177
  box_threshold: Detection confidence threshold
178
  point_radius: Radius of visualization points
179
  point_color: Color of visualization points
180
 
181
  Returns:
182
- Annotated image and count
183
  """
184
  global model, processor, device
185
 
186
  if model is None:
187
  load_model()
188
 
 
 
 
 
189
  # Ensure image is RGB
190
  if image.mode != "RGB":
191
  image = image.convert("RGB")
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Ensure captions end with period
194
- if not pos_caption.endswith('.'):
195
  pos_caption = pos_caption + '.'
196
  if neg_caption and not neg_caption.endswith('.'):
197
  neg_caption = neg_caption + '.'
 
 
 
 
 
 
 
 
 
198
 
199
  # Process positive caption
200
  pos_inputs = processor(
@@ -206,12 +364,10 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
206
  pos_inputs = pos_inputs.to(device)
207
  pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
208
 
209
- # Process negative caption if provided
210
- use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
211
-
212
 
213
  if not use_neg:
214
- # print('neg_caption: ', neg_caption)
215
  neg_caption = "None."
216
  neg_inputs = processor(
217
  images=image,
@@ -229,31 +385,12 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
229
  pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
230
  pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
231
  pos_inputs['use_neg'] = True
232
- # else:
233
- # neg_caption = "None."
234
- # neg_inputs = processor(
235
- # images=image,
236
- # text=neg_caption,
237
- # return_tensors="pt",
238
- # padding=True
239
- # )
240
- # neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
241
- # neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
242
-
243
- # # Add negative inputs to positive inputs dict
244
- # pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
245
- # pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
246
- # pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
247
- # pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
248
- # pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
249
- # pos_inputs['use_neg'] = False
250
 
251
  # Run inference
252
  with torch.no_grad():
253
  outputs = model(**pos_inputs)
254
 
255
  # Post-process outputs
256
- # positive prediction
257
  outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
258
  outputs["pred_logits"] = outputs["logits"]
259
 
@@ -270,7 +407,9 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
270
  boxes = [box.tolist() for box in boxes]
271
  points = [[box[0], box[1]] for box in boxes]
272
 
273
- # negative prediction
 
 
274
  if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
275
  neg_outputs = outputs.copy()
276
  neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
@@ -283,31 +422,25 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
283
  neg_boxes = [box.tolist() for box in neg_boxes]
284
  neg_points = [[box[0], box[1]] for box in neg_boxes]
285
 
286
- pos_queries = results["queries"]
287
- neg_queries = neg_results["queries"]
288
- pos_queries = pos_queries.cpu().numpy()
289
- neg_queries = neg_queries.cpu().numpy()
290
 
291
  img_size = image.size
292
- # filtered_points, kept_indices = filter_points_by_negative(
293
- # points,
294
- # neg_points,
295
- # image_size=img_size,
296
- # pixel_threshold=5
297
- # )
298
- filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
299
- points,
300
- neg_points,
301
- pos_queries,
302
- neg_queries,
303
- image_size=img_size,
304
- pixel_threshold=5,
305
- similarity_threshold=0.3,
306
- )
307
 
308
- filtered_boxes = [boxes[i] for i in kept_indices]
309
- if "scores" in results:
310
- filtered_scores = [results["scores"][i].item() for i in kept_indices]
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  points = filtered_points
313
  boxes = filtered_boxes
@@ -324,18 +457,10 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
324
  [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
325
  fill=point_color
326
  )
327
-
328
- # for point in neg_points:
329
- # x = point[0] * img_w
330
- # y = point[1] * img_h
331
- # draw.ellipse(
332
- # [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
333
- # fill="red"
334
- # )
335
 
336
  count = len(points)
337
 
338
- return img_draw, f"Count: {count}"
339
 
340
 
341
  # Create Gradio interface
@@ -343,76 +468,119 @@ def create_demo():
343
  with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
344
  gr.Markdown("""
345
  # CountEx: Fine-Grained Counting via Exemplars and Exclusion
346
- Count specific objects in images using positive and negative text prompts.
347
  """)
 
 
 
348
 
349
  with gr.Row():
 
350
  with gr.Column(scale=1):
351
  input_image = gr.Image(type="pil", label="Input Image")
352
-
353
- pos_caption = gr.Textbox(
354
- label="Positive Prompt",
355
- placeholder="e.g., Green Apple",
356
- value="Pos Caption Here."
357
- )
358
-
359
- neg_caption = gr.Textbox(
360
- label="Negative Prompt (optional)",
361
- placeholder="e.g., Red Apple",
362
- value="None."
363
- )
364
-
365
- box_threshold = gr.Slider(
366
- minimum=0.0,
367
- maximum=1.0,
368
- value=0.42,
369
- step=0.01,
370
- label="Detection Threshold (0.42 = use model default)"
371
- )
372
-
373
- point_radius = gr.Slider(
374
- minimum=1,
375
- maximum=20,
376
- value=5,
377
- step=1,
378
- label="Point Radius"
379
- )
380
-
381
- point_color = gr.Dropdown(
382
- choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
383
- value="blue",
384
- label="Point Color"
385
- )
386
-
387
- submit_btn = gr.Button("Count Objects", variant="primary")
388
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  with gr.Column(scale=1):
390
  output_image = gr.Image(type="pil", label="Result")
391
  count_output = gr.Textbox(label="Count Result")
 
392
 
393
- # Example images
394
- # ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."],
395
  gr.Examples(
396
  examples=[
397
- ["examples/apples.png", "apple.", "Green apple."],
398
- ["examples/apple.jpg", "apple.", "red apple."],
399
- ["examples/black_beans.jpg", "Black bean.", "Soy bean."],
400
- ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
401
- ["examples/strawberry.jpg", "strawberry and blueberry.", "strawberry."],
402
- ["examples/strawberry2.jpg", "strawberry and blueberry.", "strawberry."],
403
- ["examples/women.jpg", "person.", "woman."],
404
- ["examples/boat-1.jpg", "boat.", "blue boat."],
405
  ],
406
- inputs=[input_image, pos_caption, neg_caption],
407
- outputs=[output_image, count_output],
408
  fn=count_objects,
409
  cache_examples=False,
410
  )
411
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  submit_btn.click(
413
- fn=count_objects,
414
- inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
415
- outputs=[output_image, count_output]
 
416
  )
417
 
418
  return demo
 
1
  import os
2
+ import json
3
  import gradio as gr
4
  import torch
5
  from PIL import Image, ImageDraw
6
  from transformers import GroundingDinoProcessor
7
  from hf_model import CountEX
8
  from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
9
+ import google.generativeai as genai
10
 
11
  # Global variables for model and processor
12
  model = None
13
  processor = None
14
  device = None
15
 
16
+ # Configure Gemini
17
+ genai.configure(api_key='AIzaSyAoQcUhn_KwOWvjdVqJ1kEaT0zBcnAKppo')
18
+ gemini_model = genai.GenerativeModel("gemini-2.0-flash")
19
+
20
+ PARSING_PROMPT = """Parse the user's counting instruction into two lists:
21
+ - A (include): objects to count
22
+ - B (exclude): objects to exclude from counting
23
+
24
+ Rules:
25
+ 1. Split on "and", "or", and commas
26
+ 2. Reattach shared head nouns (e.g., "red and black beans" → "red beans", "black beans")
27
+ 3. Remove from B items that are equivalent to A (synonyms/variants/abbreviations)
28
+ 4. Remove from B items that are more specific than A
29
+ 5. If B is more general than A but shares head noun, rewrite B to specific non-overlapping forms
30
+
31
+ Examples:
32
+ - "Count green apples, not red apples" → A: ["green apples"], B: ["red apples"]
33
+ - "Count apples, not green apples" → A: ["apples"], B: []
34
+ - "Count green apples, not apples" → A: ["green apples"], B: ["non-green apples"]
35
+ - "Count fries, not chips" → A: ["fries"], B: []
36
+ - "Count black beans, not poker chips" → A: ["black beans"], B: ["poker chips"]
37
+
38
+ User instruction: {instruction}
39
+
40
+ Respond ONLY with a JSON object in this exact format, no other text:
41
+ {{"A": ["item1", "item2"], "B": ["item3"]}}
42
+ """
43
+
44
+
45
+ def parse_counting_instruction(instruction: str) -> tuple[str, str]:
46
+ """
47
+ Parse natural language counting instruction using Gemini 2.0 Flash.
48
+
49
+ Args:
50
+ instruction: Natural language instruction like "count apples, not green apples"
51
+
52
+ Returns:
53
+ tuple: (positive_caption, negative_caption)
54
+ """
55
+ try:
56
+ prompt = PARSING_PROMPT.format(instruction=instruction)
57
+ response = gemini_model.generate_content(prompt)
58
+ response_text = response.text.strip()
59
+
60
+ # Clean up response - remove markdown code blocks if present
61
+ if response_text.startswith("```"):
62
+ response_text = response_text.split("```")[1]
63
+ if response_text.startswith("json"):
64
+ response_text = response_text[4:]
65
+ response_text = response_text.strip()
66
+
67
+ result = json.loads(response_text)
68
+
69
+ # Convert lists to caption strings
70
+ pos_items = result.get("A", [])
71
+ neg_items = result.get("B", [])
72
+
73
+ # Join items with " and " and add period
74
+ pos_caption = " and ".join(pos_items) + "." if pos_items else ""
75
+ neg_caption = " and ".join(neg_items) + "." if neg_items else "None."
76
+
77
+ return pos_caption, neg_caption
78
+
79
+ except Exception as e:
80
+ print(f"Error parsing instruction: {e}")
81
+ # Fallback: treat entire instruction as positive caption
82
+ return instruction.strip() + ".", "None."
83
+
84
 
85
  def load_model():
86
  """Load model and processor once at startup"""
 
104
 
105
  import numpy as np
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def discriminative_point_suppression(
109
  points,
 
193
 
194
  return filtered_points, filtered_indices, suppression_info
195
 
196
+
197
+ def count_objects(image, instruction, box_threshold, point_radius, point_color):
198
  """
199
  Main inference function for counting objects
200
 
201
  Args:
202
  image: Input PIL Image
203
+ instruction: Natural language instruction (e.g., "count apples, not green apples")
 
204
  box_threshold: Detection confidence threshold
205
  point_radius: Radius of visualization points
206
  point_color: Color of visualization points
207
 
208
  Returns:
209
+ Annotated image, count, and parsed captions
210
  """
211
  global model, processor, device
212
 
213
  if model is None:
214
  load_model()
215
 
216
+ # Parse instruction using Gemini
217
+ pos_caption, neg_caption = parse_counting_instruction(instruction)
218
+ parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
219
+
220
  # Ensure image is RGB
221
  if image.mode != "RGB":
222
  image = image.convert("RGB")
223
 
224
+ # Process positive caption
225
+ pos_inputs = processor(
226
+ images=image,
227
+ text=pos_caption,
228
+ return_tensors="pt",
229
+ padding=True
230
+ )
231
+ pos_inputs = pos_inputs.to(device)
232
+ pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
233
+
234
+ # Process negative caption
235
+ use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
236
+
237
+ if not use_neg:
238
+ neg_caption = "None."
239
+ neg_inputs = processor(
240
+ images=image,
241
+ text=neg_caption,
242
+ return_tensors="pt",
243
+ padding=True
244
+ )
245
+ neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
246
+ neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
247
+
248
+ # Add negative inputs to positive inputs dict
249
+ pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
250
+ pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
251
+ pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
252
+ pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
253
+ pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
254
+ pos_inputs['use_neg'] = True
255
+
256
+ # Run inference
257
+ with torch.no_grad():
258
+ outputs = model(**pos_inputs)
259
+
260
+ # Post-process outputs
261
+ outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
262
+ outputs["pred_logits"] = outputs["logits"]
263
+
264
+ threshold = box_threshold if box_threshold > 0 else model.box_threshold
265
+ pos_queries = outputs["pos_queries"].squeeze(0).float()
266
+ neg_queries = outputs["neg_queries"].squeeze(0).float()
267
+ pos_queries = pos_queries[-1].squeeze(0)
268
+ neg_queries = neg_queries[-1].squeeze(0)
269
+ pos_queries = pos_queries.unsqueeze(0)
270
+ neg_queries = neg_queries.unsqueeze(0)
271
+ results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
272
+
273
+ boxes = results["boxes"]
274
+ boxes = [box.tolist() for box in boxes]
275
+ points = [[box[0], box[1]] for box in boxes]
276
+
277
+ # Negative prediction
278
+ neg_points = []
279
+ neg_results = None
280
+ if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
281
+ neg_outputs = outputs.copy()
282
+ neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
283
+ neg_outputs["logits"] = outputs["neg_logits"]
284
+ neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
285
+ neg_outputs["pred_logits"] = outputs["neg_logits"]
286
+
287
+ neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
288
+ neg_boxes = neg_results["boxes"]
289
+ neg_boxes = [box.tolist() for box in neg_boxes]
290
+ neg_points = [[box[0], box[1]] for box in neg_boxes]
291
+
292
+ pos_queries_np = results["queries"].cpu().numpy()
293
+ neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
294
+
295
+ img_size = image.size
296
+
297
+ if len(neg_points) > 0 and len(neg_queries_np) > 0:
298
+ filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
299
+ points,
300
+ neg_points,
301
+ pos_queries_np,
302
+ neg_queries_np,
303
+ image_size=img_size,
304
+ pixel_threshold=5,
305
+ similarity_threshold=0.3,
306
+ )
307
+ filtered_boxes = [boxes[i] for i in kept_indices]
308
+ else:
309
+ filtered_points = points
310
+ filtered_boxes = boxes
311
+
312
+ points = filtered_points
313
+ boxes = filtered_boxes
314
+
315
+ # Visualize results
316
+ img_w, img_h = image.size
317
+ img_draw = image.copy()
318
+ draw = ImageDraw.Draw(img_draw)
319
+
320
+ for point in points:
321
+ x = point[0] * img_w
322
+ y = point[1] * img_h
323
+ draw.ellipse(
324
+ [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
325
+ fill=point_color
326
+ )
327
+
328
+ count = len(points)
329
+
330
+ return img_draw, f"Count: {count}", parsed_info
331
+
332
+
333
+ def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
334
+ """
335
+ Manual mode: directly use provided positive and negative captions.
336
+ """
337
+ global model, processor, device
338
+
339
+ if model is None:
340
+ load_model()
341
+
342
  # Ensure captions end with period
343
+ if pos_caption and not pos_caption.endswith('.'):
344
  pos_caption = pos_caption + '.'
345
  if neg_caption and not neg_caption.endswith('.'):
346
  neg_caption = neg_caption + '.'
347
+
348
+ if not neg_caption or neg_caption.strip() == '':
349
+ neg_caption = "None."
350
+
351
+ parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
352
+
353
+ # Ensure image is RGB
354
+ if image.mode != "RGB":
355
+ image = image.convert("RGB")
356
 
357
  # Process positive caption
358
  pos_inputs = processor(
 
364
  pos_inputs = pos_inputs.to(device)
365
  pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
366
 
367
+ # Process negative caption
368
+ use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
 
369
 
370
  if not use_neg:
 
371
  neg_caption = "None."
372
  neg_inputs = processor(
373
  images=image,
 
385
  pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
386
  pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
387
  pos_inputs['use_neg'] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  # Run inference
390
  with torch.no_grad():
391
  outputs = model(**pos_inputs)
392
 
393
  # Post-process outputs
 
394
  outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
395
  outputs["pred_logits"] = outputs["logits"]
396
 
 
407
  boxes = [box.tolist() for box in boxes]
408
  points = [[box[0], box[1]] for box in boxes]
409
 
410
+ # Negative prediction
411
+ neg_points = []
412
+ neg_results = None
413
  if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
414
  neg_outputs = outputs.copy()
415
  neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
 
422
  neg_boxes = [box.tolist() for box in neg_boxes]
423
  neg_points = [[box[0], box[1]] for box in neg_boxes]
424
 
425
+ pos_queries_np = results["queries"].cpu().numpy()
426
+ neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
 
 
427
 
428
  img_size = image.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ if len(neg_points) > 0 and len(neg_queries_np) > 0:
431
+ filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
432
+ points,
433
+ neg_points,
434
+ pos_queries_np,
435
+ neg_queries_np,
436
+ image_size=img_size,
437
+ pixel_threshold=5,
438
+ similarity_threshold=0.3,
439
+ )
440
+ filtered_boxes = [boxes[i] for i in kept_indices]
441
+ else:
442
+ filtered_points = points
443
+ filtered_boxes = boxes
444
 
445
  points = filtered_points
446
  boxes = filtered_boxes
 
457
  [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
458
  fill=point_color
459
  )
 
 
 
 
 
 
 
 
460
 
461
  count = len(points)
462
 
463
+ return img_draw, f"Count: {count}", parsed_info
464
 
465
 
466
  # Create Gradio interface
 
468
  with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
469
  gr.Markdown("""
470
  # CountEx: Fine-Grained Counting via Exemplars and Exclusion
471
+ Count specific objects in images using text prompts with exclusion capability.
472
  """)
473
+
474
+ # State to track current input mode
475
+ current_mode = gr.State(value="natural_language")
476
 
477
  with gr.Row():
478
+ # Left column - Input
479
  with gr.Column(scale=1):
480
  input_image = gr.Image(type="pil", label="Input Image")
481
+
482
+ with gr.Tabs() as input_tabs:
483
+ # Tab 1: Natural Language Input
484
+ with gr.TabItem("Natural Language", id=0) as tab_nl:
485
+ instruction = gr.Textbox(
486
+ label="Counting Instruction",
487
+ placeholder="e.g., Count apples, not green apples",
488
+ value="Count apples, not green apples",
489
+ lines=2
490
+ )
491
+ gr.Markdown("""
492
+ **Examples:**
493
+ - "Count apples, not green apples"
494
+ - "Count red and black beans, exclude white beans"
495
+ - "Count people, not women"
496
+ """)
497
+
498
+ # Tab 2: Manual Input
499
+ with gr.TabItem("Manual Input", id=1) as tab_manual:
500
+ pos_caption = gr.Textbox(
501
+ label="Positive Prompt (objects to count)",
502
+ placeholder="e.g., apple",
503
+ value="apple."
504
+ )
505
+ neg_caption = gr.Textbox(
506
+ label="Negative Prompt (objects to exclude)",
507
+ placeholder="e.g., green apple",
508
+ value="None."
509
+ )
510
+
511
+ # Single submit button outside tabs
512
+ submit_btn = gr.Button("Count Objects", variant="primary", size="lg")
513
+
514
+ # Shared settings
515
+ with gr.Accordion("Advanced Settings", open=False):
516
+ box_threshold = gr.Slider(
517
+ minimum=0.0,
518
+ maximum=1.0,
519
+ value=0.42,
520
+ step=0.01,
521
+ label="Detection Threshold"
522
+ )
523
+ point_radius = gr.Slider(
524
+ minimum=1,
525
+ maximum=20,
526
+ value=5,
527
+ step=1,
528
+ label="Point Radius"
529
+ )
530
+ point_color = gr.Dropdown(
531
+ choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
532
+ value="blue",
533
+ label="Point Color"
534
+ )
535
+
536
+ # Right column - Output
537
  with gr.Column(scale=1):
538
  output_image = gr.Image(type="pil", label="Result")
539
  count_output = gr.Textbox(label="Count Result")
540
+ parsed_output = gr.Textbox(label="Parsed Captions", lines=2)
541
 
542
+ # Examples for Natural Language mode
543
+ gr.Markdown("### Examples (Natural Language)")
544
  gr.Examples(
545
  examples=[
546
+ ["examples/apples.png", "Count apples, exclude green apples"],
547
+ ["examples/apple.jpg", "Count apples, not red apples"],
548
+ ["examples/black_beans.jpg", "Count black beans, not soy beans"],
549
+ ["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"],
550
+ ["examples/strawberry.jpg", "Count blueberries"],
551
+ ["examples/strawberry2.jpg", "Count blueberries"],
552
+ ["examples/women.jpg", "Count people, not women"],
553
+ ["examples/boat-1.jpg", "Count boats, exclude blue boats"],
554
  ],
555
+ inputs=[input_image, instruction],
556
+ outputs=[output_image, count_output, parsed_output],
557
  fn=count_objects,
558
  cache_examples=False,
559
  )
560
+
561
+ # Update mode when tab changes
562
+ def set_mode_nl():
563
+ return "natural_language"
564
+
565
+ def set_mode_manual():
566
+ return "manual"
567
+
568
+ tab_nl.select(fn=set_mode_nl, outputs=[current_mode])
569
+ tab_manual.select(fn=set_mode_manual, outputs=[current_mode])
570
+
571
+ # Unified handler that routes based on mode
572
+ def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color):
573
+ if mode == "natural_language":
574
+ return count_objects(image, instr, threshold, radius, color)
575
+ else:
576
+ return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color)
577
+
578
+ # Single button click handler
579
  submit_btn.click(
580
+ fn=handle_submit,
581
+ inputs=[current_mode, input_image, instruction, pos_caption, neg_caption,
582
+ box_threshold, point_radius, point_color],
583
+ outputs=[output_image, count_output, parsed_output]
584
  )
585
 
586
  return demo