yifehuang97 commited on
Commit
9115945
·
1 Parent(s): 288ef96

(feat) spatial semantic sup

Browse files
app.py CHANGED
@@ -75,6 +75,117 @@ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5)
75
 
76
  return filtered_points, filtered_indices
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
79
  """
80
  Main inference function for counting objects
@@ -167,12 +278,27 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
167
  neg_boxes = [box.tolist() for box in neg_boxes]
168
  neg_points = [[box[0], box[1]] for box in neg_boxes]
169
 
 
 
 
 
 
170
  img_size = image.size
171
- filtered_points, kept_indices = filter_points_by_negative(
 
 
 
 
 
 
172
  points,
173
  neg_points,
 
 
174
  image_size=img_size,
175
- pixel_threshold=5
 
 
176
  )
177
 
178
  filtered_boxes = [boxes[i] for i in kept_indices]
 
75
 
76
  return filtered_points, filtered_indices
77
 
78
+ def discriminative_point_suppression(
79
+ points,
80
+ neg_points,
81
+ pos_queries,
82
+ neg_queries,
83
+ image_size,
84
+ pixel_threshold=5,
85
+ similarity_threshold=0.5,
86
+ mode="and"
87
+ ):
88
+ """
89
+ Discriminative Point Suppression (DPS):
90
+ Suppress positive predictions that are both spatially close to
91
+ AND semantically similar with negative predictions.
92
+
93
+ Motivation: Spatial proximity alone may cause false suppression when
94
+ positive and negative queries represent different semantic concepts.
95
+ By jointly verifying spatial AND semantic alignment, we ensure
96
+ suppression only occurs for true conflicts.
97
+
98
+ Args:
99
+ points: List of [x, y] positive points (normalized, 0-1)
100
+ neg_points: List of [x, y] negative points (normalized, 0-1)
101
+ pos_queries: (N, D) query embeddings for positive predictions
102
+ neg_queries: (M, D) query embeddings for negative predictions
103
+ image_size: (width, height) in pixels
104
+ pixel_threshold: spatial distance threshold in pixels
105
+ similarity_threshold: cosine similarity threshold for semantic match
106
+ mode: "and" for hard joint condition, "weighted" for soft combination
107
+
108
+ Returns:
109
+ filtered_points: points after suppression
110
+ filtered_indices: indices of kept points
111
+ suppression_info: dict with detailed suppression decisions (for analysis)
112
+ """
113
+ if not neg_points or not points:
114
+ return points, list(range(len(points))), {}
115
+
116
+ width, height = image_size
117
+ N, M = len(points), len(neg_points)
118
+
119
+ # === Spatial Distance ===
120
+ points_arr = np.array(points) * np.array([width, height]) # (N, 2)
121
+ neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
122
+
123
+ spatial_dist = np.linalg.norm(
124
+ points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
125
+ ) # (N, M)
126
+
127
+ # === Query Similarity (Cosine) ===
128
+ # Normalize queries
129
+ pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
130
+ neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
131
+
132
+ query_sim = np.dot(pos_q, neg_q.T) # (N, M), range [-1, 1]
133
+
134
+ # === Joint Suppression Decision ===
135
+ if mode == "and":
136
+ # Hard condition: suppress only if BOTH spatially close AND semantically similar
137
+ spatial_close = spatial_dist < pixel_threshold # (N, M)
138
+ semantic_similar = query_sim > similarity_threshold # (N, M)
139
+
140
+ # A positive is suppressed if ANY negative satisfies both conditions
141
+ should_suppress = (spatial_close & semantic_similar).any(axis=1) # (N,)
142
+
143
+ elif mode == "weighted":
144
+ # Soft combination: weighted score
145
+ # Convert distance to proximity score (0-1, higher = closer)
146
+ spatial_proximity = np.exp(-spatial_dist / pixel_threshold) # (N, M)
147
+
148
+ # Normalize similarity to [0, 1]
149
+ semantic_score = (query_sim + 1) / 2 # (N, M)
150
+
151
+ # Combined suppression score
152
+ suppression_score = spatial_proximity * semantic_score # (N, M)
153
+ max_suppression = suppression_score.max(axis=1) # (N,)
154
+
155
+ should_suppress = max_suppression > similarity_threshold
156
+
157
+ else:
158
+ raise ValueError(f"Unknown mode: {mode}")
159
+
160
+ # === Filter ===
161
+ keep_mask = ~should_suppress
162
+ filtered_points = np.array(points)[keep_mask].tolist()
163
+ filtered_indices = np.where(keep_mask)[0].tolist()
164
+
165
+ # === Suppression Info (for analysis/visualization) ===
166
+ suppression_info = {
167
+ "spatial_dist": spatial_dist,
168
+ "query_similarity": query_sim,
169
+ "suppressed_indices": np.where(should_suppress)[0].tolist(),
170
+ "suppressed_reasons": []
171
+ }
172
+
173
+ # Record why each point was suppressed
174
+ for i in np.where(should_suppress)[0]:
175
+ if mode == "and":
176
+ matching_negs = np.where(spatial_close[i] & semantic_similar[i])[0]
177
+ else:
178
+ matching_negs = [suppression_score[i].argmax()]
179
+
180
+ suppression_info["suppressed_reasons"].append({
181
+ "pos_idx": int(i),
182
+ "matched_neg_idx": matching_negs.tolist() if isinstance(matching_negs, np.ndarray) else matching_negs,
183
+ "min_spatial_dist": float(spatial_dist[i].min()),
184
+ "max_query_sim": float(query_sim[i].max())
185
+ })
186
+
187
+ return filtered_points, filtered_indices, suppression_info
188
+
189
  def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
190
  """
