Spaces:
Sleeping
Sleeping
Commit
·
f517a75
1
Parent(s):
5a0ba26
(feat) update post_process_grounded_object_detection_with_queries
Browse files
utils.py
CHANGED
|
@@ -55,7 +55,7 @@ def post_process_grounded_object_detection_with_queries(
|
|
| 55 |
Now also returns the query embeddings for each kept prediction.
|
| 56 |
"""
|
| 57 |
logits, boxes = outputs.logits, outputs.pred_boxes
|
| 58 |
-
assert len(logits) == queries.shape[0], "logits and queries must have the same batch size"
|
| 59 |
|
| 60 |
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
|
| 61 |
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
|
|
|
|
| 55 |
Now also returns the query embeddings for each kept prediction.
|
| 56 |
"""
|
| 57 |
logits, boxes = outputs.logits, outputs.pred_boxes
|
| 58 |
+
assert len(logits) == queries.shape[0], "logits and queries must have the same batch size, but got {} and {}".format(len(logits), queries.shape[0])
|
| 59 |
|
| 60 |
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
|
| 61 |
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
|