yifehuang97 commited on
Commit
03ae022
·
verified ·
1 Parent(s): de9b038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +409 -409
app.py CHANGED
@@ -1,410 +1,410 @@
1
- import os
2
- import gradio as gr
3
- import torch
4
- from PIL import Image, ImageDraw
5
- from transformers import GroundingDinoProcessor
6
- from hf_model import CountEX
7
- from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
8
-
9
- # Global variables for model and processor
10
- model = None
11
- processor = None
12
- device = None
13
-
14
-
15
- def load_model():
16
- """Load model and processor once at startup"""
17
- global model, processor, device
18
-
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- # Load model - change path for HF Spaces
22
- model_id = "yifehuang97/CountEx_KC_aug_full_ft_20251130-v2" # Change to your HF model repo
23
- model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
24
- model = model.to(torch.bfloat16)
25
- model = model.to(device)
26
- model.eval()
27
-
28
- # Load processor
29
- processor_id = "fushh7/llmdet_swin_tiny_hf"
30
- processor = GroundingDinoProcessor.from_pretrained(processor_id)
31
-
32
- return model, processor, device
33
-
34
-
35
- import numpy as np
36
-
37
- def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
38
- """
39
- Filter out positive points that are too close to any negative point.
40
-
41
- Args:
42
- points: List of [x, y] positive points (normalized coordinates, 0-1)
43
- neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
44
- image_size: Tuple of (width, height) in pixels
45
- pixel_threshold: Minimum distance threshold in pixels
46
-
47
- Returns:
48
- filtered_points: List of points that are far enough from all negative points
49
- filtered_indices: Indices of the kept points in the original list
50
- """
51
- if not neg_points or not points:
52
- return points, list(range(len(points)))
53
-
54
- width, height = image_size
55
-
56
- points_arr = np.array(points) # (N, 2) normalized
57
- neg_points_arr = np.array(neg_points) # (M, 2) normalized
58
-
59
- # Convert to pixel coordinates
60
- points_pixel = points_arr * np.array([width, height]) # (N, 2)
61
- neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
62
-
63
- # Compute pairwise distances in pixels: (N, M)
64
- diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
65
- distances = np.linalg.norm(diff, axis=-1) # (N, M)
66
-
67
- # Find minimum distance to any negative point for each positive point
68
- min_distances = distances.min(axis=1) # (N,)
69
-
70
- # Keep points where min distance > threshold
71
- keep_mask = min_distances > pixel_threshold
72
-
73
- filtered_points = points_arr[keep_mask].tolist()
74
- filtered_indices = np.where(keep_mask)[0].tolist()
75
-
76
- return filtered_points, filtered_indices
77
-
78
-
79
- import numpy as np
80
-
81
- def discriminative_point_suppression(
82
- points,
83
- neg_points,
84
- pos_queries, # (N, D) numpy array
85
- neg_queries, # (M, D) numpy array
86
- image_size,
87
- pixel_threshold=5,
88
- similarity_threshold=0.3,
89
- ):
90
- """
91
- Discriminative Point Suppression (DPS):
92
-
93
- Step 1: Find spatially closest negative point for each positive point
94
- Step 2: If distance < pixel_threshold, check query similarity
95
- Step 3: Suppress only if query similarity > similarity_threshold
96
-
97
- This two-stage design ensures suppression only when predictions are
98
- both spatially overlapping AND semantically conflicting.
99
-
100
- Args:
101
- points: List of [x, y] positive points (normalized, 0-1)
102
- neg_points: List of [x, y] negative points (normalized, 0-1)
103
- pos_queries: (N, D) query embeddings for positive predictions
104
- neg_queries: (M, D) query embeddings for negative predictions
105
- image_size: (width, height) in pixels
106
- pixel_threshold: spatial distance threshold in pixels
107
- similarity_threshold: cosine similarity threshold for semantic conflict
108
-
109
- Returns:
110
- filtered_points: points after suppression
111
- filtered_indices: indices of kept points
112
- suppression_info: dict with detailed suppression decisions
113
- """
114
- if not neg_points or not points:
115
- return points, list(range(len(points))), {}
116
-
117
- width, height = image_size
118
- N, M = len(points), len(neg_points)
119
-
120
- # === Step 1: Spatial Matching ===
121
- points_arr = np.array(points) * np.array([width, height]) # (N, 2)
122
- neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
123
-
124
- # Compute pairwise distances
125
- spatial_dist = np.linalg.norm(
126
- points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
127
- ) # (N, M)
128
-
129
- # Find nearest negative for each positive
130
- nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
131
- nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
132
-
133
- # Check spatial condition
134
- spatially_close = nearest_neg_dist < pixel_threshold # (N,)
135
-
136
- # === Step 2: Query Similarity Check (only for spatially close pairs) ===
137
- # Normalize queries
138
- pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
139
- neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
140
-
141
- # Compute similarity only for matched pairs
142
- matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
143
- query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
144
-
145
- # Check semantic condition
146
- semantically_similar = query_sim > similarity_threshold # (N,)
147
-
148
- # === Step 3: Joint Decision ===
149
- # Suppress only if BOTH conditions are met
150
- should_suppress = spatially_close & semantically_similar # (N,)
151
-
152
- # === Filter ===
153
- keep_mask = ~should_suppress
154
- filtered_points = np.array(points)[keep_mask].tolist()
155
- filtered_indices = np.where(keep_mask)[0].tolist()
156
-
157
- # === Suppression Info ===
158
- suppression_info = {
159
- "nearest_neg_idx": nearest_neg_idx.tolist(),
160
- "nearest_neg_dist": nearest_neg_dist.tolist(),
161
- "query_similarity": query_sim.tolist(),
162
- "spatially_close": spatially_close.tolist(),
163
- "semantically_similar": semantically_similar.tolist(),
164
- "suppressed_indices": np.where(should_suppress)[0].tolist(),
165
- }
166
-
167
- return filtered_points, filtered_indices, suppression_info
168
-
169
- def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
170
- """
171
- Main inference function for counting objects
172
-
173
- Args:
174
- image: Input PIL Image
175
- pos_caption: Positive prompt (objects to count)
176
- neg_caption: Negative prompt (objects to exclude)
177
- box_threshold: Detection confidence threshold
178
- point_radius: Radius of visualization points
179
- point_color: Color of visualization points
180
-
181
- Returns:
182
- Annotated image and count
183
- """
184
- global model, processor, device
185
-
186
- if model is None:
187
- load_model()
188
-
189
- # Ensure image is RGB
190
- if image.mode != "RGB":
191
- image = image.convert("RGB")
192
-
193
- # Ensure captions end with period
194
- if not pos_caption.endswith('.'):
195
- pos_caption = pos_caption + '.'
196
- if neg_caption and not neg_caption.endswith('.'):
197
- neg_caption = neg_caption + '.'
198
-
199
- # Process positive caption
200
- pos_inputs = processor(
201
- images=image,
202
- text=pos_caption,
203
- return_tensors="pt",
204
- padding=True
205
- )
206
- pos_inputs = pos_inputs.to(device)
207
- pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
208
-
209
- # Process negative caption if provided
210
- use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
211
-
212
- if use_neg:
213
- neg_inputs = processor(
214
- images=image,
215
- text=neg_caption,
216
- return_tensors="pt",
217
- padding=True
218
- )
219
- neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
220
- neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
221
-
222
- # Add negative inputs to positive inputs dict
223
- pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
224
- pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
225
- pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
226
- pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
227
- pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
228
- pos_inputs['use_neg'] = True
229
- else:
230
- pos_inputs['use_neg'] = False
231
-
232
- # Run inference
233
- with torch.no_grad():
234
- outputs = model(**pos_inputs)
235
-
236
- # Post-process outputs
237
- # positive prediction
238
- outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
239
- outputs["pred_logits"] = outputs["logits"]
240
-
241
- threshold = box_threshold if box_threshold > 0 else model.box_threshold
242
- pos_queries = outputs["pos_queries"].squeeze(0).float()
243
- neg_queries = outputs["neg_queries"].squeeze(0).float()
244
- pos_queries = pos_queries[-1].squeeze(0)
245
- neg_queries = neg_queries[-1].squeeze(0)
246
- pos_queries = pos_queries.unsqueeze(0)
247
- neg_queries = neg_queries.unsqueeze(0)
248
- results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
249
-
250
- boxes = results["boxes"]
251
- boxes = [box.tolist() for box in boxes]
252
- points = [[box[0], box[1]] for box in boxes]
253
-
254
- # negative prediction
255
- if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
256
- neg_outputs = outputs.copy()
257
- neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
258
- neg_outputs["logits"] = outputs["neg_logits"]
259
- neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
260
- neg_outputs["pred_logits"] = outputs["neg_logits"]
261
-
262
- neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
263
- neg_boxes = neg_results["boxes"]
264
- neg_boxes = [box.tolist() for box in neg_boxes]
265
- neg_points = [[box[0], box[1]] for box in neg_boxes]
266
-
267
- pos_queries = results["queries"]
268
- neg_queries = neg_results["queries"]
269
- pos_queries = pos_queries.cpu().numpy()
270
- neg_queries = neg_queries.cpu().numpy()
271
-
272
- img_size = image.size
273
- # filtered_points, kept_indices = filter_points_by_negative(
274
- # points,
275
- # neg_points,
276
- # image_size=img_size,
277
- # pixel_threshold=5
278
- # )
279
- filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
280
- points,
281
- neg_points,
282
- pos_queries,
283
- neg_queries,
284
- image_size=img_size,
285
- pixel_threshold=5,
286
- similarity_threshold=0.3,
287
- )
288
-
289
- filtered_boxes = [boxes[i] for i in kept_indices]
290
- if "scores" in results:
291
- filtered_scores = [results["scores"][i].item() for i in kept_indices]
292
-
293
- points = filtered_points
294
- boxes = filtered_boxes
295
-
296
- # Visualize results
297
- img_w, img_h = image.size
298
- img_draw = image.copy()
299
- draw = ImageDraw.Draw(img_draw)
300
-
301
- for point in points:
302
- x = point[0] * img_w
303
- y = point[1] * img_h
304
- draw.ellipse(
305
- [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
306
- fill=point_color
307
- )
308
-
309
- # for point in neg_points:
310
- # x = point[0] * img_w
311
- # y = point[1] * img_h
312
- # draw.ellipse(
313
- # [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
314
- # fill="red"
315
- # )
316
-
317
- count = len(points)
318
-
319
- return img_draw, f"Count: {count}"
320
-
321
-
322
- # Create Gradio interface
323
- def create_demo():
324
- with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
325
- gr.Markdown("""
326
- # CountEx: Fine-Grained Counting via Exemplars and Exclusion
327
- Count specific objects in images using positive and negative text prompts.
328
- """)
329
-
330
- with gr.Row():
331
- with gr.Column(scale=1):
332
- input_image = gr.Image(type="pil", label="Input Image")
333
-
334
- pos_caption = gr.Textbox(
335
- label="Positive Prompt",
336
- placeholder="e.g., Green Apple",
337
- value="Pos Caption Here."
338
- )
339
-
340
- neg_caption = gr.Textbox(
341
- label="Negative Prompt (optional)",
342
- placeholder="e.g., Red Apple",
343
- value="None."
344
- )
345
-
346
- box_threshold = gr.Slider(
347
- minimum=0.0,
348
- maximum=1.0,
349
- value=0.42,
350
- step=0.01,
351
- label="Detection Threshold (0.42 = use model default)"
352
- )
353
-
354
- point_radius = gr.Slider(
355
- minimum=1,
356
- maximum=20,
357
- value=5,
358
- step=1,
359
- label="Point Radius"
360
- )
361
-
362
- point_color = gr.Dropdown(
363
- choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
364
- value="blue",
365
- label="Point Color"
366
- )
367
-
368
- submit_btn = gr.Button("Count Objects", variant="primary")
369
-
370
- with gr.Column(scale=1):
371
- output_image = gr.Image(type="pil", label="Result")
372
- count_output = gr.Textbox(label="Count Result")
373
-
374
- # Example images
375
- # ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."],
376
- gr.Examples(
377
- examples=[
378
- ["examples/apples.png", "apple.", "Green apple."],
379
- ["examples/apple.jpg", "apple.", "red apple."],
380
- ["examples/black_beans.jpg", "Black bean.", "Soy bean."],
381
- ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
382
- ["examples/strawberry.jpg", "strawberry and blueberry.", "strawberry."],
383
- ["examples/strawberry2.jpg", "strawberry and blueberry.", "strawberry."],
384
- ["examples/women.jpg", "person.", "woman."],
385
- ["examples/boat-1.jpg", "boat.", "blue boat."],
386
- ],
387
- inputs=[input_image, pos_caption, neg_caption],
388
- outputs=[output_image, count_output],
389
- fn=count_objects,
390
- cache_examples=False,
391
- )
392
-
393
- submit_btn.click(
394
- fn=count_objects,
395
- inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
396
- outputs=[output_image, count_output]
397
- )
398
-
399
- return demo
400
-
401
-
402
- if __name__ == "__main__":
403
- # Load model at startup
404
- print("Loading model...")
405
- load_model()
406
- print("Model loaded!")
407
-
408
- # Create and launch demo
409
- demo = create_demo()
410
  demo.launch()
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image, ImageDraw
5
+ from transformers import GroundingDinoProcessor
6
+ from hf_model import CountEX
7
+ from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
8
+
9
+ # Global variables for model and processor
10
+ model = None
11
+ processor = None
12
+ device = None
13
+
14
+
15
+ def load_model():
16
+ """Load model and processor once at startup"""
17
+ global model, processor, device
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # Load model - change path for HF Spaces
22
+ model_id = "yifehuang97/CountEX-KC-v2" # Change to your HF model repo
23
+ model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
24
+ model = model.to(torch.bfloat16)
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # Load processor
29
+ processor_id = "fushh7/llmdet_swin_tiny_hf"
30
+ processor = GroundingDinoProcessor.from_pretrained(processor_id)
31
+
32
+ return model, processor, device
33
+
34
+
35
+ import numpy as np
36
+
37
+ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
38
+ """
39
+ Filter out positive points that are too close to any negative point.
40
+
41
+ Args:
42
+ points: List of [x, y] positive points (normalized coordinates, 0-1)
43
+ neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
44
+ image_size: Tuple of (width, height) in pixels
45
+ pixel_threshold: Minimum distance threshold in pixels
46
+
47
+ Returns:
48
+ filtered_points: List of points that are far enough from all negative points
49
+ filtered_indices: Indices of the kept points in the original list
50
+ """
51
+ if not neg_points or not points:
52
+ return points, list(range(len(points)))
53
+
54
+ width, height = image_size
55
+
56
+ points_arr = np.array(points) # (N, 2) normalized
57
+ neg_points_arr = np.array(neg_points) # (M, 2) normalized
58
+
59
+ # Convert to pixel coordinates
60
+ points_pixel = points_arr * np.array([width, height]) # (N, 2)
61
+ neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
62
+
63
+ # Compute pairwise distances in pixels: (N, M)
64
+ diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
65
+ distances = np.linalg.norm(diff, axis=-1) # (N, M)
66
+
67
+ # Find minimum distance to any negative point for each positive point
68
+ min_distances = distances.min(axis=1) # (N,)
69
+
70
+ # Keep points where min distance > threshold
71
+ keep_mask = min_distances > pixel_threshold
72
+
73
+ filtered_points = points_arr[keep_mask].tolist()
74
+ filtered_indices = np.where(keep_mask)[0].tolist()
75
+
76
+ return filtered_points, filtered_indices
77
+
78
+
79
+ import numpy as np
80
+
81
+ def discriminative_point_suppression(
82
+ points,
83
+ neg_points,
84
+ pos_queries, # (N, D) numpy array
85
+ neg_queries, # (M, D) numpy array
86
+ image_size,
87
+ pixel_threshold=5,
88
+ similarity_threshold=0.3,
89
+ ):
90
+ """
91
+ Discriminative Point Suppression (DPS):
92
+
93
+ Step 1: Find spatially closest negative point for each positive point
94
+ Step 2: If distance < pixel_threshold, check query similarity
95
+ Step 3: Suppress only if query similarity > similarity_threshold
96
+
97
+ This two-stage design ensures suppression only when predictions are
98
+ both spatially overlapping AND semantically conflicting.
99
+
100
+ Args:
101
+ points: List of [x, y] positive points (normalized, 0-1)
102
+ neg_points: List of [x, y] negative points (normalized, 0-1)
103
+ pos_queries: (N, D) query embeddings for positive predictions
104
+ neg_queries: (M, D) query embeddings for negative predictions
105
+ image_size: (width, height) in pixels
106
+ pixel_threshold: spatial distance threshold in pixels
107
+ similarity_threshold: cosine similarity threshold for semantic conflict
108
+
109
+ Returns:
110
+ filtered_points: points after suppression
111
+ filtered_indices: indices of kept points
112
+ suppression_info: dict with detailed suppression decisions
113
+ """
114
+ if not neg_points or not points:
115
+ return points, list(range(len(points))), {}
116
+
117
+ width, height = image_size
118
+ N, M = len(points), len(neg_points)
119
+
120
+ # === Step 1: Spatial Matching ===
121
+ points_arr = np.array(points) * np.array([width, height]) # (N, 2)
122
+ neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
123
+
124
+ # Compute pairwise distances
125
+ spatial_dist = np.linalg.norm(
126
+ points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
127
+ ) # (N, M)
128
+
129
+ # Find nearest negative for each positive
130
+ nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
131
+ nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
132
+
133
+ # Check spatial condition
134
+ spatially_close = nearest_neg_dist < pixel_threshold # (N,)
135
+
136
+ # === Step 2: Query Similarity Check (only for spatially close pairs) ===
137
+ # Normalize queries
138
+ pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
139
+ neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
140
+
141
+ # Compute similarity only for matched pairs
142
+ matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
143
+ query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
144
+
145
+ # Check semantic condition
146
+ semantically_similar = query_sim > similarity_threshold # (N,)
147
+
148
+ # === Step 3: Joint Decision ===
149
+ # Suppress only if BOTH conditions are met
150
+ should_suppress = spatially_close & semantically_similar # (N,)
151
+
152
+ # === Filter ===
153
+ keep_mask = ~should_suppress
154
+ filtered_points = np.array(points)[keep_mask].tolist()
155
+ filtered_indices = np.where(keep_mask)[0].tolist()
156
+
157
+ # === Suppression Info ===
158
+ suppression_info = {
159
+ "nearest_neg_idx": nearest_neg_idx.tolist(),
160
+ "nearest_neg_dist": nearest_neg_dist.tolist(),
161
+ "query_similarity": query_sim.tolist(),
162
+ "spatially_close": spatially_close.tolist(),
163
+ "semantically_similar": semantically_similar.tolist(),
164
+ "suppressed_indices": np.where(should_suppress)[0].tolist(),
165
+ }
166
+
167
+ return filtered_points, filtered_indices, suppression_info
168
+
169
+ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
170
+ """
171
+ Main inference function for counting objects
172
+
173
+ Args:
174
+ image: Input PIL Image
175
+ pos_caption: Positive prompt (objects to count)
176
+ neg_caption: Negative prompt (objects to exclude)
177
+ box_threshold: Detection confidence threshold
178
+ point_radius: Radius of visualization points
179
+ point_color: Color of visualization points
180
+
181
+ Returns:
182
+ Annotated image and count
183
+ """
184
+ global model, processor, device
185
+
186
+ if model is None:
187
+ load_model()
188
+
189
+ # Ensure image is RGB
190
+ if image.mode != "RGB":
191
+ image = image.convert("RGB")
192
+
193
+ # Ensure captions end with period
194
+ if not pos_caption.endswith('.'):
195
+ pos_caption = pos_caption + '.'
196
+ if neg_caption and not neg_caption.endswith('.'):
197
+ neg_caption = neg_caption + '.'
198
+
199
+ # Process positive caption
200
+ pos_inputs = processor(
201
+ images=image,
202
+ text=pos_caption,
203
+ return_tensors="pt",
204
+ padding=True
205
+ )
206
+ pos_inputs = pos_inputs.to(device)
207
+ pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
208
+
209
+ # Process negative caption if provided
210
+ use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
211
+
212
+ if use_neg:
213
+ neg_inputs = processor(
214
+ images=image,
215
+ text=neg_caption,
216
+ return_tensors="pt",
217
+ padding=True
218
+ )
219
+ neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
220
+ neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
221
+
222
+ # Add negative inputs to positive inputs dict
223
+ pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
224
+ pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
225
+ pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
226
+ pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
227
+ pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
228
+ pos_inputs['use_neg'] = True
229
+ else:
230
+ pos_inputs['use_neg'] = False
231
+
232
+ # Run inference
233
+ with torch.no_grad():
234
+ outputs = model(**pos_inputs)
235
+
236
+ # Post-process outputs
237
+ # positive prediction
238
+ outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
239
+ outputs["pred_logits"] = outputs["logits"]
240
+
241
+ threshold = box_threshold if box_threshold > 0 else model.box_threshold
242
+ pos_queries = outputs["pos_queries"].squeeze(0).float()
243
+ neg_queries = outputs["neg_queries"].squeeze(0).float()
244
+ pos_queries = pos_queries[-1].squeeze(0)
245
+ neg_queries = neg_queries[-1].squeeze(0)
246
+ pos_queries = pos_queries.unsqueeze(0)
247
+ neg_queries = neg_queries.unsqueeze(0)
248
+ results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
249
+
250
+ boxes = results["boxes"]
251
+ boxes = [box.tolist() for box in boxes]
252
+ points = [[box[0], box[1]] for box in boxes]
253
+
254
+ # negative prediction
255
+ if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
256
+ neg_outputs = outputs.copy()
257
+ neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
258
+ neg_outputs["logits"] = outputs["neg_logits"]
259
+ neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
260
+ neg_outputs["pred_logits"] = outputs["neg_logits"]
261
+
262
+ neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
263
+ neg_boxes = neg_results["boxes"]
264
+ neg_boxes = [box.tolist() for box in neg_boxes]
265
+ neg_points = [[box[0], box[1]] for box in neg_boxes]
266
+
267
+ pos_queries = results["queries"]
268
+ neg_queries = neg_results["queries"]
269
+ pos_queries = pos_queries.cpu().numpy()
270
+ neg_queries = neg_queries.cpu().numpy()
271
+
272
+ img_size = image.size
273
+ # filtered_points, kept_indices = filter_points_by_negative(
274
+ # points,
275
+ # neg_points,
276
+ # image_size=img_size,
277
+ # pixel_threshold=5
278
+ # )
279
+ filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
280
+ points,
281
+ neg_points,
282
+ pos_queries,
283
+ neg_queries,
284
+ image_size=img_size,
285
+ pixel_threshold=5,
286
+ similarity_threshold=0.3,
287
+ )
288
+
289
+ filtered_boxes = [boxes[i] for i in kept_indices]
290
+ if "scores" in results:
291
+ filtered_scores = [results["scores"][i].item() for i in kept_indices]
292
+
293
+ points = filtered_points
294
+ boxes = filtered_boxes
295
+
296
+ # Visualize results
297
+ img_w, img_h = image.size
298
+ img_draw = image.copy()
299
+ draw = ImageDraw.Draw(img_draw)
300
+
301
+ for point in points:
302
+ x = point[0] * img_w
303
+ y = point[1] * img_h
304
+ draw.ellipse(
305
+ [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
306
+ fill=point_color
307
+ )
308
+
309
+ # for point in neg_points:
310
+ # x = point[0] * img_w
311
+ # y = point[1] * img_h
312
+ # draw.ellipse(
313
+ # [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
314
+ # fill="red"
315
+ # )
316
+
317
+ count = len(points)
318
+
319
+ return img_draw, f"Count: {count}"
320
+
321
+
322
+ # Create Gradio interface
323
+ def create_demo():
324
+ with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
325
+ gr.Markdown("""
326
+ # CountEx: Fine-Grained Counting via Exemplars and Exclusion
327
+ Count specific objects in images using positive and negative text prompts.
328
+ """)
329
+
330
+ with gr.Row():
331
+ with gr.Column(scale=1):
332
+ input_image = gr.Image(type="pil", label="Input Image")
333
+
334
+ pos_caption = gr.Textbox(
335
+ label="Positive Prompt",
336
+ placeholder="e.g., Green Apple",
337
+ value="Pos Caption Here."
338
+ )
339
+
340
+ neg_caption = gr.Textbox(
341
+ label="Negative Prompt (optional)",
342
+ placeholder="e.g., Red Apple",
343
+ value="None."
344
+ )
345
+
346
+ box_threshold = gr.Slider(
347
+ minimum=0.0,
348
+ maximum=1.0,
349
+ value=0.42,
350
+ step=0.01,
351
+ label="Detection Threshold (0.42 = use model default)"
352
+ )
353
+
354
+ point_radius = gr.Slider(
355
+ minimum=1,
356
+ maximum=20,
357
+ value=5,
358
+ step=1,
359
+ label="Point Radius"
360
+ )
361
+
362
+ point_color = gr.Dropdown(
363
+ choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
364
+ value="blue",
365
+ label="Point Color"
366
+ )
367
+
368
+ submit_btn = gr.Button("Count Objects", variant="primary")
369
+
370
+ with gr.Column(scale=1):
371
+ output_image = gr.Image(type="pil", label="Result")
372
+ count_output = gr.Textbox(label="Count Result")
373
+
374
+ # Example images
375
+ # ["examples/in_the_wild.jpg", "Green plastic cup.", "Blue plastic cup."],
376
+ gr.Examples(
377
+ examples=[
378
+ ["examples/apples.png", "apple.", "Green apple."],
379
+ ["examples/apple.jpg", "apple.", "red apple."],
380
+ ["examples/black_beans.jpg", "Black bean.", "Soy bean."],
381
+ ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
382
+ ["examples/strawberry.jpg", "strawberry and blueberry.", "strawberry."],
383
+ ["examples/strawberry2.jpg", "strawberry and blueberry.", "strawberry."],
384
+ ["examples/women.jpg", "person.", "woman."],
385
+ ["examples/boat-1.jpg", "boat.", "blue boat."],
386
+ ],
387
+ inputs=[input_image, pos_caption, neg_caption],
388
+ outputs=[output_image, count_output],
389
+ fn=count_objects,
390
+ cache_examples=False,
391
+ )
392
+
393
+ submit_btn.click(
394
+ fn=count_objects,
395
+ inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
396
+ outputs=[output_image, count_output]
397
+ )
398
+
399
+ return demo
400
+
401
+
402
+ if __name__ == "__main__":
403
+ # Load model at startup
404
+ print("Loading model...")
405
+ load_model()
406
+ print("Model loaded!")
407
+
408
+ # Create and launch demo
409
+ demo = create_demo()
410
  demo.launch()