yifehuang97 commited on
Commit
6c2d054
·
1 Parent(s): 20914e5
Files changed (3) hide show
  1. app.py +5 -3
  2. examples/strawberry.jpg +3 -0
  3. utils.py +5 -12
app.py CHANGED
@@ -243,8 +243,8 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
243
  neg_queries = outputs["neg_queries"].squeeze(0).float()
244
  pos_queries = pos_queries[-1].squeeze(0)
245
  neg_queries = neg_queries[-1].squeeze(0)
246
- pos_queries = pos_queries.cpu().numpy()
247
- neg_queries = neg_queries.cpu().numpy()
248
  results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
249
 
250
  boxes = results["boxes"]
@@ -266,6 +266,8 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
266
 
267
  pos_queries = results["queries"]
268
  neg_queries = neg_results["queries"]
 
 
269
 
270
  img_size = image.size
271
  # filtered_points, kept_indices = filter_points_by_negative(
@@ -323,7 +325,6 @@ def create_demo():
323
  gr.Markdown("""
324
  # CountEx: Fine-Grained Counting via Exemplars and Exclusion
325
  Count specific objects in images using positive and negative text prompts.
326
- **Important Note: Both the Positive and Negative prompts must end with a period (.) for the model to correctly interpret the instruction.**
327
  """)
328
 
329
  with gr.Row():
@@ -377,6 +378,7 @@ def create_demo():
377
  ["examples/apples.png", "Green apple.", "Red apple."],
378
  ["examples/black_beans.jpg", "Black bean.", "Soy bean."],
379
  ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
 
380
  ],
381
  inputs=[input_image, pos_caption, neg_caption],
382
  outputs=[output_image, count_output],
 
243
  neg_queries = outputs["neg_queries"].squeeze(0).float()
244
  pos_queries = pos_queries[-1].squeeze(0)
245
  neg_queries = neg_queries[-1].squeeze(0)
246
+ pos_queries = pos_queries.cpu()
247
+ neg_queries = neg_queries.cpu()
248
  results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
249
 
250
  boxes = results["boxes"]
 
266
 
267
  pos_queries = results["queries"]
268
  neg_queries = neg_results["queries"]
269
+ pos_queries = pos_queries.numpy()
270
+ neg_queries = neg_queries.numpy()
271
 
272
  img_size = image.size
273
  # filtered_points, kept_indices = filter_points_by_negative(
 
325
  gr.Markdown("""
326
  # CountEx: Fine-Grained Counting via Exemplars and Exclusion
327
  Count specific objects in images using positive and negative text prompts.
 
328
  """)
329
 
330
  with gr.Row():
 
378
  ["examples/apples.png", "Green apple.", "Red apple."],
379
  ["examples/black_beans.jpg", "Black bean.", "Soy bean."],
380
  ["examples/candy.jpg", "Brown coffee candy.", "Black coffee candy."],
381
+ ["examples/strawberry.jpg", "strawberry.", "None."],
382
  ],
383
  inputs=[input_image, pos_caption, neg_caption],
384
  outputs=[output_image, count_output],
examples/strawberry.jpg ADDED

Git LFS Details

  • SHA256: aea3767562c09cc516f972743428152d6c796394624f68e4a9f5507394bae2c9
  • Pointer size: 130 Bytes
  • Size of remote file: 34 kB
utils.py CHANGED
@@ -55,10 +55,7 @@ def post_process_grounded_object_detection_with_queries(
55
  Now also returns the query embeddings for each kept prediction.
56
  """
57
  logits, boxes = outputs.logits, outputs.pred_boxes
58
- print("logits: ", logits.shape)
59
- print("boxes: ", boxes.shape)
60
- print("queries: ", queries.shape)
61
- assert len(logits[0]) == queries.shape[0], "logits and queries must have the same batch size, but got {} and {}".format(len(logits), queries.shape[0])
62
 
63
  probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
64
  scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
@@ -69,15 +66,11 @@ def post_process_grounded_object_detection_with_queries(
69
  score = s[mask]
70
  box = b[mask]
71
  prob = p[mask]
72
-
73
- result = {"scores": score, "boxes": box}
74
-
75
- # 保存对应的 query embeddings
76
- if queries is not None:
77
- result["queries"] = queries[idx][mask] # (num_kept, D)
78
-
79
  results.append(result)
80
- assert len(results['scores']) == len(results['boxes']) == results['queries'].shape[0], "scores, boxes and queries must have the same length"
81
  return results
82
 
83
 
 
55
  Now also returns the query embeddings for each kept prediction.
56
  """
57
  logits, boxes = outputs.logits, outputs.pred_boxes
58
+ assert logits.shape == queries.shape, "logits and queries must have the same batch size, but got {} and {}".format(logits.shape[0], queries.shape[0])
 
 
 
59
 
60
  probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
61
  scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
 
66
  score = s[mask]
67
  box = b[mask]
68
  prob = p[mask]
69
+ queries = queries[mask]
70
+ result = {"scores": score, "boxes": box, "queries": queries}
71
+ print('scores: ', score.shape, 'boxes: ', box.shape, 'queries: ', queries.shape)
 
 
 
 
72
  results.append(result)
73
+ assert results['scores'].shape == results['boxes'].shape == results['queries'].shape, "scores, boxes and queries must have the same shape, but got {} and {}".format(results['scores'].shape, results['boxes'].shape, results['queries'].shape)
74
  return results
75
 
76