Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from transformers import GroundingDinoProcessor | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch | |
| def prepare_targets(points, caption, shapes, emb_size, device, llmdet_processor): | |
| gt_points_b = [np.array(points) / np.array(shapes)[::-1]] | |
| gt_points_b[0] = gt_points_b[0].squeeze(0) | |
| gt_points = [torch.from_numpy(img_points).float() for img_points in gt_points_b] | |
| gt_logits = [torch.zeros((img_points.shape[0], emb_size)) for img_points in gt_points] | |
| tokenized = llmdet_processor.tokenizer(caption[0], padding="longest", return_tensors="pt") | |
| end_idxes = [torch.where(ids == 1012)[0][-1] for ids in tokenized['input_ids']] | |
| for i, end_idx in enumerate(end_idxes): | |
| gt_logits[i][:, :end_idx] = 1.0 | |
| caption_sizes = [idx + 2 for idx in end_idxes] | |
| targets = [{"points": p.to(device), "labels": l.to(device), "caption_size": c} | |
| for p, l, c in zip(gt_points, gt_logits, caption_sizes)] | |
| return targets | |
| def post_process_grounded_object_detection( | |
| outputs, | |
| box_threshold: float = 0.4, | |
| ): | |
| # for the fine-tuning model, the box threshold should be set to 0.50 | |
| logits, boxes = outputs.logits, outputs.pred_boxes | |
| probs = torch.sigmoid(logits) # (batch_size, num_queries, 256) | |
| scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries) | |
| results = [] | |
| for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)): | |
| score = s[s > box_threshold] | |
| box = b[s > box_threshold] | |
| prob = p[s > box_threshold] | |
| results.append({"scores": score, "boxes": box}) | |
| return results | |
| def post_process_grounded_object_detection_with_queries( | |
| outputs, | |
| queries, | |
| box_threshold: float = 0.4, | |
| ): | |
| """ | |
| Post-process grounded object detection outputs. | |
| Now also returns the query embeddings for each kept prediction. | |
| """ | |
| logits, boxes = outputs.logits, outputs.pred_boxes | |
| assert logits.shape == queries.shape, "logits and queries must have the same batch size, but got {} and {}".format(logits.shape[0], queries.shape[0]) | |
| probs = torch.sigmoid(logits) # (batch_size, num_queries, 256) | |
| scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries) | |
| results = [] | |
| for idx, (s, b, p, q) in enumerate(zip(scores, boxes, probs, queries)): | |
| mask = s > box_threshold | |
| score = s[mask] | |
| box = b[mask] | |
| prob = p[mask] | |
| query = q[mask] | |
| result = {"scores": score, "boxes": box, "queries": query} | |
| assert score.shape[0] == box.shape[0] == query.shape[0], "scores, boxes and queries must have the same length, but got {} and {} and {}".format(score.shape[0], box.shape[0], query.shape[0]) | |
| results.append(result) | |
| return results | |
| class collator: | |
| def __init__(self, processor=None, use_negative=True): | |
| model_id = "fushh7/llmdet_swin_tiny_hf" | |
| self.llmdet_processor = GroundingDinoProcessor.from_pretrained(model_id) | |
| self.use_negative = use_negative | |
| def __call__(self, batch): | |
| # assume batch size is 1 | |
| example = batch[0] | |
| image = example['image'] | |
| pil_image = example['image'] | |
| w, h = image.size | |
| pos_caption = example['pos_caption'] | |
| neg_caption = example['neg_caption'] | |
| pos_points = example['pos_points'] | |
| neg_points = example['neg_points'] | |
| pos_count = example['pos_count'] | |
| neg_count = example['neg_count'] | |
| annotated_pos_count = example['annotated_pos_count'] | |
| annotated_neg_count = example['annotated_neg_count'] | |
| if 'type' in example: | |
| sample_type = example['type'] | |
| else: | |
| sample_type = 'eval' | |
| category = example['category'] | |
| image_name = "{}_{}_{}_{}_{}".format(category, pos_caption, neg_caption, pos_count, neg_count) | |
| pos_llm_det_inputs = self.llmdet_processor(images=image, text=pos_caption, return_tensors="pt", padding=True) | |
| neg_llm_det_inputs = self.llmdet_processor(images=image, text=neg_caption, return_tensors="pt", padding=True) | |
| pos_caption = [[pos_caption]] | |
| neg_caption = [[neg_caption]] | |
| shapes = [(w, h)] | |
| pos_points = [pos_points] | |
| neg_points = [neg_points] | |
| # exemplars | |
| if 'positive_exemplars' in example and 'negative_exemplars' in example and example[ | |
| 'positive_exemplars'] is not None and example['negative_exemplars'] is not None: | |
| pos_exemplars = example['positive_exemplars'] | |
| neg_exemplars = example['negative_exemplars'] | |
| img_height, img_width = pil_image.size | |
| norm_pos_exemplars = [] | |
| norm_neg_exemplars = [] | |
| exemplar_valid = True | |
| for exemplars in pos_exemplars: | |
| tly, tlx, bry, brx = exemplars | |
| tlx = tlx / img_width | |
| tly = tly / img_height | |
| brx = brx / img_width | |
| bry = bry / img_height | |
| if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0: | |
| exemplar_valid = False | |
| if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0: | |
| exemplar_valid = False | |
| if tlx >= brx or tly >= bry: | |
| exemplar_valid = False | |
| tlx = max(tlx, 0) | |
| tly = max(tly, 0) | |
| tly = min(tly, 1 - 1e-4) | |
| tlx = min(tlx, 1 - 1e-4) | |
| brx = min(brx, 1) | |
| bry = min(bry, 1) | |
| brx = max(brx, tlx) | |
| bry = max(bry, tly) | |
| assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}" | |
| norm_pos_exemplars.append([tlx, tly, brx, bry]) | |
| for exemplars in neg_exemplars: | |
| tly, tlx, bry, brx = exemplars | |
| tlx = tlx / img_width | |
| tly = tly / img_height | |
| brx = brx / img_width | |
| bry = bry / img_height | |
| if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0: | |
| exemplar_valid = False | |
| if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0: | |
| exemplar_valid = False | |
| if tlx >= brx or tly >= bry: | |
| exemplar_valid = False | |
| tlx = max(tlx, 0) | |
| tly = max(tly, 0) | |
| tly = min(tly, 1 - 1e-4) | |
| tlx = min(tlx, 1 - 1e-4) | |
| brx = min(brx, 1) | |
| bry = min(bry, 1) | |
| brx = max(brx, tlx) | |
| bry = max(bry, tly) | |
| assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}" | |
| norm_neg_exemplars.append([tlx, tly, brx, bry]) | |
| if exemplar_valid: | |
| pos_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_pos_exemplars] | |
| neg_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_neg_exemplars] | |
| pos_exemplars = torch.stack(pos_exemplars) | |
| neg_exemplars = torch.stack(neg_exemplars) | |
| batch_dict = { | |
| 'pos_llm_det_inputs': pos_llm_det_inputs, | |
| 'neg_llm_det_inputs': neg_llm_det_inputs, | |
| 'pos_caption': pos_caption, | |
| 'neg_caption': neg_caption, | |
| 'shapes': shapes, | |
| 'pos_points': pos_points, | |
| 'neg_points': neg_points, | |
| 'pos_count': pos_count, | |
| 'neg_count': neg_count, | |
| 'annotated_pos_count': annotated_pos_count, | |
| 'annotated_neg_count': annotated_neg_count, | |
| 'image': pil_image, | |
| 'category': category, | |
| 'type': sample_type, | |
| 'pos_exemplars': pos_exemplars, | |
| 'neg_exemplars': neg_exemplars, | |
| 'image_name': image_name, | |
| } | |
| else: | |
| batch_dict = { | |
| 'pos_llm_det_inputs': pos_llm_det_inputs, | |
| 'neg_llm_det_inputs': neg_llm_det_inputs, | |
| 'pos_caption': pos_caption, | |
| 'neg_caption': neg_caption, | |
| 'shapes': shapes, | |
| 'pos_points': pos_points, | |
| 'neg_points': neg_points, | |
| 'pos_count': pos_count, | |
| 'neg_count': neg_count, | |
| 'annotated_pos_count': annotated_pos_count, | |
| 'annotated_neg_count': annotated_neg_count, | |
| 'image': pil_image, | |
| 'category': category, | |
| 'type': sample_type, | |
| 'image_name': image_name, | |
| } | |
| else: | |
| batch_dict = { | |
| 'pos_llm_det_inputs': pos_llm_det_inputs, | |
| 'neg_llm_det_inputs': neg_llm_det_inputs, | |
| 'pos_caption': pos_caption, | |
| 'neg_caption': neg_caption, | |
| 'shapes': shapes, | |
| 'pos_points': pos_points, | |
| 'neg_points': neg_points, | |
| 'pos_count': pos_count, | |
| 'neg_count': neg_count, | |
| 'annotated_pos_count': annotated_pos_count, | |
| 'annotated_neg_count': annotated_neg_count, | |
| 'image': pil_image, | |
| 'category': category, | |
| 'type': sample_type, | |
| 'image_name': image_name, | |
| } | |
| return batch_dict | |
| import torch.distributed as dist | |
| def rank0_print(*args): | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| print(f"Rank {dist.get_rank()}: ", *args) | |
| else: | |
| print(*args) | |
| def build_dataset(data_args): | |
| from datasets import load_from_disk, concatenate_datasets | |
| categories = ["FOO", "FUN", "OFF", "OTR", "HOU"] | |
| if data_args.data_split not in categories: | |
| rank0_print(f"Warning: Invalid data_split '{data_args.data_split}'. Switching to 'all' mode.") | |
| data_args.data_split = "all" | |
| if data_args.data_split == "all": | |
| train_dataset = load_from_disk(data_args.train_data_path) | |
| train_dataset = concatenate_datasets( | |
| [train_dataset["FOO"], train_dataset["FUN"], train_dataset["OFF"], train_dataset["OTR"], | |
| train_dataset["HOU"]]) | |
| val_dataset = load_from_disk(data_args.val_data_path) | |
| val_dataset = concatenate_datasets( | |
| [val_dataset["FOO"], val_dataset["FUN"], val_dataset["OFF"], val_dataset["OTR"], val_dataset["HOU"]]) | |
| test_dataset = load_from_disk(data_args.test_data_path) | |
| test_dataset = concatenate_datasets( | |
| [test_dataset["FOO"], test_dataset["FUN"], test_dataset["OFF"], test_dataset["OTR"], test_dataset["HOU"]]) | |
| weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path) | |
| weakly_supervised_data = concatenate_datasets( | |
| [weakly_supervised_data["FOO"], weakly_supervised_data["FUN"], weakly_supervised_data["OFF"], | |
| weakly_supervised_data["OTR"], weakly_supervised_data["HOU"]]) | |
| rank0_print("Using 'all' mode: all categories for train/val/test") | |
| else: | |
| test_category = data_args.data_split | |
| train_categories = [cat for cat in categories if cat != test_category] | |
| train_dataset = load_from_disk(data_args.train_data_path) | |
| print(train_categories, train_dataset.keys()) | |
| train_datasets = [train_dataset[cat] for cat in train_categories] | |
| train_dataset = concatenate_datasets(train_datasets) | |
| weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path) | |
| weakly_supervised_data = [weakly_supervised_data[cat] for cat in train_categories] | |
| weakly_supervised_data = concatenate_datasets(weakly_supervised_data) | |
| val_dataset = load_from_disk(data_args.val_data_path) | |
| val_dataset = val_dataset[test_category] | |
| test_dataset = load_from_disk(data_args.test_data_path) | |
| test_dataset = test_dataset[test_category] | |
| rank0_print(f"Cross-validation mode: using {train_categories} for train, {test_category} for val/test") | |
| rank0_print('train_dataset: ', len(train_dataset)) | |
| rank0_print('val_dataset: ', len(val_dataset)) | |
| rank0_print('test_dataset: ', len(test_dataset)) | |
| rank0_print('weakly_supervised_data: ', len(weakly_supervised_data)) | |
| return train_dataset, val_dataset, test_dataset, weakly_supervised_data | |
| def generate_pseudo_density_map(points_norm: torch.Tensor, | |
| output_size: tuple[int, int], | |
| sigma: float = 4.0, | |
| normalize: bool = True) -> torch.Tensor: | |
| device = points_norm.device | |
| H, W = output_size | |
| N = points_norm.shape[0] | |
| ys = torch.arange(H, device=device).float() | |
| xs = torch.arange(W, device=device).float() | |
| grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij') # (H, W) | |
| pts_px = points_norm.clone() | |
| pts_px[:, 0] *= (W - 1) # x | |
| pts_px[:, 1] *= (H - 1) # y | |
| dx = grid_x.unsqueeze(0) - pts_px[:, 0].view(-1, 1, 1) # (N, H, W) | |
| dy = grid_y.unsqueeze(0) - pts_px[:, 1].view(-1, 1, 1) # (N, H, W) | |
| dist2 = dx ** 2 + dy ** 2 | |
| gaussians = torch.exp(-dist2 / (2 * sigma ** 2)) # (N, H, W) | |
| density_map = gaussians.sum(dim=0, keepdim=True) # (1, H, W) | |
| if normalize and N > 0: | |
| density_map = density_map * (N / density_map.sum()) | |
| return density_map.unsqueeze(0) | |
| def show_density_map(density_map: torch.Tensor, | |
| points_norm: torch.Tensor | None = None, | |
| figsize: tuple[int, int] = (6, 8), | |
| cmap: str = "jet") -> None: | |
| dm = density_map.squeeze().detach().cpu().numpy() # (H, W) | |
| H, W = dm.shape | |
| plt.figure(figsize=figsize) | |
| plt.imshow(dm, cmap=cmap, origin="upper") | |
| plt.colorbar(label="Density") | |
| if points_norm is not None and points_norm.numel() > 0: | |
| pts = points_norm.detach().cpu().numpy() | |
| xs = pts[:, 0] * (W - 1) | |
| ys = pts[:, 1] * (H - 1) | |
| plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5) | |
| plt.title(f"Density map (sum = {dm.sum():.2f})") | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.show() | |
| def show_image_with_density(pil_img: Image.Image, | |
| density_map: torch.Tensor, | |
| points_norm: torch.Tensor | None = None, | |
| cmap: str = "jet", | |
| alpha: float = 0.45, | |
| figsize: tuple[int, int] = (6, 8)) -> None: | |
| dm = density_map.squeeze().detach().cpu().numpy() # (H, W) | |
| H, W = dm.shape | |
| img_resized = pil_img.resize((W, H), Image.BILINEAR) # or LANCZOS | |
| img_np = np.asarray(img_resized) | |
| plt.figure(figsize=figsize) | |
| plt.imshow(img_np, origin="upper") | |
| plt.imshow(dm, cmap=cmap, alpha=alpha, origin="upper") | |
| if points_norm is not None and points_norm.numel() > 0: | |
| pts = points_norm.detach().cpu().numpy() | |
| xs = pts[:, 0] * (W - 1) | |
| ys = pts[:, 1] * (H - 1) | |
| plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5) | |
| plt.title(f"Overlay (density sum = {dm.sum():.2f})") | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.show() | |
| def build_point_count_map(feat_maps: torch.Tensor, | |
| pts_norm_list: list[torch.Tensor]) -> torch.Tensor: | |
| assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)" | |
| B, H, W, _ = feat_maps.shape | |
| device = feat_maps.device | |
| count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device) | |
| for b in range(B): | |
| pts = pts_norm_list[b].to(device).float() # (Ni, 2) | |
| if pts.numel() == 0: | |
| continue | |
| idx_xy = (pts * torch.tensor([W, H], device=device)).long() | |
| idx_xy[..., 0].clamp_(0, W - 1) # x | |
| idx_xy[..., 1].clamp_(0, H - 1) # y | |
| lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,) | |
| one = torch.ones_like(lin_idx, dtype=torch.float32) | |
| flat = torch.zeros(H * W, dtype=torch.float32, device=device) | |
| flat.scatter_add_(0, lin_idx, one) | |
| count_map[b] = flat.view(H, W) | |
| return count_map | |
| import torch | |
| import torch.nn.functional as F | |
| def extract_pos_tokens_single(feat_maps: torch.Tensor, | |
| count_map: torch.Tensor): | |
| assert feat_maps.dim() == 4 and count_map.dim() == 3, "维度应为 (B,H,W,D) / (B,H,W)" | |
| B, H, W, D = feat_maps.shape | |
| assert B == 1, "当前函数假设 batch_size == 1" | |
| feat = feat_maps[0] # (H,W,D) | |
| cnt = count_map[0] # (H,W) | |
| pos_mask = cnt > 0 # Bool (H,W) | |
| if pos_mask.sum() == 0: | |
| empty = torch.empty(0, device=feat.device) | |
| return empty.reshape(0, D), empty.long() | |
| pos_tokens = feat[pos_mask] # (N_pos, D) | |
| y_idx, x_idx = torch.nonzero(pos_mask, as_tuple=True) | |
| lin_index = y_idx * W + x_idx # (N_pos,) | |
| return pos_tokens, lin_index | |
| def filter_overlap(pos_tok, lin_pos, neg_tok, lin_neg): | |
| pos_only_mask = ~torch.isin(lin_pos, lin_neg) | |
| neg_only_mask = ~torch.isin(lin_neg, lin_pos) | |
| return pos_tok[pos_only_mask], neg_tok[neg_only_mask] | |
| # ------------------------------------------------------------ | |
| # 2) supervised contrastive loss | |
| # ------------------------------------------------------------ | |
| def supcon_pos_neg(pos_tokens, neg_tokens, temperature=0.07): | |
| """ | |
| pos_tokens : (Np, D) Pos token | |
| neg_tokens : (Nn, D) Neg token | |
| """ | |
| if pos_tokens.numel() == 0 or neg_tokens.numel() == 0: | |
| return torch.tensor(0., device=pos_tokens.device, requires_grad=True) | |
| pos_tokens = F.normalize(pos_tokens, dim=-1) | |
| neg_tokens = F.normalize(neg_tokens, dim=-1) | |
| feats = torch.cat([pos_tokens, neg_tokens], dim=0) # (N, D) | |
| labels = torch.cat([torch.zeros(len(pos_tokens), device=feats.device, dtype=torch.long), | |
| torch.ones(len(neg_tokens), device=feats.device, dtype=torch.long)], dim=0) # (N,) | |
| logits = feats @ feats.T / temperature # (N, N) | |
| logits.fill_diagonal_(-1e4) | |
| mask_pos = labels.unsqueeze(0) == labels.unsqueeze(1) # (N, N) | |
| mask_pos.fill_diagonal_(False) | |
| exp_logits = logits.exp() | |
| denom = exp_logits.sum(dim=1, keepdim=True) # Σ_{a≠i} exp | |
| log_prob = logits - denom.log() # log softmax | |
| loss_i = -(mask_pos * log_prob).sum(1) / mask_pos.sum(1).clamp_min(1) | |
| loss = loss_i.mean() | |
| return loss | |
| def build_point_count_map(feat_maps: torch.Tensor, | |
| pts_norm_list: list[torch.Tensor]) -> torch.Tensor: | |
| assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)" | |
| B, H, W, _ = feat_maps.shape | |
| device = feat_maps.device | |
| count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device) | |
| for b in range(B): | |
| pts = pts_norm_list[b].to(device).float() # (Ni, 2) | |
| if pts.numel() == 0: | |
| continue | |
| idx_xy = (pts * torch.tensor([W, H], device=device)).long() | |
| idx_xy[..., 0].clamp_(0, W - 1) # x | |
| idx_xy[..., 1].clamp_(0, H - 1) # y | |
| lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,) | |
| one = torch.ones_like(lin_idx, dtype=torch.float32) | |
| flat = torch.zeros(H * W, dtype=torch.float32, device=device) | |
| flat.scatter_add_(0, lin_idx, one) | |
| count_map[b] = flat.view(H, W) | |
| return count_map | |