lorebianchi98 commited on
Commit
c26362f
Β·
1 Parent(s): e7d7e74

Simplified model loading

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -43,14 +43,14 @@ Open-Vocabulary Segmentation (OVS) aims at segmenting images from free-form text
43
  ### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
44
  We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
45
  ```python
46
- from hf_model.talk2dino import Talk2DINO
47
  from torchvision.io import read_image
48
 
49
  # Device setup
50
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
51
 
52
  # Model Loading
53
- model = Talk2DINO.from_pretrained("lorebianchi98/Talk2DINO-ViTL").to(device).eval()
54
 
55
  # Embedding generation
56
  with torch.no_grad():
 
43
  ### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
44
  We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
45
  ```python
46
+ from transformers import AutoModel
47
  from torchvision.io import read_image
48
 
49
  # Device setup
50
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
51
 
52
  # Model Loading
53
+ model = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL").to(device).eval()
54
 
55
  # Embedding generation
56
  with torch.no_grad():
config.json CHANGED
@@ -1,4 +1,10 @@
1
  {
 
 
 
 
 
 
2
  "avg_self_attn_token": false,
3
  "clip_model_name": "ViT-B/16",
4
  "disentangled_self_attn_token": true,
 
1
  {
2
+ "architectures": ["Talk2DINO"],
3
+ "model_type": "talk2dino",
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_talk2dino.Talk2DINOConfig",
6
+ "AutoModel": "modeling_talk2dino.Talk2DINO"
7
+ },
8
  "avg_self_attn_token": false,
9
  "clip_model_name": "ViT-B/16",
10
  "disentangled_self_attn_token": true,
configuration_talk2dino.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class Talk2DINOConfig(PretrainedConfig):
5
+ model_type = "talk2dino"
6
+
7
+ def __init__(
8
+ self,
9
+ avg_self_attn_token=False,
10
+ clip_model_name="ViT-B/16",
11
+ disentangled_self_attn_token=True,
12
+ is_eval=True,
13
+ keep_cls=False,
14
+ keep_end_seq=False,
15
+ loss=None,
16
+ model_name="dinov2_vitb14_reg",
17
+ pre_trained=True,
18
+ proj_class="vitb_mlp_infonce",
19
+ proj_model="ProjectionLayer",
20
+ proj_name="vitb_mlp_infonce",
21
+ resize_dim=518,
22
+ type="DINOText",
23
+ unfreeze_last_image_layer=False,
24
+ unfreeze_last_text_layer=False,
25
+ use_avg_text_token=False,
26
+ with_bg_clean=False,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+
31
+ # Store all parameters
32
+ self.avg_self_attn_token = avg_self_attn_token
33
+ self.clip_model_name = clip_model_name
34
+ self.disentangled_self_attn_token = disentangled_self_attn_token
35
+ self.is_eval = is_eval
36
+ self.keep_cls = keep_cls
37
+ self.keep_end_seq = keep_end_seq
38
+ self.loss = loss
39
+ self.model_name = model_name
40
+ self.pre_trained = pre_trained
41
+ self.proj_class = proj_class
42
+ self.proj_model = proj_model
43
+ self.proj_name = proj_name
44
+ self.resize_dim = resize_dim
45
+ self.type = type
46
+ self.unfreeze_last_image_layer = unfreeze_last_image_layer
47
+ self.unfreeze_last_text_layer = unfreeze_last_text_layer
48
+ self.use_avg_text_token = use_avg_text_token
49
+ self.with_bg_clean = with_bg_clean
hf_demo.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_talk2dino.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.dinotext import DINOText
2
+ from transformers import PreTrainedModel
3
+ from configuration_talk2dino import Talk2DINOConfig
4
+ import clip
5
+ import torch
6
+
7
+ class Talk2DINO(DINOText, PreTrainedModel):
8
+ config_class = Talk2DINOConfig
9
+
10
+ def __init__(self, config: Talk2DINOConfig):
11
+ # Store the config
12
+ self.config = config
13
+
14
+ # Convert config to a dict (works for PretrainedConfig subclasses)
15
+ cfg_dict = config.to_dict()
16
+
17
+ # Initialize parent (DINOText) with unpacked kwargs
18
+ super().__init__(**cfg_dict)
19
+
20
+ def encode_text(self, texts):
21
+ """ texts: string or list of strings
22
+ returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
23
+ """
24
+ text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
25
+ txt_embed = self.clip_model.encode_text(text_tokens)
26
+ txt_embed = self.proj.project_clip_txt(txt_embed)
27
+ return txt_embed
28
+
29
+ def encode_image(self, images):
30
+ """ images: PIL image or list of PIL images
31
+ returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
32
+ """
33
+ if type(images) is not list:
34
+ images = [images]
35
+ img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
36
+ img_preprocessed = torch.stack(img_preprocessed)
37
+ if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
38
+ img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
39
+ elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
40
+ img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
41
+
42
+ return img_embed
{hf_model β†’ src}/__init__.py RENAMED
File without changes
hf_model/talk2dino.py β†’ src/dinotext.py RENAMED
@@ -16,14 +16,14 @@ from transformers import BertModel, AutoTokenizer
16
  import torchvision.transforms as T
17
  import clip
18
  import importlib
19
- import hf_model.us as us
20
 
21
- from hf_model.pamr import PAMR
22
- from hf_model.masker import DINOTextMasker
23
- from hf_model.templates import get_template
24
 
25
- from hf_model.model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
- from hf_model.hooks import average_text_tokens, get_vit_out, feats
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
@@ -45,7 +45,8 @@ class DINOText(nn.Module):
45
  self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
46
  unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
47
  ):
48
- super().__init__()
 
49
  self.feats = {}
50
  self.model_name = model_name
51
  # loading the model
@@ -82,7 +83,7 @@ class DINOText(nn.Module):
82
  T.Normalize(mean, std),
83
  ])
