Spaces:
Sleeping
Sleeping
| # mmdet to groundingdino | |
| import argparse | |
| from collections import OrderedDict | |
| import torch | |
| from mmengine.runner import CheckpointLoader | |
| # convert the functions from mmdet to groundingdino | |
| def correct_unfold_reduction_order(x): | |
| out_channel, in_channel = x.shape | |
| x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2) | |
| x = x[:, [0, 2, 1, 3], :] | |
| x = x.reshape(out_channel, in_channel) | |
| return x | |
| def correct_unfold_norm_order(x): | |
| in_channel = x.shape[0] | |
| x = x.reshape(in_channel // 4, 4).transpose(0, 1) | |
| x = x[[0, 2, 1, 3], :] | |
| x = x.reshape(in_channel) | |
| return x | |
| def convert(ckpt): | |
| """Inverse mapping of checkpoint parameters to their original names.""" | |
| # Create a dictionary to hold the reversed checkpoint | |
| new_ckpt = OrderedDict() | |
| for k, v in list(ckpt.items()): | |
| new_v = v # Start with the original value | |
| # Inverse rules based on the convert function (from specific to general) | |
| if k.startswith('decoder'): | |
| new_k = k.replace('decoder', 'transformer.decoder') | |
| if 'norms.2' in new_k: | |
| new_k = new_k.replace('norms.2', 'norm1') | |
| if 'norms.1' in new_k: | |
| new_k = new_k.replace('norms.1', 'catext_norm') | |
| if 'norms.0' in new_k: | |
| new_k = new_k.replace('norms.0', 'norm2') | |
| if 'norms.3' in new_k: | |
| new_k = new_k.replace('norms.3', 'norm3') | |
| if 'cross_attn_text' in new_k: | |
| new_k = new_k.replace('cross_attn_text', 'ca_text') | |
| new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight') | |
| new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias') | |
| new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight') | |
| new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias') | |
| if 'ffn.layers.0.0' in new_k: | |
| new_k = new_k.replace('ffn.layers.0.0', 'linear1') | |
| if 'ffn.layers.1' in new_k: | |
| new_k = new_k.replace('ffn.layers.1', 'linear2') | |
| if 'self_attn.attn' in new_k: | |
| new_k = new_k.replace('self_attn.attn', 'self_attn') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| # encoder部分最后的reg_layer_id是6,和decoder区分开来 | |
| elif k.startswith('bbox_head.reg_branches.6'): | |
| if k.startswith('bbox_head.reg_branches.6.0'): | |
| new_k = k.replace('bbox_head.reg_branches.6.0', | |
| 'transformer.enc_out_bbox_embed.layers.0') | |
| if k.startswith('bbox_head.reg_branches.6.2'): | |
| new_k = k.replace('bbox_head.reg_branches.6.2', | |
| 'transformer.enc_out_bbox_embed.layers.1') | |
| if k.startswith('bbox_head.reg_branches.6.4'): | |
| new_k = k.replace('bbox_head.reg_branches.6.4', | |
| 'transformer.enc_out_bbox_embed.layers.2') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('query_embedding'): | |
| new_k = k.replace('query_embedding', 'transformer.tgt_embed') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('bbox_head.reg_branches'): | |
| # mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint | |
| # groundingdino有两部分参数值是一致的 | |
| # 分别是bbox_embed和transformer.decoder.embed | |
| # 所以mmdet直接将两部分参数进行了“合并” | |
| reg_layer_id = int(k.split('.')[2]) | |
| linear_id = int(k.split('.')[3]) | |
| weight_or_bias = k.split('.')[-1] | |
| new_k1 = 'transformer.decoder.bbox_embed.' + \ | |
| str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias | |
| new_k2 = 'bbox_embed.' + \ | |
| str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias | |
| new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict | |
| new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('bbox_head.cls_branches.6'): | |
| # mmdet在contrastive_embed中添加了bias项 | |
| # 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed | |
| new_k = 'transformer.enc_out_class_embed.bias' | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('bbox_head.cls_branches'): | |
| # mmdet在contrastive_embed中添加了bias项 | |
| new_k1 = 'transformer.decoder.class_embed.' + k[-6:] | |
| new_k2 = 'class_embed.' + k[-6:] | |
| new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict | |
| new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('memory_trans_'): | |
| if k.startswith('memory_trans_fc'): | |
| new_k = k.replace('memory_trans_fc', 'transformer.enc_output') | |
| elif k.startswith('memory_trans_norm'): | |
| new_k = k.replace('memory_trans_norm', 'transformer.enc_output_norm') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('encoder'): | |
| new_k = k.replace('encoder', 'transformer.encoder') | |
| new_k = new_k.replace('norms.0', 'norm1') | |
| new_k = new_k.replace('norms.1', 'norm2') | |
| new_k = new_k.replace('norms.2', 'norm3') | |
| new_k = new_k.replace('ffn.layers.0.0', 'linear1') | |
| new_k = new_k.replace('ffn.layers.1', 'linear2') | |
| if 'text_layers' in new_k: | |
| new_k = new_k.replace('self_attn.attn', 'self_attn') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('level_embed'): | |
| new_k = k.replace('level_embed', 'transformer.level_embed') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('neck.convs'): | |
| new_k = k.replace('neck.convs', 'input_proj') | |
| new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3') | |
| new_k = new_k.replace('conv.weight', '0.weight') | |
| new_k = new_k.replace('conv.bias', '0.bias') | |
| new_k = new_k.replace('gn.weight', '1.weight') | |
| new_k = new_k.replace('gn.bias', '1.bias') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif 'neck.extra_convs.0' in k: | |
| new_k = k.replace('neck.extra_convs.0', 'neck.convs.4') | |
| new_k = new_k.replace('neck.convs', 'input_proj') | |
| new_k = new_k.replace('conv.weight', '0.weight') | |
| new_k = new_k.replace('conv.bias', '0.bias') | |
| new_k = new_k.replace('gn.weight', '1.weight') | |
| new_k = new_k.replace('gn.bias', '1.bias') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('text_feat_map'): | |
| new_k = k.replace('text_feat_map', 'feat_map') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('language_model.language_backbone.body.model'): | |
| new_k = k.replace('language_model.language_backbone.body.model', 'bert') | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| elif k.startswith('backbone'): | |
| new_k = k.replace('backbone', 'backbone.0') | |
| if 'patch_embed.projection' in new_k: | |
| new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj') | |
| elif 'drop_after_pos' in new_k: | |
| new_k = new_k.replace('drop_after_pos', 'pos_drop') | |
| if 'stages' in new_k: | |
| new_k = new_k.replace('stages', 'layers') | |
| if 'ffn.layers.0.0' in new_k: | |
| new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1') | |
| elif 'ffn.layers.1' in new_k: | |
| new_k = new_k.replace('ffn.layers.1', 'mlp.fc2') | |
| elif 'attn.w_msa' in new_k: | |
| new_k = new_k.replace('attn.w_msa', 'attn') | |
| if 'downsample' in k: | |
| if 'reduction.' in k: | |
| new_v = correct_unfold_reduction_order(v) | |
| elif 'norm.' in k: | |
| new_v = correct_unfold_norm_order(v) | |
| new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict | |
| ######################################################################### | |
| else: | |
| print('skip:', k) | |
| continue | |
| # if 'transformer.decoder.bbox_embed' in new_k: | |
| # new_k = new_k.replace('transformer.decoder.bbox_embed', 'bbox_embed') | |
| # if new_k.startswith('module.'): | |
| # new_k = new_k.replace('module.', '') | |
| return new_ckpt | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Convert keys to GroundingDINO style.') | |
| parser.add_argument( | |
| 'src', | |
| nargs='?', | |
| default='grounding_dino_swin-l_pretrain_all-56d69e78.pth', | |
| help='src model path or url') | |
| # The dst path must be a full path of the new checkpoint. | |
| parser.add_argument( | |
| 'dst', | |
| nargs='?', | |
| default='mmdet_swinl.pth_groundingdino.pth', | |
| help='save path') | |
| args = parser.parse_args() | |
| checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') | |
| # mmdet中是state_dict而不是model | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| weight = convert(state_dict) | |
| torch.save(weight, args.dst) | |
| # sha = subprocess.check_output(['sha256sum', args.dst]).decode() | |
| # sha = calculate_sha256(args.dst) | |
| # final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8]) | |
| # subprocess.Popen(['mv', args.dst, final_file]) | |
| print(f'Done!!, save to {args.dst}') | |
| if __name__ == '__main__': | |
| main() | |
| # skip: dn_query_generator.label_embedding.weight |