CountEx / utils.py
yifehuang97's picture
feat
d533db3
raw
history blame
20.1 kB
import torch
import numpy as np
from transformers import GroundingDinoProcessor
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
def prepare_targets(points, caption, shapes, emb_size, device, llmdet_processor):
gt_points_b = [np.array(points) / np.array(shapes)[::-1]]
gt_points_b[0] = gt_points_b[0].squeeze(0)
gt_points = [torch.from_numpy(img_points).float() for img_points in gt_points_b]
gt_logits = [torch.zeros((img_points.shape[0], emb_size)) for img_points in gt_points]
tokenized = llmdet_processor.tokenizer(caption[0], padding="longest", return_tensors="pt")
end_idxes = [torch.where(ids == 1012)[0][-1] for ids in tokenized['input_ids']]
for i, end_idx in enumerate(end_idxes):
gt_logits[i][:, :end_idx] = 1.0
caption_sizes = [idx + 2 for idx in end_idxes]
targets = [{"points": p.to(device), "labels": l.to(device), "caption_size": c}
for p, l, c in zip(gt_points, gt_logits, caption_sizes)]
return targets
def post_process_grounded_object_detection(
outputs,
box_threshold: float = 0.4,
):
# for the fine-tuning model, the box threshold should be set to 0.50
logits, boxes = outputs.logits, outputs.pred_boxes
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
results = []
for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
score = s[s > box_threshold]
box = b[s > box_threshold]
prob = p[s > box_threshold]
results.append({"scores": score, "boxes": box})
return results
def post_process_grounded_object_detection_with_queries(
outputs,
queries,
box_threshold: float = 0.4,
):
"""
Post-process grounded object detection outputs.
Now also returns the query embeddings for each kept prediction.
"""
logits, boxes = outputs.logits, outputs.pred_boxes
assert logits.shape == queries.shape, "logits and queries must have the same batch size, but got {} and {}".format(logits.shape[0], queries.shape[0])
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
results = []
for idx, (s, b, p, q) in enumerate(zip(scores, boxes, probs, queries)):
mask = s > box_threshold
score = s[mask]
box = b[mask]
prob = p[mask]
query = q[mask]
result = {"scores": score, "boxes": box, "queries": query}
assert score.shape[0] == box.shape[0] == query.shape[0], "scores, boxes and queries must have the same length, but got {} and {} and {}".format(score.shape[0], box.shape[0], query.shape[0])
results.append(result)
return results
class collator:
def __init__(self, processor=None, use_negative=True):
model_id = "fushh7/llmdet_swin_tiny_hf"
self.llmdet_processor = GroundingDinoProcessor.from_pretrained(model_id)
self.use_negative = use_negative
def __call__(self, batch):
# assume batch size is 1
example = batch[0]
image = example['image']
pil_image = example['image']
w, h = image.size
pos_caption = example['pos_caption']
neg_caption = example['neg_caption']
pos_points = example['pos_points']
neg_points = example['neg_points']
pos_count = example['pos_count']
neg_count = example['neg_count']
annotated_pos_count = example['annotated_pos_count']
annotated_neg_count = example['annotated_neg_count']
if 'type' in example:
sample_type = example['type']
else:
sample_type = 'eval'
category = example['category']
image_name = "{}_{}_{}_{}_{}".format(category, pos_caption, neg_caption, pos_count, neg_count)
pos_llm_det_inputs = self.llmdet_processor(images=image, text=pos_caption, return_tensors="pt", padding=True)
neg_llm_det_inputs = self.llmdet_processor(images=image, text=neg_caption, return_tensors="pt", padding=True)
pos_caption = [[pos_caption]]
neg_caption = [[neg_caption]]
shapes = [(w, h)]
pos_points = [pos_points]
neg_points = [neg_points]
# exemplars
if 'positive_exemplars' in example and 'negative_exemplars' in example and example[
'positive_exemplars'] is not None and example['negative_exemplars'] is not None:
pos_exemplars = example['positive_exemplars']
neg_exemplars = example['negative_exemplars']
img_height, img_width = pil_image.size
norm_pos_exemplars = []
norm_neg_exemplars = []
exemplar_valid = True
for exemplars in pos_exemplars:
tly, tlx, bry, brx = exemplars
tlx = tlx / img_width
tly = tly / img_height
brx = brx / img_width
bry = bry / img_height
if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
exemplar_valid = False
if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
exemplar_valid = False
if tlx >= brx or tly >= bry:
exemplar_valid = False
tlx = max(tlx, 0)
tly = max(tly, 0)
tly = min(tly, 1 - 1e-4)
tlx = min(tlx, 1 - 1e-4)
brx = min(brx, 1)
bry = min(bry, 1)
brx = max(brx, tlx)
bry = max(bry, tly)
assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
norm_pos_exemplars.append([tlx, tly, brx, bry])
for exemplars in neg_exemplars:
tly, tlx, bry, brx = exemplars
tlx = tlx / img_width
tly = tly / img_height
brx = brx / img_width
bry = bry / img_height
if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
exemplar_valid = False
if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
exemplar_valid = False
if tlx >= brx or tly >= bry:
exemplar_valid = False
tlx = max(tlx, 0)
tly = max(tly, 0)
tly = min(tly, 1 - 1e-4)
tlx = min(tlx, 1 - 1e-4)
brx = min(brx, 1)
bry = min(bry, 1)
brx = max(brx, tlx)
bry = max(bry, tly)
assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
norm_neg_exemplars.append([tlx, tly, brx, bry])
if exemplar_valid:
pos_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_pos_exemplars]
neg_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_neg_exemplars]
pos_exemplars = torch.stack(pos_exemplars)
neg_exemplars = torch.stack(neg_exemplars)
batch_dict = {
'pos_llm_det_inputs': pos_llm_det_inputs,
'neg_llm_det_inputs': neg_llm_det_inputs,
'pos_caption': pos_caption,
'neg_caption': neg_caption,
'shapes': shapes,
'pos_points': pos_points,
'neg_points': neg_points,
'pos_count': pos_count,
'neg_count': neg_count,
'annotated_pos_count': annotated_pos_count,
'annotated_neg_count': annotated_neg_count,
'image': pil_image,
'category': category,
'type': sample_type,
'pos_exemplars': pos_exemplars,
'neg_exemplars': neg_exemplars,
'image_name': image_name,
}
else:
batch_dict = {
'pos_llm_det_inputs': pos_llm_det_inputs,
'neg_llm_det_inputs': neg_llm_det_inputs,
'pos_caption': pos_caption,
'neg_caption': neg_caption,
'shapes': shapes,
'pos_points': pos_points,
'neg_points': neg_points,
'pos_count': pos_count,
'neg_count': neg_count,
'annotated_pos_count': annotated_pos_count,
'annotated_neg_count': annotated_neg_count,
'image': pil_image,
'category': category,
'type': sample_type,
'image_name': image_name,
}
else:
batch_dict = {
'pos_llm_det_inputs': pos_llm_det_inputs,
'neg_llm_det_inputs': neg_llm_det_inputs,
'pos_caption': pos_caption,
'neg_caption': neg_caption,
'shapes': shapes,
'pos_points': pos_points,
'neg_points': neg_points,
'pos_count': pos_count,
'neg_count': neg_count,
'annotated_pos_count': annotated_pos_count,
'annotated_neg_count': annotated_neg_count,
'image': pil_image,
'category': category,
'type': sample_type,
'image_name': image_name,
}
return batch_dict
import torch.distributed as dist
def rank0_print(*args):
if dist.is_initialized():
if dist.get_rank() == 0:
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
def build_dataset(data_args):
from datasets import load_from_disk, concatenate_datasets
categories = ["FOO", "FUN", "OFF", "OTR", "HOU"]
if data_args.data_split not in categories:
rank0_print(f"Warning: Invalid data_split '{data_args.data_split}'. Switching to 'all' mode.")
data_args.data_split = "all"
if data_args.data_split == "all":
train_dataset = load_from_disk(data_args.train_data_path)
train_dataset = concatenate_datasets(
[train_dataset["FOO"], train_dataset["FUN"], train_dataset["OFF"], train_dataset["OTR"],
train_dataset["HOU"]])
val_dataset = load_from_disk(data_args.val_data_path)
val_dataset = concatenate_datasets(
[val_dataset["FOO"], val_dataset["FUN"], val_dataset["OFF"], val_dataset["OTR"], val_dataset["HOU"]])
test_dataset = load_from_disk(data_args.test_data_path)
test_dataset = concatenate_datasets(
[test_dataset["FOO"], test_dataset["FUN"], test_dataset["OFF"], test_dataset["OTR"], test_dataset["HOU"]])
weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
weakly_supervised_data = concatenate_datasets(
[weakly_supervised_data["FOO"], weakly_supervised_data["FUN"], weakly_supervised_data["OFF"],
weakly_supervised_data["OTR"], weakly_supervised_data["HOU"]])
rank0_print("Using 'all' mode: all categories for train/val/test")
else:
test_category = data_args.data_split
train_categories = [cat for cat in categories if cat != test_category]
train_dataset = load_from_disk(data_args.train_data_path)
print(train_categories, train_dataset.keys())
train_datasets = [train_dataset[cat] for cat in train_categories]
train_dataset = concatenate_datasets(train_datasets)
weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
weakly_supervised_data = [weakly_supervised_data[cat] for cat in train_categories]
weakly_supervised_data = concatenate_datasets(weakly_supervised_data)
val_dataset = load_from_disk(data_args.val_data_path)
val_dataset = val_dataset[test_category]
test_dataset = load_from_disk(data_args.test_data_path)
test_dataset = test_dataset[test_category]
rank0_print(f"Cross-validation mode: using {train_categories} for train, {test_category} for val/test")
rank0_print('train_dataset: ', len(train_dataset))
rank0_print('val_dataset: ', len(val_dataset))
rank0_print('test_dataset: ', len(test_dataset))
rank0_print('weakly_supervised_data: ', len(weakly_supervised_data))
return train_dataset, val_dataset, test_dataset, weakly_supervised_data
def generate_pseudo_density_map(points_norm: torch.Tensor,
output_size: tuple[int, int],
sigma: float = 4.0,
normalize: bool = True) -> torch.Tensor:
device = points_norm.device
H, W = output_size
N = points_norm.shape[0]
ys = torch.arange(H, device=device).float()
xs = torch.arange(W, device=device).float()
grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij') # (H, W)
pts_px = points_norm.clone()
pts_px[:, 0] *= (W - 1) # x
pts_px[:, 1] *= (H - 1) # y
dx = grid_x.unsqueeze(0) - pts_px[:, 0].view(-1, 1, 1) # (N, H, W)
dy = grid_y.unsqueeze(0) - pts_px[:, 1].view(-1, 1, 1) # (N, H, W)
dist2 = dx ** 2 + dy ** 2
gaussians = torch.exp(-dist2 / (2 * sigma ** 2)) # (N, H, W)
density_map = gaussians.sum(dim=0, keepdim=True) # (1, H, W)
if normalize and N > 0:
density_map = density_map * (N / density_map.sum())
return density_map.unsqueeze(0)
def show_density_map(density_map: torch.Tensor,
points_norm: torch.Tensor | None = None,
figsize: tuple[int, int] = (6, 8),
cmap: str = "jet") -> None:
dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
H, W = dm.shape
plt.figure(figsize=figsize)
plt.imshow(dm, cmap=cmap, origin="upper")
plt.colorbar(label="Density")
if points_norm is not None and points_norm.numel() > 0:
pts = points_norm.detach().cpu().numpy()
xs = pts[:, 0] * (W - 1)
ys = pts[:, 1] * (H - 1)
plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
plt.title(f"Density map (sum = {dm.sum():.2f})")
plt.axis("off")
plt.tight_layout()
plt.show()
def show_image_with_density(pil_img: Image.Image,
density_map: torch.Tensor,
points_norm: torch.Tensor | None = None,
cmap: str = "jet",
alpha: float = 0.45,
figsize: tuple[int, int] = (6, 8)) -> None:
dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
H, W = dm.shape
img_resized = pil_img.resize((W, H), Image.BILINEAR) # or LANCZOS
img_np = np.asarray(img_resized)
plt.figure(figsize=figsize)
plt.imshow(img_np, origin="upper")
plt.imshow(dm, cmap=cmap, alpha=alpha, origin="upper")
if points_norm is not None and points_norm.numel() > 0:
pts = points_norm.detach().cpu().numpy()
xs = pts[:, 0] * (W - 1)
ys = pts[:, 1] * (H - 1)
plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
plt.title(f"Overlay (density sum = {dm.sum():.2f})")
plt.axis("off")
plt.tight_layout()
plt.show()
def build_point_count_map(feat_maps: torch.Tensor,
pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
B, H, W, _ = feat_maps.shape
device = feat_maps.device
count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
for b in range(B):
pts = pts_norm_list[b].to(device).float() # (Ni, 2)
if pts.numel() == 0:
continue
idx_xy = (pts * torch.tensor([W, H], device=device)).long()
idx_xy[..., 0].clamp_(0, W - 1) # x
idx_xy[..., 1].clamp_(0, H - 1) # y
lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
one = torch.ones_like(lin_idx, dtype=torch.float32)
flat = torch.zeros(H * W, dtype=torch.float32, device=device)
flat.scatter_add_(0, lin_idx, one)
count_map[b] = flat.view(H, W)
return count_map
import torch
import torch.nn.functional as F
def extract_pos_tokens_single(feat_maps: torch.Tensor,
count_map: torch.Tensor):
assert feat_maps.dim() == 4 and count_map.dim() == 3, "维度应为 (B,H,W,D) / (B,H,W)"
B, H, W, D = feat_maps.shape
assert B == 1, "当前函数假设 batch_size == 1"
feat = feat_maps[0] # (H,W,D)
cnt = count_map[0] # (H,W)
pos_mask = cnt > 0 # Bool (H,W)
if pos_mask.sum() == 0:
empty = torch.empty(0, device=feat.device)
return empty.reshape(0, D), empty.long()
pos_tokens = feat[pos_mask] # (N_pos, D)
y_idx, x_idx = torch.nonzero(pos_mask, as_tuple=True)
lin_index = y_idx * W + x_idx # (N_pos,)
return pos_tokens, lin_index
def filter_overlap(pos_tok, lin_pos, neg_tok, lin_neg):
pos_only_mask = ~torch.isin(lin_pos, lin_neg)
neg_only_mask = ~torch.isin(lin_neg, lin_pos)
return pos_tok[pos_only_mask], neg_tok[neg_only_mask]
# ------------------------------------------------------------
# 2) supervised contrastive loss
# ------------------------------------------------------------
def supcon_pos_neg(pos_tokens, neg_tokens, temperature=0.07):
"""
pos_tokens : (Np, D) Pos token
neg_tokens : (Nn, D) Neg token
"""
if pos_tokens.numel() == 0 or neg_tokens.numel() == 0:
return torch.tensor(0., device=pos_tokens.device, requires_grad=True)
pos_tokens = F.normalize(pos_tokens, dim=-1)
neg_tokens = F.normalize(neg_tokens, dim=-1)
feats = torch.cat([pos_tokens, neg_tokens], dim=0) # (N, D)
labels = torch.cat([torch.zeros(len(pos_tokens), device=feats.device, dtype=torch.long),
torch.ones(len(neg_tokens), device=feats.device, dtype=torch.long)], dim=0) # (N,)
logits = feats @ feats.T / temperature # (N, N)
logits.fill_diagonal_(-1e4)
mask_pos = labels.unsqueeze(0) == labels.unsqueeze(1) # (N, N)
mask_pos.fill_diagonal_(False)
exp_logits = logits.exp()
denom = exp_logits.sum(dim=1, keepdim=True) # Σ_{a≠i} exp
log_prob = logits - denom.log() # log softmax
loss_i = -(mask_pos * log_prob).sum(1) / mask_pos.sum(1).clamp_min(1)
loss = loss_i.mean()
return loss
def build_point_count_map(feat_maps: torch.Tensor,
pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
B, H, W, _ = feat_maps.shape
device = feat_maps.device
count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
for b in range(B):
pts = pts_norm_list[b].to(device).float() # (Ni, 2)
if pts.numel() == 0:
continue
idx_xy = (pts * torch.tensor([W, H], device=device)).long()
idx_xy[..., 0].clamp_(0, W - 1) # x
idx_xy[..., 1].clamp_(0, H - 1) # y
lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
one = torch.ones_like(lin_idx, dtype=torch.float32)
flat = torch.zeros(H * W, dtype=torch.float32, device=device)
flat.scatter_add_(0, lin_idx, one)
count_map[b] = flat.view(H, W)
return count_map