84
 
85
- self.model.to(device)
86
  self.model.requires_grad_(False)
87
 
88
  self.clip_model_name = clip_model_name
@@ -91,7 +92,7 @@ class DINOText(nn.Module):
91
  # load the corresponding wordtokenizer
92
  self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
93
  else:
94
- self.clip_model, _ = clip.load(clip_model_name, device=device)
95
  self.clip_model.eval()
96
  self.clip_model.requires_grad_(False)
97
  if unfreeze_last_text_layer:
@@ -118,13 +119,11 @@ class DINOText(nn.Module):
118
  }
119
 
120
  self.proj = ProjectionLayer.from_config(config)
121
- if type(self.proj) == CLIPLastLayer:
122
- self.clip_model.transformer.resblocks[-2].register_forward_hook(self.get_clip_second_last_dense_out)
123
 
124
 
125
  # if pre_trained:
126
  # self.proj.load_state_dict(torch.load(os.path.join("weights", f"{proj_name}.pth"), 'cpu'))
127
- self.proj.to(device)
128
 
129
  self.masker = DINOTextMasker(similarity_type="cosine")
130
  self.masker = self.masker.eval()
@@ -166,12 +165,7 @@ class DINOText(nn.Module):
166
  return self_attn
167
 
168
  def encode_text(self, tokenized_texts):
169
- if type(self.proj) == CLIPLastLayer:
170
- self.clip_model.encode_text(tokenized_texts)
171
- x = self.feats['clip_second_last_out']
172
- x = x.to(dtype=torch.float32)
173
- else:
174
- x = self.clip_model.encode_text(tokenized_texts)
175
  return x
176
 
177
  def encode_image(self, images):
@@ -404,29 +398,4 @@ class DINOText(nn.Module):
404
  return mask_output
405
 
406
 
407
- from huggingface_hub import PyTorchModelHubMixin
408
 
