yifehuang97 commited on
Commit
f517a75
·
1 Parent(s): 5a0ba26

(feat) update post_process_grounded_object_detection_with_queries

Browse files
Files changed (1) hide show
  1. utils.py +1 -1
utils.py CHANGED
@@ -55,7 +55,7 @@ def post_process_grounded_object_detection_with_queries(
55
  Now also returns the query embeddings for each kept prediction.
56
  """
57
  logits, boxes = outputs.logits, outputs.pred_boxes
58
- assert len(logits) == queries.shape[0], "logits and queries must have the same batch size"
59
 
60
  probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
61
  scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
 
55
  Now also returns the query embeddings for each kept prediction.
56
  """
57
  logits, boxes = outputs.logits, outputs.pred_boxes
58
+ assert len(logits) == queries.shape[0], "logits and queries must have the same batch size, but got {} and {}".format(len(logits), queries.shape[0])
59
 
60
  probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
61
  scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)