Spaces:
Sleeping
Sleeping
Commit
·
9115945
1
Parent(s):
288ef96
(feat) spatial semantic sup
Browse files- app.py +128 -2
- hf_model/CountEX.py +2 -0
- hf_model/modeling_grounding_dino.py +2 -0
app.py
CHANGED
|
@@ -75,6 +75,117 @@ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5)
|
|
| 75 |
|
| 76 |
return filtered_points, filtered_indices
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
|
| 79 |
"""
|
| 80 |
Main inference function for counting objects
|
|
@@ -167,12 +278,27 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
|
|
| 167 |
neg_boxes = [box.tolist() for box in neg_boxes]
|
| 168 |
neg_points = [[box[0], box[1]] for box in neg_boxes]
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
img_size = image.size
|
| 171 |
-
filtered_points, kept_indices = filter_points_by_negative(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
points,
|
| 173 |
neg_points,
|
|
|
|
|
|
|
| 174 |
image_size=img_size,
|
| 175 |
-
pixel_threshold=5
|
|
|
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
filtered_boxes = [boxes[i] for i in kept_indices]
|
|
|
|
| 75 |
|
| 76 |
return filtered_points, filtered_indices
|
| 77 |
|
| 78 |
+
def discriminative_point_suppression(
|
| 79 |
+
points,
|
| 80 |
+
neg_points,
|
| 81 |
+
pos_queries,
|
| 82 |
+
neg_queries,
|
| 83 |
+
image_size,
|
| 84 |
+
pixel_threshold=5,
|
| 85 |
+
similarity_threshold=0.5,
|
| 86 |
+
mode="and"
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Discriminative Point Suppression (DPS):
|
| 90 |
+
Suppress positive predictions that are both spatially close to
|
| 91 |
+
AND semantically similar with negative predictions.
|
| 92 |
+
|
| 93 |
+
Motivation: Spatial proximity alone may cause false suppression when
|
| 94 |
+
positive and negative queries represent different semantic concepts.
|
| 95 |
+
By jointly verifying spatial AND semantic alignment, we ensure
|
| 96 |
+
suppression only occurs for true conflicts.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
points: List of [x, y] positive points (normalized, 0-1)
|
| 100 |
+
neg_points: List of [x, y] negative points (normalized, 0-1)
|
| 101 |
+
pos_queries: (N, D) query embeddings for positive predictions
|
| 102 |
+
neg_queries: (M, D) query embeddings for negative predictions
|
| 103 |
+
image_size: (width, height) in pixels
|
| 104 |
+
pixel_threshold: spatial distance threshold in pixels
|
| 105 |
+
similarity_threshold: cosine similarity threshold for semantic match
|
| 106 |
+
mode: "and" for hard joint condition, "weighted" for soft combination
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
filtered_points: points after suppression
|
| 110 |
+
filtered_indices: indices of kept points
|
| 111 |
+
suppression_info: dict with detailed suppression decisions (for analysis)
|
| 112 |
+
"""
|
| 113 |
+
if not neg_points or not points:
|
| 114 |
+
return points, list(range(len(points))), {}
|
| 115 |
+
|
| 116 |
+
width, height = image_size
|
| 117 |
+
N, M = len(points), len(neg_points)
|
| 118 |
+
|
| 119 |
+
# === Spatial Distance ===
|
| 120 |
+
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
|
| 121 |
+
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
|
| 122 |
+
|
| 123 |
+
spatial_dist = np.linalg.norm(
|
| 124 |
+
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
|
| 125 |
+
) # (N, M)
|
| 126 |
+
|
| 127 |
+
# === Query Similarity (Cosine) ===
|
| 128 |
+
# Normalize queries
|
| 129 |
+
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
|
| 130 |
+
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
|
| 131 |
+
|
| 132 |
+
query_sim = np.dot(pos_q, neg_q.T) # (N, M), range [-1, 1]
|
| 133 |
+
|
| 134 |
+
# === Joint Suppression Decision ===
|
| 135 |
+
if mode == "and":
|
| 136 |
+
# Hard condition: suppress only if BOTH spatially close AND semantically similar
|
| 137 |
+
spatial_close = spatial_dist < pixel_threshold # (N, M)
|
| 138 |
+
semantic_similar = query_sim > similarity_threshold # (N, M)
|
| 139 |
+
|
| 140 |
+
# A positive is suppressed if ANY negative satisfies both conditions
|
| 141 |
+
should_suppress = (spatial_close & semantic_similar).any(axis=1) # (N,)
|
| 142 |
+
|
| 143 |
+
elif mode == "weighted":
|
| 144 |
+
# Soft combination: weighted score
|
| 145 |
+
# Convert distance to proximity score (0-1, higher = closer)
|
| 146 |
+
spatial_proximity = np.exp(-spatial_dist / pixel_threshold) # (N, M)
|
| 147 |
+
|
| 148 |
+
# Normalize similarity to [0, 1]
|
| 149 |
+
semantic_score = (query_sim + 1) / 2 # (N, M)
|
| 150 |
+
|
| 151 |
+
# Combined suppression score
|
| 152 |
+
suppression_score = spatial_proximity * semantic_score # (N, M)
|
| 153 |
+
max_suppression = suppression_score.max(axis=1) # (N,)
|
| 154 |
+
|
| 155 |
+
should_suppress = max_suppression > similarity_threshold
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 159 |
+
|
| 160 |
+
# === Filter ===
|
| 161 |
+
keep_mask = ~should_suppress
|
| 162 |
+
filtered_points = np.array(points)[keep_mask].tolist()
|
| 163 |
+
filtered_indices = np.where(keep_mask)[0].tolist()
|
| 164 |
+
|
| 165 |
+
# === Suppression Info (for analysis/visualization) ===
|
| 166 |
+
suppression_info = {
|
| 167 |
+
"spatial_dist": spatial_dist,
|
| 168 |
+
"query_similarity": query_sim,
|
| 169 |
+
"suppressed_indices": np.where(should_suppress)[0].tolist(),
|
| 170 |
+
"suppressed_reasons": []
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Record why each point was suppressed
|
| 174 |
+
for i in np.where(should_suppress)[0]:
|
| 175 |
+
if mode == "and":
|
| 176 |
+
matching_negs = np.where(spatial_close[i] & semantic_similar[i])[0]
|
| 177 |
+
else:
|
| 178 |
+
matching_negs = [suppression_score[i].argmax()]
|
| 179 |
+
|
| 180 |
+
suppression_info["suppressed_reasons"].append({
|
| 181 |
+
"pos_idx": int(i),
|
| 182 |
+
"matched_neg_idx": matching_negs.tolist() if isinstance(matching_negs, np.ndarray) else matching_negs,
|
| 183 |
+
"min_spatial_dist": float(spatial_dist[i].min()),
|
| 184 |
+
"max_query_sim": float(query_sim[i].max())
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
return filtered_points, filtered_indices, suppression_info
|
| 188 |
+
|
| 189 |
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
|
| 190 |
"""
|
| 191 |
Main inference function for counting objects
|
|
|
|
| 278 |
neg_boxes = [box.tolist() for box in neg_boxes]
|
| 279 |
neg_points = [[box[0], box[1]] for box in neg_boxes]
|
| 280 |
|
| 281 |
+
pos_queries = outputs["pos_queries"].squeeze(0)
|
| 282 |
+
neg_queries = outputs["neg_queries"].squeeze(0)
|
| 283 |
+
pos_queries = pos_queries.cpu().numpy()
|
| 284 |
+
neg_queries = neg_queries.cpu().numpy()
|
| 285 |
+
|
| 286 |
img_size = image.size
|
| 287 |
+
# filtered_points, kept_indices = filter_points_by_negative(
|
| 288 |
+
# points,
|
| 289 |
+
# neg_points,
|
| 290 |
+
# image_size=img_size,
|
| 291 |
+
# pixel_threshold=5
|
| 292 |
+
# )
|
| 293 |
+
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
|
| 294 |
points,
|
| 295 |
neg_points,
|
| 296 |
+
pos_queries,
|
| 297 |
+
neg_queries,
|
| 298 |
image_size=img_size,
|
| 299 |
+
pixel_threshold=5,
|
| 300 |
+
similarity_threshold=0.5,
|
| 301 |
+
mode="and"
|
| 302 |
)
|
| 303 |
|
| 304 |
filtered_boxes = [boxes[i] for i in kept_indices]
|
hf_model/CountEX.py
CHANGED
|
@@ -578,6 +578,8 @@ class CountEX(GroundingDinoForObjectDetection):
|
|
| 578 |
extra_logs=logs,
|
| 579 |
neg_logits=neg_logits,
|
| 580 |
neg_pred_boxes=neg_pred_boxes,
|
|
|
|
|
|
|
| 581 |
)
|
| 582 |
|
| 583 |
return dict_outputs
|
|
|
|
| 578 |
extra_logs=logs,
|
| 579 |
neg_logits=neg_logits,
|
| 580 |
neg_pred_boxes=neg_pred_boxes,
|
| 581 |
+
pos_queries=hidden_states,
|
| 582 |
+
neg_queries=neg_hidden_states,
|
| 583 |
)
|
| 584 |
|
| 585 |
return dict_outputs
|
hf_model/modeling_grounding_dino.py
CHANGED
|
@@ -373,6 +373,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
|
| 373 |
extra_logs: Optional[Dict] = None
|
| 374 |
neg_logits: Optional[torch.FloatTensor] = None
|
| 375 |
neg_pred_boxes: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino
|
|
|
|
| 373 |
extra_logs: Optional[Dict] = None
|
| 374 |
neg_logits: Optional[torch.FloatTensor] = None
|
| 375 |
neg_pred_boxes: Optional[torch.FloatTensor] = None
|
| 376 |
+
pos_queries: Optional[torch.FloatTensor] = None
|
| 377 |
+
neg_queries: Optional[torch.FloatTensor] = None
|
| 378 |
|
| 379 |
|
| 380 |
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino
|