409
- class Talk2DINO(DINOText, PyTorchModelHubMixin):
410
- def encode_text(self, texts):
411
- """ texts: string or list of strings
412
- returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
413
- """
414
- text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
415
- txt_embed = self.clip_model.encode_text(text_tokens)
416
- txt_embed = self.proj.project_clip_txt(txt_embed)
417
- return txt_embed
418
-
419
- def encode_image(self, images):
420
- """ images: PIL image or list of PIL images
421
- returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
422
- """
423
- if type(images) is not list:
424
- images = [images]
425
- img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
426
- img_preprocessed = torch.stack(img_preprocessed)
427
- if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
428
- img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
429
- elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
430
- img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
431
-
432
- return img_embed
 
16
  import torchvision.transforms as T
17
  import clip
18
  import importlib
19
+ import src.us as us
20
 
21
+ from src.pamr import PAMR
22
+ from src.masker import DINOTextMasker
23
+ from src.templates import get_template
24
 
25
+ from src.model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
+ from src.hooks import average_text_tokens, get_vit_out, feats
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
 
45
  self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
46
  unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
47
  ):
48
+ nn.Module.__init__(self)
49
+
50
  self.feats = {}
51
  self.model_name = model_name
52
  # loading the model
 
83
  T.Normalize(mean, std),
84
  ])
85
 
86
+ self.model
87
  self.model.requires_grad_(False)
88
 
89
  self.clip_model_name = clip_model_name
 
92
  # load the corresponding wordtokenizer
93
  self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
94
  else:
95
+ self.clip_model, _ = clip.load(clip_model_name, device='meta')
96
  self.clip_model.eval()
97
  self.clip_model.requires_grad_(False)
98
  if unfreeze_last_text_layer:
 
119
  }
120
 
121
  self.proj = ProjectionLayer.from_config(config)
 
 
122
 
123
 
124
  # if pre_trained:
125
  # self.proj.load_state_dict(torch.load(os.path.join("weights", f"{proj_name}.pth"), 'cpu'))
126
+ self.proj
127
 
128
  self.masker = DINOTextMasker(similarity_type="cosine")
129
  self.masker = self.masker.eval()
 
165
  return self_attn
166
 
167
  def encode_text(self, tokenized_texts):
168
+ x = self.clip_model.encode_text(tokenized_texts)
 
 
 
 
 
169
  return x
170
 
171
  def encode_image(self, images):
 
398
  return mask_output
399
 
400
 
 
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{hf_model β†’ src}/hooks.py RENAMED
File without changes
{hf_model β†’ src}/masker.py RENAMED
@@ -8,11 +8,11 @@ import torch
8
  import torch.distributed as dist
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- import hf_model.us as us
12
  from einops import rearrange, repeat
13
 
14
  # from models.dinotext.gumbel import gumbel_sigmoid
15
- from hf_model.modules import FeatureEncoder
16
  from omegaconf import OmegaConf
17
 
18
 
 
8
  import torch.distributed as dist
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ import src.us as us
12
  from einops import rearrange, repeat
13
 
14
  # from models.dinotext.gumbel import gumbel_sigmoid
15
+ from src.modules import FeatureEncoder
16
  from omegaconf import OmegaConf
17
 
18
 
{hf_model β†’ src}/model.py RENAMED
@@ -4,7 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
- from hf_model.hooks import get_self_attention, process_self_attention, feats
8
 
9
  class VisualProjectionLayer(nn.Module):
10
  """
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
+ from src.hooks import get_self_attention, process_self_attention, feats
8
 
9
  class VisualProjectionLayer(nn.Module):
10
  """
{hf_model β†’ src}/modules.py RENAMED
File without changes
{hf_model β†’ src}/pamr.py RENAMED
File without changes
{hf_model β†’ src}/templates.py RENAMED
File without changes
{hf_model β†’ src}/us.py RENAMED
File without changes