Spaces:
Runtime error
Runtime error
| PROJECT_PATH = 'cleaned_code' | |
| import os | |
| import sys | |
| sys.path.append(PROJECT_PATH) | |
| import numpy as np | |
| import pickle | |
| import h5py | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from scipy.special import expit | |
| import torch | |
| from typing import Optional | |
| import json | |
| from src import BertForSemanticEmbedding, getLabelModel | |
| from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config | |
| from src import dataset_classification_type | |
| from src import SemSupDataset | |
| from transformers import AutoConfig, HfArgumentParser, AutoTokenizer | |
| import torch | |
| import json | |
| from tqdm import tqdm | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def compute_tok_score_cart(doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask): | |
| qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1 | |
| doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD | |
| exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD | |
| exact_match = exact_match.float() | |
| scores_no_masking = torch.matmul( | |
| qry_reps.view(-1, 16), # (Q * LQ) * d | |
| doc_reps.view(-1, 16).transpose(0, 1) # d * (D * LD) | |
| ) | |
| scores_no_masking = scores_no_masking.view( | |
| *qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD | |
| scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D | |
| tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1) | |
| return tok_scores | |
| def coil_fast_eval_forward( | |
| input_ids: Optional[torch.Tensor] = None, | |
| doc_reps = None, | |
| logits: Optional[torch.Tensor] = None, | |
| desc_input_ids = None, | |
| desc_attention_mask = None, | |
| lab_reps = None, | |
| label_embeddings = None | |
| ): | |
| tok_scores = compute_tok_score_cart( | |
| doc_reps, input_ids, | |
| lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask | |
| ) | |
| logits = (logits.unsqueeze(0) @ label_embeddings.T) | |
| new_tok_scores = torch.zeros(logits.shape, device = logits.device) | |
| for i in range(tok_scores.shape[1]): | |
| stride = tok_scores.shape[0]//tok_scores.shape[1] | |
| new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] | |
| return (logits + new_tok_scores).squeeze() | |
| class DemoModel: | |
| def __init__(self, ): | |
| self.label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/all_labels.txt')] | |
| unseen_label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/unseen_labels_split6500_2.txt')] | |
| num_labels = len(self.label_list) | |
| self.label_list.sort() # For consistency | |
| l2i = {v: i for i, v in enumerate(self.label_list)} | |
| unseen_label_indexes = [l2i[x] for x in unseen_label_list] | |
| self.coil_cluster_map = json.load(open(f'{PROJECT_PATH}/bert_coil_map_dict_lemma255K_isotropic.json')) | |
| all_lab_reps1, all_label_embeddings1, all_desc_input_ids_orig1, all_desc_input_ids1, all_desc_attention_mask1 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_1.pkl','rb')) | |
| all_lab_reps2, all_label_embeddings2, all_desc_input_ids_orig2, all_desc_input_ids2, all_desc_attention_mask2 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_2.pkl','rb')) | |
| all_lab_reps3, all_label_embeddings3, all_desc_input_ids_orig3, all_desc_input_ids3, all_desc_attention_mask3 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_3.pkl','rb')) | |
| all_lab_reps4, all_label_embeddings4, all_desc_input_ids_orig4, all_desc_input_ids4, all_desc_attention_mask4 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_4.pkl','rb')) | |
| all_lab_reps5, all_label_embeddings5, all_desc_input_ids_orig5, all_desc_input_ids5, all_desc_attention_mask5 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_5.pkl','rb')) | |
| self.all_lab_reps = [all_lab_reps1.to(device), all_lab_reps2.to(device), all_lab_reps3.to(device), all_lab_reps4.to(device), all_lab_reps5.to(device)] | |
| self.all_label_embeddings = [all_label_embeddings1.to(device), all_label_embeddings2.to(device), all_label_embeddings3.to(device), all_label_embeddings4.to(device), all_label_embeddings5.to(device)] | |
| self.all_desc_input_ids_orig = [all_desc_input_ids_orig1.to(device), all_desc_input_ids_orig2.to(device), all_desc_input_ids_orig3.to(device), all_desc_input_ids_orig4.to(device), all_desc_input_ids_orig5.to(device)] | |
| self.all_desc_input_ids = [all_desc_input_ids1.to(device), all_desc_input_ids2.to(device), all_desc_input_ids3.to(device), all_desc_input_ids4.to(device), all_desc_input_ids5.to(device)] | |
| self.all_desc_attention_mask = [all_desc_attention_mask1.to(device), all_desc_attention_mask2.to(device), all_desc_attention_mask3.to(device), all_desc_attention_mask4.to(device), all_desc_attention_mask5.to(device)] | |
| ARGS_FILE = f'{PROJECT_PATH}/configs/ablation_amzn_eda.yml' | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) | |
| self.model_args, self.data_args, self.training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp', extra_args = {})) | |
| config = AutoConfig.from_pretrained( | |
| self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, | |
| finetuning_task=self.data_args.task_name, | |
| cache_dir=self.model_args.cache_dir, | |
| revision=self.model_args.model_revision, | |
| use_auth_token=True if self.model_args.use_auth_token else None, | |
| ) | |
| config.model_name_or_path = self.model_args.model_name_or_path | |
| config.problem_type = dataset_classification_type[self.data_args.task_name] | |
| config.negative_sampling = self.model_args.negative_sampling | |
| config.semsup = self.model_args.semsup | |
| config.encoder_model_type = self.model_args.encoder_model_type | |
| config.arch_type = self.model_args.arch_type | |
| config.coil = self.model_args.coil | |
| config.token_dim = self.model_args.token_dim | |
| config.colbert = self.model_args.colbert | |
| label_model, label_tokenizer = getLabelModel(self.data_args, self.model_args) | |
| config.label_hidden_size = label_model.config.hidden_size | |
| model = BertForSemanticEmbedding(config) | |
| model.label_model = label_model | |
| model.label_tokenizer = label_tokenizer | |
| model.config.label2id = {l: i for i, l in enumerate(self.label_list)} | |
| model.config.id2label = {id: label for label, id in config.label2id.items()} | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| model.to(device) | |
| model.eval() | |
| torch.set_grad_enabled(False) | |
| model.load_state_dict(torch.load(f'{PROJECT_PATH}/ckpt/Amzn13K/amzn_main_model.bin', map_location = device)) | |
| self.model = model | |
| self.extracted_descs = [self.extract_descriptions(adi) for adi in self.all_desc_input_ids_orig] | |
| tot_len = len(self.all_desc_input_ids_orig) | |
| for i in range(len(self.all_desc_input_ids_orig[0])): | |
| for j in range(tot_len): | |
| if self.extracted_descs[j][i] == "": | |
| for k in range(tot_len): | |
| if self.extracted_descs[k][i] != '': | |
| self.extracted_descs[j][i] = self.extracted_descs[k][i] | |
| break | |
| def extract_descriptions(self, input_ids): | |
| descs = self.tokenizer.batch_decode(input_ids, skip_special_tokens = True) | |
| new_descs = [] | |
| for desc in descs: | |
| a = desc.find('description is') | |
| if a == -1: | |
| # There is no description to use, lets go with empty | |
| new_descs.append("") | |
| continue | |
| b = min([desc.find(x, a) if desc.find(x, a) !=-1 else 99999999999 for x in ['label is','parents are','children are']]) | |
| if b == 99999999999: | |
| new_descs.append(desc[a:].strip()) | |
| else: | |
| new_descs.append(desc[a:b].strip()) | |
| return new_descs | |
| def classify(self, text, unseen_labels = None): | |
| self.model.eval() | |
| with torch.no_grad(): | |
| item = self.tokenizer(text, padding='max_length', max_length=self.data_args.max_seq_length, truncation=True) | |
| item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()} | |
| outputs_doc, logits = self.model.forward_input_encoder(**item) | |
| doc_reps = self.model.tok_proj(outputs_doc.last_hidden_state) | |
| input_ids = torch.tensor([self.coil_cluster_map[str(x.item())] for x in item['input_ids'][0]]).to(device).unsqueeze(0) | |
| all_logits = [] | |
| descriptions = [] | |
| for adi, ada, alr, ale in zip(self.all_desc_input_ids, self.all_desc_attention_mask, self.all_lab_reps, self.all_label_embeddings): | |
| all_logits.append(coil_fast_eval_forward(input_ids, doc_reps, logits, adi, ada, alr, ale)) | |
| final_logits = sum([expit(x.cpu()) for x in all_logits]) / len(all_logits) | |
| max_indices = torch.argmax(torch.stack(all_logits), dim=0).cpu().tolist() | |
| # from pdb import set_trace as bp | |
| # bp() | |
| outs = torch.topk(final_logits, k = 50) | |
| preds_dic = dict() | |
| descs_dic = dict() | |
| for i,v in zip(outs.indices, outs.values): | |
| preds_dic[self.label_list[i]] = v.item() | |
| print(self.extracted_descs[max_indices[i]][i]) | |
| descs_dic[self.label_list[i]] = self.extracted_descs[max_indices[i]][i] | |
| return preds_dic, descs_dic | |
| if __name__ == '__main__': | |
| model = DemoModel() | |
| model.classify('Hello') |