Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| """ | |
| Negative Grounding DINO Model for Object Detection with Negative Caption Support. | |
| This module extends the original GroundingDinoForObjectDetection to support negative captions | |
| for improved object detection performance. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from transformers.modeling_outputs import ModelOutput | |
| import torch.nn.functional as F | |
| from .modeling_grounding_dino import ( | |
| GroundingDinoForObjectDetection, | |
| GroundingDinoObjectDetectionOutput, | |
| GroundingDinoEncoderOutput, | |
| ) | |
| # density_fpn_head.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def _bilinear(x, size): | |
| return F.interpolate(x, size=size, mode="bilinear", align_corners=False) | |
| class DensityFPNHead(nn.Module): | |
| def __init__(self, | |
| in_channels: int = 512, | |
| mid_channels: int = 128, | |
| act_layer=nn.ReLU, | |
| norm_layer=nn.BatchNorm2d): | |
| super().__init__() | |
| # ---- 1×1 lateral convs (P3–P6) ---- | |
| self.lateral = nn.ModuleList([ | |
| nn.Conv2d(in_channels, mid_channels, 1) for _ in range(4) | |
| ]) | |
| # ---- smooth convs after add ---- | |
| self.smooth = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False), | |
| norm_layer(mid_channels), | |
| act_layer(inplace=True), | |
| ) for _ in range(3) # P6→P5, P5→P4, P4→P3 | |
| ]) | |
| self.up_blocks = nn.ModuleList([ | |
| nn.Sequential( | |
| act_layer(inplace=True), | |
| nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False), | |
| norm_layer(mid_channels), | |
| act_layer(inplace=True), | |
| ) for _ in range(3) # 167×94 → … → 1336×752 | |
| ]) | |
| # ---- output 3×3 conv -> 1 ---- | |
| self.out_conv = nn.Conv2d(mid_channels, 1, 3, padding=1, bias=False) | |
| def forward(self, feats): | |
| assert len(feats) == 4, "Expect feats list = [P3,P4,P5,P6]" | |
| # lateral 1×1 | |
| lat = [l(f) for l, f in zip(self.lateral, feats)] | |
| # top-down FPN fusion | |
| x = lat[-1] # P6 | |
| for i in range(3)[::-1]: # P5,P4,P3 | |
| x = _bilinear(x, lat[i].shape[-2:]) | |
| x = x + lat[i] | |
| x = self.smooth[i](x) | |
| # three-stage upsample + conv | |
| for up in self.up_blocks: | |
| h, w = x.shape[-2], x.shape[-1] | |
| x = _bilinear(x, (h * 2, w * 2)) | |
| x = up(x) | |
| x = self.out_conv(x) | |
| return F.relu(x) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def l2norm(x, dim=-1, eps=1e-6): | |
| return x / (x.norm(dim=dim, keepdim=True) + eps) | |
| # ----------------------------------- | |
| # 1) CommonFinderSimple | |
| # learn r "common prototypes", representing the common representative of positive/negative | |
| # non fancy: only MHA pooling + two light regularizations (shareability + diversity) | |
| # ----------------------------------- | |
| class CommonFinderSimple(nn.Module): | |
| """ | |
| Inputs: | |
| Q_pos: [B, K, D] | |
| Q_neg: [B, K, D] | |
| Returns: | |
| C_rows: [B, r, D] # batch copied r common prototypes (unitized) | |
| loss: scalar # small regularization: shareability + diversity | |
| stats: dict | |
| """ | |
| def __init__(self, d_model=256, r=64, nhead=4, | |
| share_w=0.02, div_w=0.02, ln_after=False): | |
| super().__init__() | |
| self.r = r | |
| self.share_w = share_w | |
| self.div_w = div_w | |
| proto = torch.randn(r, d_model) | |
| self.proto = nn.Parameter(l2norm(proto, -1)) # r×D learnable "core queries" | |
| self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) | |
| self.post = nn.Linear(d_model, d_model) | |
| self.ln = nn.LayerNorm(d_model) if ln_after else nn.Identity() | |
| def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor): | |
| B, K, D = Q_pos.shape | |
| seeds = self.proto[None].expand(B, -1, -1).contiguous() # [B,r,D] | |
| X = torch.cat([Q_pos, Q_neg], dim=1) # [B,2K,D] | |
| # use seeds to do one attention pooling on positive and negative sets, get r "common prototypes" | |
| C, _ = self.attn(query=seeds, key=X, value=X) # [B,r,D] | |
| C = l2norm(self.ln(self.post(C)), -1) # unitization | |
| # ---- Simple regularization: encourage C to be close to both Q_pos and Q_neg, and diverse from each other ---- | |
| # Shareability: average of maximum cosine similarity between C and Q_pos/Q_neg | |
| cos_pos = torch.einsum('brd,bkd->brk', C, l2norm(Q_pos, -1)) # [B,r,K] | |
| cos_neg = torch.einsum('brd,bkd->brk', C, l2norm(Q_neg, -1)) | |
| share_term = -(cos_pos.amax(dim=-1).mean() + cos_neg.amax(dim=-1).mean()) | |
| # Diversity: cosine between C should not collapse | |
| C0 = l2norm(self.proto, -1) # [r,D] | |
| gram = C0 @ C0.t() # [r,r] | |
| div_term = (gram - torch.eye(self.r, device=gram.device)).pow(2).mean() | |
| loss = self.share_w * share_term + self.div_w * div_term | |
| stats = { | |
| 'share_term': share_term.detach(), | |
| 'div_term': div_term.detach(), | |
| 'mean_cos_pos': cos_pos.mean().detach(), | |
| 'mean_cos_neg': cos_neg.mean().detach() | |
| } | |
| return C, loss, stats | |
| # ----------------------------------- | |
| # 2) NegExclusiveSimple | |
| # Remove "common" information from negative queries: two simple strategies can be used independently or together | |
| # (A) Soft removal: subtract the projection onto C (residual keeps non-common) | |
| # (B) Filtering: only keep the Top-M negative samples least similar to C | |
| # ----------------------------------- | |
| class NegExclusiveSimple(nn.Module): | |
| """ | |
| Inputs: | |
| Q_neg: [B,K,D] | |
| C_rows: [B,r,D] # common prototypes | |
| Args: | |
| mode: 'residual' | 'filter' | 'both' | |
| M: Top-M for 'filter' | |
| thresh: Filter threshold (max_cos_neg < thresh to keep), None means only use Top-M | |
| Returns: | |
| neg_refs: [B, M_or_K, D] # as negative reference (for next fusion) | |
| aux: dict | |
| """ | |
| def __init__(self, mode='residual', M=16, thresh=None): | |
| super().__init__() | |
| assert mode in ('residual', 'filter', 'both') | |
| self.mode = mode | |
| self.M = M | |
| self.thresh = thresh | |
| def forward(self, Q_neg: torch.Tensor, C_rows: torch.Tensor): | |
| B, K, D = Q_neg.shape | |
| r = C_rows.size(1) | |
| Qn = l2norm(Q_neg, -1) | |
| C = l2norm(C_rows, -1) | |
| sim = torch.einsum('bkd,brd->bkr', Qn, C).amax(dim=-1) # [B,K] | |
| outputs = {} | |
| if self.mode in ('residual', 'both'): | |
| # proj = (Q · C^T) C -> [B,K,D]; first weight [B,K,r], then multiply C [B,r,D] | |
| w = torch.einsum('bkd,brd->bkr', Qn, C) # [B,K,r] | |
| proj = torch.einsum('bkr,brd->bkd', w, C) # [B,K,D] | |
| neg_resid = l2norm(Qn - proj, -1) # non-common residual | |
| outputs['residual'] = neg_resid | |
| if self.mode in ('filter', 'both'): | |
| excl_score = 1.0 - sim # large = away from common | |
| if self.thresh is not None: | |
| mask = (sim < self.thresh).float() | |
| excl_score = excl_score * mask + (-1e4) * (1 - mask) | |
| M = min(self.M, K) | |
| topv, topi = torch.topk(excl_score, k=M, dim=1) # [B,M] | |
| neg_top = torch.gather(Qn, 1, topi.unsqueeze(-1).expand(-1, -1, D)) | |
| outputs['filtered'] = neg_top | |
| if self.mode == 'residual': | |
| neg_refs = outputs['residual'] | |
| elif self.mode == 'filter': | |
| neg_refs = outputs['filtered'] | |
| else: | |
| R = outputs['residual'] # [B,K,D] | |
| excl_score = 1.0 - sim | |
| M = min(self.M, K) | |
| topv, topi = torch.topk(excl_score, k=M, dim=1) | |
| neg_refs = torch.gather(R, 1, topi.unsqueeze(-1).expand(-1, -1, D)) # [B,M,D] | |
| aux = { | |
| 'mean_sim_to_common': sim.mean().detach(), | |
| 'kept_M': neg_refs.size(1) | |
| } | |
| return neg_refs, aux | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def l2norm(x, dim=-1, eps=1e-6): | |
| return x / (x.norm(dim=dim, keepdim=True) + eps) | |
| class FusionNoGate(nn.Module): | |
| """ | |
| Direct fusion (no gating): fuse neg_ref into Q_pos via one cross-attn. | |
| Variants: | |
| - 'residual_sub': Q_new = Q_pos - scale * LN(Z) | |
| - 'residual_add': Q_new = Q_pos + scale * LN(Z) | |
| - 'concat_linear': Q_new = Q_pos + Linear([Q_pos; Z]) | |
| """ | |
| def __init__(self, d_model=256, nhead=4, fusion_mode='residual_sub', | |
| init_scale=0.2, dropout_p=0.0): | |
| super().__init__() | |
| assert fusion_mode in ('residual_sub', 'residual_add', 'concat_linear') | |
| self.fusion_mode = fusion_mode | |
| self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) | |
| self.ln_z = nn.LayerNorm(d_model) | |
| self.drop = nn.Dropout(dropout_p) if dropout_p > 0 else nn.Identity() | |
| self.scale = nn.Parameter(torch.tensor(float(init_scale))) | |
| if fusion_mode == 'concat_linear': | |
| self.mix = nn.Linear(2 * d_model, d_model) | |
| nn.init.zeros_(self.mix.weight) | |
| nn.init.zeros_(self.mix.bias) | |
| def forward(self, Q_pos: torch.Tensor, neg_ref: torch.Tensor): | |
| """ | |
| Q_pos: [B, K, D] | |
| neg_ref: [B, M, D] | |
| return: Q_new [B, K, D], stats dict | |
| """ | |
| B, K, D = Q_pos.shape | |
| M = neg_ref.size(1) | |
| if M == 0: | |
| return Q_pos, {'kept': 0, 'scale': self.scale.detach()} | |
| # 1) Cross-attention: | |
| Z, attn_w = self.attn(query=Q_pos, key=neg_ref, value=neg_ref) # Z:[B,K,D] | |
| Z = self.ln_z(Z) | |
| Z = self.drop(Z) | |
| # 2) wo gating | |
| if self.fusion_mode == 'residual_sub': | |
| Q_new = Q_pos - self.scale * Z | |
| # print("z: ", Z.sum()) | |
| # print(torch.abs(Q_new - Q_pos).sum()) | |
| elif self.fusion_mode == 'residual_add': | |
| Q_new = Q_pos + self.scale * Z | |
| else: # 'concat_linear' | |
| fused = torch.cat([Q_pos, Z], dim=-1) # [B,K,2D] | |
| delta = self.mix(fused) # [B,K,D] | |
| Q_new = Q_pos + delta | |
| stats = { | |
| 'kept': M, | |
| 'attn_mean': attn_w.mean().detach(), | |
| 'fusion_scale': self.scale.detach() | |
| } | |
| return Q_new, stats | |
| class QuerySideNegNaive(nn.Module): | |
| def __init__(self, d_model=256, r=64, M=64, nhead=4, | |
| excl_mode='both', excl_thresh=0.5, gamma_max=0.7, | |
| share_w=0.02, div_w=0.02): | |
| super().__init__() | |
| self.common = CommonFinderSimple(d_model, r, nhead, share_w, div_w) | |
| self.excl = NegExclusiveSimple(mode=excl_mode, M=M, thresh=excl_thresh) | |
| self.fuse = FusionNoGate(d_model=d_model, | |
| nhead=4, | |
| fusion_mode='residual_sub', # or 'concat_linear' | |
| init_scale=0.25, | |
| dropout_p=0.1) | |
| def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor): | |
| C_rows, l_common, common_stats = self.common(Q_pos, Q_neg) | |
| neg_refs, excl_stats = self.excl(Q_neg, C_rows) | |
| Q_new, fuse_stats = self.fuse(Q_pos, neg_refs) | |
| loss = l_common | |
| stats = {} | |
| stats.update(common_stats); stats.update(excl_stats); stats.update(fuse_stats) | |
| return Q_new, loss, stats | |
| def set_fusion_scale(self, scale: float): | |
| del self.fuse.scale | |
| self.fuse.scale = nn.Parameter(torch.tensor(scale)) | |
| class CountEX(GroundingDinoForObjectDetection): | |
| """ | |
| Grounding DINO Model with negative caption support for improved object detection. | |
| This model extends the original GroundingDinoForObjectDetection by adding | |
| support for negative captions, which helps improve detection accuracy by | |
| learning what NOT to detect. | |
| """ | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Initialize negative fusion modules directly in __init__ | |
| self.query_side_neg_pipeline = QuerySideNegNaive() | |
| self.density_head = DensityFPNHead() | |
| self.config = config | |
| self.box_threshold = getattr(config, 'box_threshold', 0.4) | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| input_ids: torch.LongTensor, | |
| token_type_ids: torch.LongTensor = None, | |
| attention_mask: torch.LongTensor = None, | |
| pixel_mask: Optional[torch.BoolTensor] = None, | |
| encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None, | |
| # Negative prompt parameters | |
| neg_pixel_values: Optional[torch.FloatTensor] = None, | |
| neg_input_ids: Optional[torch.LongTensor] = None, | |
| neg_token_type_ids: Optional[torch.LongTensor] = None, | |
| neg_attention_mask: Optional[torch.LongTensor] = None, | |
| neg_pixel_mask: Optional[torch.BoolTensor] = None, | |
| **kwargs, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| use_neg = kwargs.get('use_neg', True) | |
| # Get positive outputs | |
| pos_kwargs = { | |
| 'exemplars': kwargs.get('pos_exemplars', None), | |
| } | |
| outputs = self.model( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| token_type_ids=token_type_ids, | |
| attention_mask=attention_mask, | |
| pixel_mask=pixel_mask, | |
| encoder_outputs=encoder_outputs, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **pos_kwargs, | |
| ) | |
| spatial_shapes = outputs.spatial_shapes | |
| token_num = 0 | |
| token_num_list = [0] | |
| for i in range(len(spatial_shapes)): | |
| token_num += spatial_shapes[i][0] * spatial_shapes[i][1] | |
| token_num_list.append(token_num.item()) | |
| positive_feature_maps = [] | |
| encoder_last_hidden_state_vision = outputs.encoder_last_hidden_state_vision | |
| for i in range(len(spatial_shapes)): | |
| feature_map = encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :] | |
| spatial_shape = spatial_shapes[i] | |
| b, t, d = feature_map.shape | |
| feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d) | |
| positive_feature_maps.append(feature_map) | |
| # Get negative outputs | |
| neg_kwargs = { | |
| 'exemplars': kwargs.get('neg_exemplars', None), | |
| } | |
| # print(kwargs) | |
| neg_outputs = self.model( | |
| pixel_values=neg_pixel_values, | |
| input_ids=neg_input_ids, | |
| token_type_ids=neg_token_type_ids, | |
| attention_mask=neg_attention_mask, | |
| pixel_mask=neg_pixel_mask, | |
| encoder_outputs=encoder_outputs, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **neg_kwargs, | |
| ) | |
| neg_encoder_last_hidden_state_vision = neg_outputs.encoder_last_hidden_state_vision | |
| neg_positive_feature_maps = [] | |
| for i in range(len(spatial_shapes)): | |
| feature_map = neg_encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :] | |
| spatial_shape = spatial_shapes[i] | |
| b, t, d = feature_map.shape | |
| feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d) | |
| neg_positive_feature_maps.append(feature_map) | |
| if return_dict: | |
| hidden_states = outputs.intermediate_hidden_states | |
| neg_hidden_states = neg_outputs.intermediate_hidden_states | |
| else: | |
| hidden_states = outputs[2] | |
| neg_hidden_states = neg_outputs[2] | |
| idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0) | |
| enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx] | |
| hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] | |
| init_reference_points = outputs.init_reference_points if return_dict else outputs[1] | |
| inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3] | |
| neg_inter_references_points = neg_outputs.intermediate_reference_points if return_dict else neg_outputs[3] | |
| neg_init_reference_points = neg_outputs.init_reference_points if return_dict else neg_outputs[1] | |
| neg_enc_text_hidden_state = neg_outputs.encoder_last_hidden_state_text if return_dict else neg_outputs[idx] | |
| # drop the exemplar tokens if used | |
| pos_exemplars = pos_kwargs.get('pos_exemplars', None) | |
| neg_exemplars = neg_kwargs.get('neg_exemplars', None) | |
| if pos_exemplars is not None or neg_exemplars is not None or attention_mask.shape[1] != enc_text_hidden_state.shape[1]: | |
| enc_text_hidden_state = enc_text_hidden_state[:, :enc_text_hidden_state.shape[1] - 3, :] | |
| neg_enc_text_hidden_state = neg_enc_text_hidden_state[:, :neg_enc_text_hidden_state.shape[1] - 3, :] | |
| # class logits + predicted bounding boxes | |
| outputs_classes = [] | |
| outputs_coords = [] | |
| # Apply negative fusion | |
| if use_neg: | |
| # print("Using negative fusions") | |
| #neg_hidden_states = self.negative_semantic_extractor(neg_hidden_states) | |
| #hidden_states = self.negative_fusion_module(hidden_states, neg_hidden_states) | |
| hidden_states = hidden_states.squeeze(0) | |
| neg_hidden_states = neg_hidden_states.squeeze(0) | |
| hidden_states, extra_loss, logs = self.query_side_neg_pipeline(hidden_states, neg_hidden_states) | |
| hidden_states = hidden_states.unsqueeze(0) | |
| neg_hidden_states = neg_hidden_states.unsqueeze(0) | |
| # print("extra_loss: ", extra_loss) | |
| else: | |
| # print("Not using negative fusions") | |
| extra_loss = None | |
| logs = None | |
| # print("Not using negative fusion") | |
| # print("extra_loss: ", extra_loss) | |
| # predict class and bounding box deltas for each stage | |
| num_levels = hidden_states.shape[1] | |
| for level in range(num_levels): | |
| if level == 0: | |
| reference = init_reference_points | |
| else: | |
| reference = inter_references_points[:, level - 1] | |
| reference = torch.special.logit(reference, eps=1e-5) | |
| # print("hidden_states[:, level]: ", hidden_states[:, level].shape) | |
| # print("enc_text_hidden_state: ", enc_text_hidden_state.shape) | |
| # print("attention_mask: ", attention_mask.shape) | |
| assert attention_mask.shape[1] == enc_text_hidden_state.shape[1], "Attention mask and text hidden state have different lengths: {} != {}".format(attention_mask.shape[1], enc_text_hidden_state.shape[1]) | |
| outputs_class = self.class_embed[level]( | |
| vision_hidden_state=hidden_states[:, level], | |
| text_hidden_state=enc_text_hidden_state, | |
| text_token_mask=attention_mask.bool(), | |
| ) | |
| delta_bbox = self.bbox_embed[level](hidden_states[:, level]) | |
| reference_coordinates = reference.shape[-1] | |
| if reference_coordinates == 4: | |
| outputs_coord_logits = delta_bbox + reference | |
| elif reference_coordinates == 2: | |
| delta_bbox[..., :2] += reference | |
| outputs_coord_logits = delta_bbox | |
| else: | |
| raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}") | |
| outputs_coord = outputs_coord_logits.sigmoid() | |
| outputs_classes.append(outputs_class) | |
| outputs_coords.append(outputs_coord) | |
| outputs_class = torch.stack(outputs_classes) | |
| outputs_coord = torch.stack(outputs_coords) | |
| logits = outputs_class[-1] | |
| pred_boxes = outputs_coord[-1] | |
| # INSERT_YOUR_CODE | |
| # ==== Get negative branch's logits and pred_boxes ==== | |
| neg_outputs_classes = [] | |
| neg_outputs_coords = [] | |
| for level in range(num_levels): | |
| if level == 0: | |
| neg_reference = neg_init_reference_points | |
| else: | |
| neg_reference = neg_inter_references_points[:, level - 1] | |
| neg_reference = torch.special.logit(neg_reference, eps=1e-5) | |
| neg_outputs_class = self.class_embed[level]( | |
| vision_hidden_state=neg_hidden_states[:, level], | |
| text_hidden_state=neg_enc_text_hidden_state, | |
| text_token_mask=neg_attention_mask.bool(), | |
| ) | |
| neg_delta_bbox = self.bbox_embed[level](neg_hidden_states[:, level]) | |
| neg_reference_coordinates = neg_reference.shape[-1] | |
| if neg_reference_coordinates == 4: | |
| neg_outputs_coord_logits = neg_delta_bbox + neg_reference | |
| elif neg_reference_coordinates == 2: | |
| neg_delta_bbox[..., :2] += neg_reference | |
| neg_outputs_coord_logits = neg_delta_bbox | |
| else: | |
| raise ValueError(f"neg_reference.shape[-1] should be 4 or 2, but got {neg_reference.shape[-1]}") | |
| neg_outputs_coord = neg_outputs_coord_logits.sigmoid() | |
| neg_outputs_classes.append(neg_outputs_class) | |
| neg_outputs_coords.append(neg_outputs_coord) | |
| neg_outputs_class = torch.stack(neg_outputs_classes) | |
| neg_outputs_coord = torch.stack(neg_outputs_coords) | |
| neg_logits = neg_outputs_class[-1] | |
| neg_pred_boxes = neg_outputs_coord[-1] | |
| loss, loss_dict, auxiliary_outputs = None, None, None | |
| if not return_dict: | |
| if auxiliary_outputs is not None: | |
| output = (logits, pred_boxes) + auxiliary_outputs + outputs | |
| else: | |
| output = (logits, pred_boxes) + outputs | |
| tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output | |
| return tuple_outputs | |
| all_feats = [] | |
| for pf, npf in zip(positive_feature_maps, neg_positive_feature_maps): | |
| pf = pf.permute(0, 3, 1, 2) | |
| npf = npf.permute(0, 3, 1, 2) | |
| all_feats.append(torch.cat([pf, npf], dim=1)) | |
| # pos_feat = positive_feature_maps[0].permute(0, 3, 1, 2) | |
| # neg_feat = neg_positive_feature_maps[0].permute(0, 3, 1, 2) | |
| # pos_minus_neg_feat = F.relu(pos_feat - neg_feat) | |
| # density_feat_map = torch.cat([pos_feat, neg_feat, pos_minus_neg_feat], dim=1) | |
| # density_feat_map = torch.cat([pos_feat, neg_feat], dim=1) | |
| density_map_pred = self.density_head(all_feats) | |
| dict_outputs = GroundingDinoObjectDetectionOutput( | |
| loss=loss, | |
| loss_dict=loss_dict, | |
| logits=logits, | |
| pred_boxes=pred_boxes, | |
| last_hidden_state=outputs.last_hidden_state, | |
| auxiliary_outputs=auxiliary_outputs, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision, | |
| encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text, | |
| encoder_vision_hidden_states=outputs.encoder_vision_hidden_states, | |
| encoder_text_hidden_states=outputs.encoder_text_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| intermediate_hidden_states=outputs.intermediate_hidden_states, | |
| intermediate_reference_points=outputs.intermediate_reference_points, | |
| init_reference_points=outputs.init_reference_points, | |
| enc_outputs_class=outputs.enc_outputs_class, | |
| enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, | |
| spatial_shapes=outputs.spatial_shapes, | |
| positive_feature_maps=positive_feature_maps, | |
| negative_feature_maps=neg_positive_feature_maps, | |
| density_map_pred=density_map_pred, | |
| extra_loss=extra_loss, | |
| extra_logs=logs, | |
| neg_logits=neg_logits, | |
| neg_pred_boxes=neg_pred_boxes, | |
| pos_queries=hidden_states, | |
| neg_queries=neg_hidden_states, | |
| ) | |
| return dict_outputs |