yifehuang97 commited on
Commit
476c991
·
1 Parent(s): ea3f939

(feat) update neg process

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -34,14 +34,15 @@ def load_model():
34
 
35
  import numpy as np
36
 
37
- def filter_points_by_negative(points, neg_points, distance_threshold=0.05):
38
  """
39
  Filter out positive points that are too close to any negative point.
40
 
41
  Args:
42
- points: List of [x, y] positive points (normalized coordinates)
43
- neg_points: List of [x, y] negative points (normalized coordinates)
44
- distance_threshold: Minimum distance threshold (in normalized coords)
 
45
 
46
  Returns:
47
  filtered_points: List of points that are far enough from all negative points
@@ -50,19 +51,24 @@ def filter_points_by_negative(points, neg_points, distance_threshold=0.05):
50
  if not neg_points or not points:
51
  return points, list(range(len(points)))
52
 
53
- points_arr = np.array(points) # (N, 2)
54
- neg_points_arr = np.array(neg_points) # (M, 2)
55
 
56
- # Compute pairwise distances: (N, M)
57
- # Using broadcasting: (N, 1, 2) - (1, M, 2) -> (N, M, 2) -> (N, M)
58
- diff = points_arr[:, None, :] - neg_points_arr[None, :, :]
 
 
 
 
 
 
59
  distances = np.linalg.norm(diff, axis=-1) # (N, M)
60
 
61
  # Find minimum distance to any negative point for each positive point
62
  min_distances = distances.min(axis=1) # (N,)
63
 
64
  # Keep points where min distance > threshold
65
- keep_mask = min_distances > distance_threshold
66
 
67
  filtered_points = points_arr[keep_mask].tolist()
68
  filtered_indices = np.where(keep_mask)[0].tolist()
@@ -161,10 +167,12 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
161
  neg_boxes = [box.tolist() for box in neg_boxes]
162
  neg_points = [[box[0], box[1]] for box in neg_boxes]
163
 
 
164
  filtered_points, kept_indices = filter_points_by_negative(
165
  points,
166
  neg_points,
167
- distance_threshold=0.05
 
168
  )
169
 
170
  filtered_boxes = [boxes[i] for i in kept_indices]
 
34
 
35
  import numpy as np
36
 
37
+ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5):
38
  """
39
  Filter out positive points that are too close to any negative point.
40
 
41
  Args:
42
+ points: List of [x, y] positive points (normalized coordinates, 0-1)
43
+ neg_points: List of [x, y] negative points (normalized coordinates, 0-1)
44
+ image_size: Tuple of (width, height) in pixels
45
+ pixel_threshold: Minimum distance threshold in pixels
46
 
47
  Returns:
48
  filtered_points: List of points that are far enough from all negative points
 
51
  if not neg_points or not points:
52
  return points, list(range(len(points)))
53
 
54
+ width, height = image_size
 
55
 
56
+ points_arr = np.array(points) # (N, 2) normalized
57
+ neg_points_arr = np.array(neg_points) # (M, 2) normalized
58
+
59
+ # Convert to pixel coordinates
60
+ points_pixel = points_arr * np.array([width, height]) # (N, 2)
61
+ neg_points_pixel = neg_points_arr * np.array([width, height]) # (M, 2)
62
+
63
+ # Compute pairwise distances in pixels: (N, M)
64
+ diff = points_pixel[:, None, :] - neg_points_pixel[None, :, :]
65
  distances = np.linalg.norm(diff, axis=-1) # (N, M)
66
 
67
  # Find minimum distance to any negative point for each positive point
68
  min_distances = distances.min(axis=1) # (N,)
69
 
70
  # Keep points where min distance > threshold
71
+ keep_mask = min_distances > pixel_threshold
72
 
73
  filtered_points = points_arr[keep_mask].tolist()
74
  filtered_indices = np.where(keep_mask)[0].tolist()
 
167
  neg_boxes = [box.tolist() for box in neg_boxes]
168
  neg_points = [[box[0], box[1]] for box in neg_boxes]
169
 
170
+ img_size = image.size
171
  filtered_points, kept_indices = filter_points_by_negative(
172
  points,
173
  neg_points,
174
+ image_size=img_size,
175
+ pixel_threshold=5
176
  )
177
 
178
  filtered_boxes = [boxes[i] for i in kept_indices]