191
  Main inference function for counting objects
 
278
  neg_boxes = [box.tolist() for box in neg_boxes]
279
  neg_points = [[box[0], box[1]] for box in neg_boxes]
280
 
281
+ pos_queries = outputs["pos_queries"].squeeze(0)
282
+ neg_queries = outputs["neg_queries"].squeeze(0)
283
+ pos_queries = pos_queries.cpu().numpy()
284
+ neg_queries = neg_queries.cpu().numpy()
285
+
286
  img_size = image.size
287
+ # filtered_points, kept_indices = filter_points_by_negative(
288
+ # points,
289
+ # neg_points,
290
+ # image_size=img_size,
291
+ # pixel_threshold=5
292
+ # )
293
+ filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
294
  points,
295
  neg_points,
296
+ pos_queries,
297
+ neg_queries,
298
  image_size=img_size,
299
+ pixel_threshold=5,
300
+ similarity_threshold=0.5,
301
+ mode="and"
302
  )
303
 
304
  filtered_boxes = [boxes[i] for i in kept_indices]
hf_model/CountEX.py CHANGED
@@ -578,6 +578,8 @@ class CountEX(GroundingDinoForObjectDetection):
578
  extra_logs=logs,
579
  neg_logits=neg_logits,
580
  neg_pred_boxes=neg_pred_boxes,
 
 
581
  )
582
 
583
  return dict_outputs
 
578
  extra_logs=logs,
579
  neg_logits=neg_logits,
580
  neg_pred_boxes=neg_pred_boxes,
581
+ pos_queries=hidden_states,
582
+ neg_queries=neg_hidden_states,
583
  )
584
 
585
  return dict_outputs
hf_model/modeling_grounding_dino.py CHANGED
@@ -373,6 +373,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
373
  extra_logs: Optional[Dict] = None
374
  neg_logits: Optional[torch.FloatTensor] = None
375
  neg_pred_boxes: Optional[torch.FloatTensor] = None
 
 
376
 
377
 
378
  # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino
 
373
  extra_logs: Optional[Dict] = None
374
  neg_logits: Optional[torch.FloatTensor] = None
375
  neg_pred_boxes: Optional[torch.FloatTensor] = None
376
+ pos_queries: Optional[torch.FloatTensor] = None
377
+ neg_queries: Optional[torch.FloatTensor] = None
378
 
379
 
380
  # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino