import torch import numpy as np from transformers import GroundingDinoProcessor import numpy as np import matplotlib.pyplot as plt from PIL import Image import torch def prepare_targets(points, caption, shapes, emb_size, device, llmdet_processor): gt_points_b = [np.array(points) / np.array(shapes)[::-1]] gt_points_b[0] = gt_points_b[0].squeeze(0) gt_points = [torch.from_numpy(img_points).float() for img_points in gt_points_b] gt_logits = [torch.zeros((img_points.shape[0], emb_size)) for img_points in gt_points] tokenized = llmdet_processor.tokenizer(caption[0], padding="longest", return_tensors="pt") end_idxes = [torch.where(ids == 1012)[0][-1] for ids in tokenized['input_ids']] for i, end_idx in enumerate(end_idxes): gt_logits[i][:, :end_idx] = 1.0 caption_sizes = [idx + 2 for idx in end_idxes] targets = [{"points": p.to(device), "labels": l.to(device), "caption_size": c} for p, l, c in zip(gt_points, gt_logits, caption_sizes)] return targets def post_process_grounded_object_detection( outputs, box_threshold: float = 0.4, ): # for the fine-tuning model, the box threshold should be set to 0.50 logits, boxes = outputs.logits, outputs.pred_boxes probs = torch.sigmoid(logits) # (batch_size, num_queries, 256) scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries) results = [] for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)): score = s[s > box_threshold] box = b[s > box_threshold] prob = p[s > box_threshold] results.append({"scores": score, "boxes": box}) return results def post_process_grounded_object_detection_with_queries( outputs, queries, box_threshold: float = 0.4, ): """ Post-process grounded object detection outputs. Now also returns the query embeddings for each kept prediction. """ logits, boxes = outputs.logits, outputs.pred_boxes assert logits.shape == queries.shape, "logits and queries must have the same batch size, but got {} and {}".format(logits.shape[0], queries.shape[0]) probs = torch.sigmoid(logits) # (batch_size, num_queries, 256) scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries) results = [] for idx, (s, b, p, q) in enumerate(zip(scores, boxes, probs, queries)): mask = s > box_threshold score = s[mask] box = b[mask] prob = p[mask] query = q[mask] result = {"scores": score, "boxes": box, "queries": query} assert score.shape[0] == box.shape[0] == query.shape[0], "scores, boxes and queries must have the same length, but got {} and {} and {}".format(score.shape[0], box.shape[0], query.shape[0]) results.append(result) return results class collator: def __init__(self, processor=None, use_negative=True): model_id = "fushh7/llmdet_swin_tiny_hf" self.llmdet_processor = GroundingDinoProcessor.from_pretrained(model_id) self.use_negative = use_negative def __call__(self, batch): # assume batch size is 1 example = batch[0] image = example['image'] pil_image = example['image'] w, h = image.size pos_caption = example['pos_caption'] neg_caption = example['neg_caption'] pos_points = example['pos_points'] neg_points = example['neg_points'] pos_count = example['pos_count'] neg_count = example['neg_count'] annotated_pos_count = example['annotated_pos_count'] annotated_neg_count = example['annotated_neg_count'] if 'type' in example: sample_type = example['type'] else: sample_type = 'eval' category = example['category'] image_name = "{}_{}_{}_{}_{}".format(category, pos_caption, neg_caption, pos_count, neg_count) pos_llm_det_inputs = self.llmdet_processor(images=image, text=pos_caption, return_tensors="pt", padding=True) neg_llm_det_inputs = self.llmdet_processor(images=image, text=neg_caption, return_tensors="pt", padding=True) pos_caption = [[pos_caption]] neg_caption = [[neg_caption]] shapes = [(w, h)] pos_points = [pos_points] neg_points = [neg_points] # exemplars if 'positive_exemplars' in example and 'negative_exemplars' in example and example[ 'positive_exemplars'] is not None and example['negative_exemplars'] is not None: pos_exemplars = example['positive_exemplars'] neg_exemplars = example['negative_exemplars'] img_height, img_width = pil_image.size norm_pos_exemplars = [] norm_neg_exemplars = [] exemplar_valid = True for exemplars in pos_exemplars: tly, tlx, bry, brx = exemplars tlx = tlx / img_width tly = tly / img_height brx = brx / img_width bry = bry / img_height if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0: exemplar_valid = False if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0: exemplar_valid = False if tlx >= brx or tly >= bry: exemplar_valid = False tlx = max(tlx, 0) tly = max(tly, 0) tly = min(tly, 1 - 1e-4) tlx = min(tlx, 1 - 1e-4) brx = min(brx, 1) bry = min(bry, 1) brx = max(brx, tlx) bry = max(bry, tly) assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}" norm_pos_exemplars.append([tlx, tly, brx, bry]) for exemplars in neg_exemplars: tly, tlx, bry, brx = exemplars tlx = tlx / img_width tly = tly / img_height brx = brx / img_width bry = bry / img_height if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0: exemplar_valid = False if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0: exemplar_valid = False if tlx >= brx or tly >= bry: exemplar_valid = False tlx = max(tlx, 0) tly = max(tly, 0) tly = min(tly, 1 - 1e-4) tlx = min(tlx, 1 - 1e-4) brx = min(brx, 1) bry = min(bry, 1) brx = max(brx, tlx) bry = max(bry, tly) assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}" norm_neg_exemplars.append([tlx, tly, brx, bry]) if exemplar_valid: pos_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_pos_exemplars] neg_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_neg_exemplars] pos_exemplars = torch.stack(pos_exemplars) neg_exemplars = torch.stack(neg_exemplars) batch_dict = { 'pos_llm_det_inputs': pos_llm_det_inputs, 'neg_llm_det_inputs': neg_llm_det_inputs, 'pos_caption': pos_caption, 'neg_caption': neg_caption, 'shapes': shapes, 'pos_points': pos_points, 'neg_points': neg_points, 'pos_count': pos_count, 'neg_count': neg_count, 'annotated_pos_count': annotated_pos_count, 'annotated_neg_count': annotated_neg_count, 'image': pil_image, 'category': category, 'type': sample_type, 'pos_exemplars': pos_exemplars, 'neg_exemplars': neg_exemplars, 'image_name': image_name, } else: batch_dict = { 'pos_llm_det_inputs': pos_llm_det_inputs, 'neg_llm_det_inputs': neg_llm_det_inputs, 'pos_caption': pos_caption, 'neg_caption': neg_caption, 'shapes': shapes, 'pos_points': pos_points, 'neg_points': neg_points, 'pos_count': pos_count, 'neg_count': neg_count, 'annotated_pos_count': annotated_pos_count, 'annotated_neg_count': annotated_neg_count, 'image': pil_image, 'category': category, 'type': sample_type, 'image_name': image_name, } else: batch_dict = { 'pos_llm_det_inputs': pos_llm_det_inputs, 'neg_llm_det_inputs': neg_llm_det_inputs, 'pos_caption': pos_caption, 'neg_caption': neg_caption, 'shapes': shapes, 'pos_points': pos_points, 'neg_points': neg_points, 'pos_count': pos_count, 'neg_count': neg_count, 'annotated_pos_count': annotated_pos_count, 'annotated_neg_count': annotated_neg_count, 'image': pil_image, 'category': category, 'type': sample_type, 'image_name': image_name, } return batch_dict import torch.distributed as dist def rank0_print(*args): if dist.is_initialized(): if dist.get_rank() == 0: print(f"Rank {dist.get_rank()}: ", *args) else: print(*args) def build_dataset(data_args): from datasets import load_from_disk, concatenate_datasets categories = ["FOO", "FUN", "OFF", "OTR", "HOU"] if data_args.data_split not in categories: rank0_print(f"Warning: Invalid data_split '{data_args.data_split}'. Switching to 'all' mode.") data_args.data_split = "all" if data_args.data_split == "all": train_dataset = load_from_disk(data_args.train_data_path) train_dataset = concatenate_datasets( [train_dataset["FOO"], train_dataset["FUN"], train_dataset["OFF"], train_dataset["OTR"], train_dataset["HOU"]]) val_dataset = load_from_disk(data_args.val_data_path) val_dataset = concatenate_datasets( [val_dataset["FOO"], val_dataset["FUN"], val_dataset["OFF"], val_dataset["OTR"], val_dataset["HOU"]]) test_dataset = load_from_disk(data_args.test_data_path) test_dataset = concatenate_datasets( [test_dataset["FOO"], test_dataset["FUN"], test_dataset["OFF"], test_dataset["OTR"], test_dataset["HOU"]]) weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path) weakly_supervised_data = concatenate_datasets( [weakly_supervised_data["FOO"], weakly_supervised_data["FUN"], weakly_supervised_data["OFF"], weakly_supervised_data["OTR"], weakly_supervised_data["HOU"]]) rank0_print("Using 'all' mode: all categories for train/val/test") else: test_category = data_args.data_split train_categories = [cat for cat in categories if cat != test_category] train_dataset = load_from_disk(data_args.train_data_path) print(train_categories, train_dataset.keys()) train_datasets = [train_dataset[cat] for cat in train_categories] train_dataset = concatenate_datasets(train_datasets) weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path) weakly_supervised_data = [weakly_supervised_data[cat] for cat in train_categories] weakly_supervised_data = concatenate_datasets(weakly_supervised_data) val_dataset = load_from_disk(data_args.val_data_path) val_dataset = val_dataset[test_category] test_dataset = load_from_disk(data_args.test_data_path) test_dataset = test_dataset[test_category] rank0_print(f"Cross-validation mode: using {train_categories} for train, {test_category} for val/test") rank0_print('train_dataset: ', len(train_dataset)) rank0_print('val_dataset: ', len(val_dataset)) rank0_print('test_dataset: ', len(test_dataset)) rank0_print('weakly_supervised_data: ', len(weakly_supervised_data)) return train_dataset, val_dataset, test_dataset, weakly_supervised_data def generate_pseudo_density_map(points_norm: torch.Tensor, output_size: tuple[int, int], sigma: float = 4.0, normalize: bool = True) -> torch.Tensor: device = points_norm.device H, W = output_size N = points_norm.shape[0] ys = torch.arange(H, device=device).float() xs = torch.arange(W, device=device).float() grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij') # (H, W) pts_px = points_norm.clone() pts_px[:, 0] *= (W - 1) # x pts_px[:, 1] *= (H - 1) # y dx = grid_x.unsqueeze(0) - pts_px[:, 0].view(-1, 1, 1) # (N, H, W) dy = grid_y.unsqueeze(0) - pts_px[:, 1].view(-1, 1, 1) # (N, H, W) dist2 = dx ** 2 + dy ** 2 gaussians = torch.exp(-dist2 / (2 * sigma ** 2)) # (N, H, W) density_map = gaussians.sum(dim=0, keepdim=True) # (1, H, W) if normalize and N > 0: density_map = density_map * (N / density_map.sum()) return density_map.unsqueeze(0) def show_density_map(density_map: torch.Tensor, points_norm: torch.Tensor | None = None, figsize: tuple[int, int] = (6, 8), cmap: str = "jet") -> None: dm = density_map.squeeze().detach().cpu().numpy() # (H, W) H, W = dm.shape plt.figure(figsize=figsize) plt.imshow(dm, cmap=cmap, origin="upper") plt.colorbar(label="Density") if points_norm is not None and points_norm.numel() > 0: pts = points_norm.detach().cpu().numpy() xs = pts[:, 0] * (W - 1) ys = pts[:, 1] * (H - 1) plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5) plt.title(f"Density map (sum = {dm.sum():.2f})") plt.axis("off") plt.tight_layout() plt.show() def show_image_with_density(pil_img: Image.Image, density_map: torch.Tensor, points_norm: torch.Tensor | None = None, cmap: str = "jet", alpha: float = 0.45, figsize: tuple[int, int] = (6, 8)) -> None: dm = density_map.squeeze().detach().cpu().numpy() # (H, W) H, W = dm.shape img_resized = pil_img.resize((W, H), Image.BILINEAR) # or LANCZOS img_np = np.asarray(img_resized) plt.figure(figsize=figsize) plt.imshow(img_np, origin="upper") plt.imshow(dm, cmap=cmap, alpha=alpha, origin="upper") if points_norm is not None and points_norm.numel() > 0: pts = points_norm.detach().cpu().numpy() xs = pts[:, 0] * (W - 1) ys = pts[:, 1] * (H - 1) plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5) plt.title(f"Overlay (density sum = {dm.sum():.2f})") plt.axis("off") plt.tight_layout() plt.show() def build_point_count_map(feat_maps: torch.Tensor, pts_norm_list: list[torch.Tensor]) -> torch.Tensor: assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)" B, H, W, _ = feat_maps.shape device = feat_maps.device count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device) for b in range(B): pts = pts_norm_list[b].to(device).float() # (Ni, 2) if pts.numel() == 0: continue idx_xy = (pts * torch.tensor([W, H], device=device)).long() idx_xy[..., 0].clamp_(0, W - 1) # x idx_xy[..., 1].clamp_(0, H - 1) # y lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,) one = torch.ones_like(lin_idx, dtype=torch.float32) flat = torch.zeros(H * W, dtype=torch.float32, device=device) flat.scatter_add_(0, lin_idx, one) count_map[b] = flat.view(H, W) return count_map import torch import torch.nn.functional as F def extract_pos_tokens_single(feat_maps: torch.Tensor, count_map: torch.Tensor): assert feat_maps.dim() == 4 and count_map.dim() == 3, "维度应为 (B,H,W,D) / (B,H,W)" B, H, W, D = feat_maps.shape assert B == 1, "当前函数假设 batch_size == 1" feat = feat_maps[0] # (H,W,D) cnt = count_map[0] # (H,W) pos_mask = cnt > 0 # Bool (H,W) if pos_mask.sum() == 0: empty = torch.empty(0, device=feat.device) return empty.reshape(0, D), empty.long() pos_tokens = feat[pos_mask] # (N_pos, D) y_idx, x_idx = torch.nonzero(pos_mask, as_tuple=True) lin_index = y_idx * W + x_idx # (N_pos,) return pos_tokens, lin_index def filter_overlap(pos_tok, lin_pos, neg_tok, lin_neg): pos_only_mask = ~torch.isin(lin_pos, lin_neg) neg_only_mask = ~torch.isin(lin_neg, lin_pos) return pos_tok[pos_only_mask], neg_tok[neg_only_mask] # ------------------------------------------------------------ # 2) supervised contrastive loss # ------------------------------------------------------------ def supcon_pos_neg(pos_tokens, neg_tokens, temperature=0.07): """ pos_tokens : (Np, D) Pos token neg_tokens : (Nn, D) Neg token """ if pos_tokens.numel() == 0 or neg_tokens.numel() == 0: return torch.tensor(0., device=pos_tokens.device, requires_grad=True) pos_tokens = F.normalize(pos_tokens, dim=-1) neg_tokens = F.normalize(neg_tokens, dim=-1) feats = torch.cat([pos_tokens, neg_tokens], dim=0) # (N, D) labels = torch.cat([torch.zeros(len(pos_tokens), device=feats.device, dtype=torch.long), torch.ones(len(neg_tokens), device=feats.device, dtype=torch.long)], dim=0) # (N,) logits = feats @ feats.T / temperature # (N, N) logits.fill_diagonal_(-1e4) mask_pos = labels.unsqueeze(0) == labels.unsqueeze(1) # (N, N) mask_pos.fill_diagonal_(False) exp_logits = logits.exp() denom = exp_logits.sum(dim=1, keepdim=True) # Σ_{a≠i} exp log_prob = logits - denom.log() # log softmax loss_i = -(mask_pos * log_prob).sum(1) / mask_pos.sum(1).clamp_min(1) loss = loss_i.mean() return loss def build_point_count_map(feat_maps: torch.Tensor, pts_norm_list: list[torch.Tensor]) -> torch.Tensor: assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)" B, H, W, _ = feat_maps.shape device = feat_maps.device count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device) for b in range(B): pts = pts_norm_list[b].to(device).float() # (Ni, 2) if pts.numel() == 0: continue idx_xy = (pts * torch.tensor([W, H], device=device)).long() idx_xy[..., 0].clamp_(0, W - 1) # x idx_xy[..., 1].clamp_(0, H - 1) # y lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,) one = torch.ones_like(lin_idx, dtype=torch.float32) flat = torch.zeros(H * W, dtype=torch.float32, device=device) flat.scatter_add_(0, lin_idx, one) count_map[b] = flat.view(H, W) return count_map