Spaces:
Runtime error
Runtime error
| from typing import Dict, Optional | |
| import numpy as np | |
| import torch | |
| import itertools | |
| import torch | |
| from torch.utils.data import Dataset | |
| import json | |
| import random | |
| from collections.abc import Mapping | |
| from typing import Dict, Optional, List, Any, NewType | |
| import pandas as pd | |
| from torch.utils.data import DataLoader | |
| from os.path import join | |
| import os | |
| import gensim.downloader | |
| import h5py | |
| import time | |
| from tqdm import tqdm | |
| def getTokenizedLabelDescriptions(data_args, desc_file, tokenizer): | |
| padding = "max_length" if data_args.pad_to_max_length else False | |
| max_seq_length = min(data_args.label_max_seq_length, tokenizer.model_max_length) | |
| label_descs = json.load(open(desc_file, encoding = 'utf-8')) | |
| return {label_key: [ | |
| tokenizer( | |
| desc, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_seq_length, | |
| return_tensors='pt' | |
| ) | |
| for desc in descs[1]] for label_key, descs in label_descs.items()} | |
| class SemSupDataset(Dataset): | |
| def __init__(self, input_dataset, data_args, label_descriptions_file, label_to_id, id_to_label, tokenizer, clsas_descs_len = None, return_desc_embeddings = False, sampleRandom : int = -1, cl_min_positive_descs = 20, useSemSup = True, seen_labels = None, add_label_name = False, max_descs_per_label = 999999, use_precomputed_embeddings = '', bm_short_file = '', ignore_pos_labels_file = '', isTrain = True, class_descs_tokenized = None, choice_indexes = None): | |
| self.input_dataset = input_dataset | |
| self.sampleRandom = sampleRandom | |
| self.cl_min_positive_descs = cl_min_positive_descs | |
| self.semsup = useSemSup | |
| self.seen_labels = seen_labels | |
| self.add_label_name = add_label_name | |
| self.max_descs_per_label = max_descs_per_label | |
| self.use_precomputed_embeddings = use_precomputed_embeddings | |
| self.choice_indexes = choice_indexes | |
| self.bmshortfile = bm_short_file | |
| self.useBMShort = True if self.bmshortfile!='' else False | |
| self.data_args = data_args | |
| self.tok_format = 0 | |
| self.isTrain = isTrain | |
| # if data_args.large_dset: | |
| # Instead of loading the | |
| self.coil_cluster_map = None | |
| try: | |
| if data_args.coil_cluster_mapping_path: | |
| self.coil_cluster_map = json.load(open(data_args.coil_cluster_mapping_path)) | |
| except: | |
| print('Failed to load cluster map for some reason') | |
| self.coil_cluster_map = None | |
| self.ignore_pos_labels_file = ignore_pos_labels_file | |
| if self.ignore_pos_labels_file: | |
| self.ignored_labels = [[y.strip() for y in x.split('\t') if y.strip()!=''] for x in open(self.ignore_pos_labels_file).readlines()] | |
| else: | |
| self.ignored_labels = False | |
| if self.useBMShort and not data_args.large_dset: | |
| self.shortlists = [[y.strip() for y in x.split('\t')] for x in open(self.bmshortfile).readlines()] | |
| if self.semsup and not data_args.large_dset: | |
| self.data_args = data_args | |
| self.label_descriptions_file = label_descriptions_file | |
| self.label_to_id = label_to_id | |
| self.id_to_label = id_to_label | |
| if self.seen_labels is not None and isinstance(self.seen_labels[0], str): | |
| self.seen_labels = np.array([self.label_to_id[x] for x in self.seen_labels]) | |
| self.tokenizer = tokenizer | |
| if class_descs_len is None: | |
| js_file = json.load(open(self.label_descriptions_file, encoding = 'utf-8')) | |
| self.class_descs_len = self.tokenize_class_descs(js_file, return_lengths = True) | |
| self.class_descs = self.tokenize_class_descs(js_file) | |
| else: | |
| self.class_descs_len = class_descs_len | |
| self.return_desc_embeddings = return_desc_embeddings | |
| self.label_max_seq_length = data_args.label_max_seq_length | |
| if return_desc_embeddings: | |
| self.save_tokenized_descs(self.add_label_name) | |
| if self.use_precomputed_embeddings: | |
| self.computed_desc_inputs_embeds = torch.from_numpy(np.load(self.use_precomputed_embeddings)) | |
| if self.semsup and data_args.large_dset: | |
| self.data_args = data_args | |
| self.label_descriptions_file = label_descriptions_file | |
| self.label_to_id = label_to_id | |
| self.id_to_label = id_to_label | |
| # No concept of seen labels over here, directly load the shortlists | |
| self.tokenizer = tokenizer | |
| self.return_desc_embeddings = return_desc_embeddings | |
| self.label_max_seq_length = data_args.label_max_seq_length | |
| to_save = True | |
| if os.path.exists(data_args.tokenized_descs_file): | |
| print('Path Exists') | |
| if data_args.tok_format == 1: | |
| self.tok_format = 1 | |
| if class_descs_tokenized is not None: | |
| self.class_descs_tokenized = class_descs_tokenized | |
| else: | |
| if data_args.tokenized_descs_file.endswith('h5'): | |
| self.class_descs_tokenized = h5py.File(data_args.tokenized_descs_file) # np.load(data_args.tokenized_descs_file, allow_pickle=True).item() | |
| self.tok_format = 1 | |
| else: | |
| self.class_descs_tokenized = np.load(data_args.tokenized_descs_file, allow_pickle=True) | |
| # TODO: Fix this hardcoding | |
| # if len(arr) < int(1e6): | |
| # to_save = True # Possibly Corrupt File | |
| # # All set, load the file | |
| # else: | |
| to_save = False | |
| js_file = json.load(open(self.label_descriptions_file, encoding = 'utf-8')) | |
| print('Loaded js File') | |
| self.class_descs_len = self.tokenize_class_descs(js_file, return_lengths = True) | |
| if to_save: | |
| self.class_descs = self.tokenize_class_descs(js_file) | |
| print('Begin Tokenization Process') | |
| self.save_tokenized_descs(self.add_label_name) | |
| print('Saving Tokenized Descriptions') | |
| import pickle | |
| pickle.dump(self.class_descs_tokenized, open(data_args.tokenized_descs_file,'wb')) | |
| print(len(self.class_descs_tokenized)) | |
| 3/0 | |
| file = h5py.File(data_args.tokenized_descs_file,'w') | |
| for key in tqdm(self.class_descs_tokenized): | |
| key_h5 = key | |
| if key.find('/') != -1: | |
| print('There may be issue with', key) | |
| key_h5 = key.replace('/','\/') | |
| file.create_dataset(key_h5+'/'+'input_ids', data = np.array(self.class_descs_tokenized[key]['input_ids'])) | |
| file[key_h5].create_dataset('attention_mask', data = np.array(self.class_descs_tokenized[key]['attention_mask'])) | |
| # else: | |
| # self.class_descs_tokenized = np.load(data_args.tokenized_descs_file).item() | |
| if isTrain: | |
| self.shortlists = h5py.File(data_args.train_tfidf_short)['data'] | |
| else: | |
| print('Testtt File Loaded') | |
| self.shortlists = h5py.File(data_args.test_tfidf_short)['data'] | |
| try: | |
| del self.class_descs | |
| except: ... | |
| if self.tok_format != 1: | |
| self.class_descs_tokenized = pd.DataFrame({k: [np.array(x) for i, x in enumerate(v.values()) if i != 1] for k,v in self.class_descs_tokenized.items()}) | |
| def tokenize_class_descs(self, label_descs, return_lengths = False): | |
| if return_lengths == 1: | |
| return { | |
| label_key: min(descs[0],self.max_descs_per_label) for label_key, descs in label_descs.items() | |
| } # descs 0 is the length | |
| else: | |
| return { | |
| label_key: descs[1][:self.max_descs_per_label] for label_key, descs in label_descs.items() | |
| } | |
| def save_tokenized_descs(self, add_label_name = False): | |
| self.class_descs_tokenized = dict() | |
| for label_key in tqdm(list(self.class_descs.keys())): | |
| descs_len = self.class_descs_len[label_key] | |
| descs = self.class_descs[label_key] | |
| self.class_descs_tokenized[label_key] = self.tokenizer( | |
| [label_key + ". " + x for x in descs] if add_label_name else | |
| descs, | |
| max_length = self.label_max_seq_length, padding = 'max_length', truncation= True) | |
| # del self.class_descs_tokenized[label_key]['token_type_ids'] | |
| def __len__(self): | |
| return len(self.input_dataset) | |
| def get_item_for_large_dset(self, idx, item): | |
| if self.choice_indexes is not None: | |
| idx = int(self.choice_indexes[idx]) | |
| # print(idx) | |
| shortlists = self.shortlists[idx] | |
| labels_new = item['label'] | |
| if self.sampleRandom != -1: | |
| if self.sampleRandom < len(shortlists): | |
| shortlists = np.random.choice(shortlists, self.sampleRandom, replace = False) | |
| elif self.sampleRandom > len(shortlists): | |
| # randomly choose from all remaining labels | |
| shortlists = shortlists.tolist() + [self.label_to_id[x] for x in np.random.choice(self.seen_labels, self.sampleRandom - len(shortlists), replace = False)] | |
| if self.isTrain: | |
| pos_labels = np.where(np.array(labels_new) == 1)[0] | |
| item['all_candidate_labels'] = np.unique(np.concatenate([pos_labels, shortlists]))[:len(shortlists)] | |
| else: | |
| item['all_candidate_labels'] = np.unique(shortlists) | |
| if self.sampleRandom!=-1: | |
| if len(item['all_candidate_labels']) < self.sampleRandom: | |
| # Duplicate entries were deleted, manually add some duplicates :) | |
| item['all_candidate_labels'] = np.concatenate([item['all_candidate_labels'], item['all_candidate_labels'][len(item['all_candidate_labels'])-self.sampleRandom:]]) | |
| item['all_candidate_labels'] = item['all_candidate_labels'][:self.sampleRandom] | |
| l1 = len(item['all_candidate_labels']) | |
| if self.ignored_labels: | |
| # Remove the ignored labels | |
| # After removing make sure the size is equal to l1, by randomly duplicating elements | |
| ignore_list = {self.label_to_id[x] for x in self.ignored_labels} | |
| if len(ignore_list) > 0: | |
| item['all_candidate_labels'] = set(item['all_candidate_labels'].tolist()).difference(ignore_list) | |
| item['all_candidate_labels'] = sorted(list(item['all_candidate_labels'])) | |
| if len(item['all_candidate_labels']) < l: | |
| item['all_candidate_labels'] += item['all_candidate_labels'][:l - len(item['all_candidate_labels'])] | |
| item['all_candidate_labels'] = np.array(item['all_candidate_labels']) | |
| # l1 = np.array(item['label']).sum() | |
| item['label'] = np.array(item['label'])[item['all_candidate_labels']] | |
| # print(f'{item["label"].sum()} / {l1}') | |
| item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in item['all_candidate_labels']] | |
| if self.tok_format ==1: | |
| item['desc_input_ids'] = [self.class_descs_tokenized['input_ids'][label_key][item['label_desc_ids'][i]].astype(np.int32) for i, label_key in enumerate(item['all_candidate_labels'])] | |
| item['desc_attention_mask'] = [self.class_descs_tokenized['attention_mask'][label_key][item['label_desc_ids'][i]].astype(np.int32) for i, label_key in enumerate(item['all_candidate_labels'])] | |
| else: | |
| item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][i]] for i, label_key in enumerate(item['all_candidate_labels'])] | |
| item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][i]] for i, label_key in enumerate(item['all_candidate_labels'])] | |
| pos_pts = item['label'].nonzero()[0] | |
| # if len(pos_pts) > 0: | |
| # print(idx, item['desc_input_ids'][pos_pts[0]]) | |
| if self.coil_cluster_map: | |
| map_to_cluster = lambda x : self.coil_cluster_map[str(x)] | |
| if isinstance(item['input_ids'], list): | |
| item['clustered_input_ids'] = [self.coil_cluster_map[str(x)] for x in item['input_ids']] | |
| else: | |
| item['clustered_input_ids'] = item['input_ids'].vectorize(map_to_cluster) | |
| item['clustered_desc_ids'] = [[self.coil_cluster_map[str(x)] for x in xx] for xx in item['desc_input_ids']] | |
| return item | |
| def __getitem__(self, idx): | |
| item = self.input_dataset.__getitem__(idx) | |
| if self.data_args.large_dset: | |
| return self.get_item_for_large_dset(idx, item) | |
| # Iterate over all the labels of input_dataset | |
| # and add random label_description to the item in the same order | |
| if self.ignored_labels: | |
| ignored_labels = self.ignored_labels[idx] | |
| if self.sampleRandom != -1: | |
| # Create all_candidate_labels | |
| if self.seen_labels is None: | |
| labels_new = item['label'] | |
| else: | |
| labels_new = np.array(item['label'])[self.seen_labels] | |
| if self.useBMShort: | |
| # Instead of choosing randomly, choose 60% topmost most from the shortlist | |
| # Next sample the remaining random entries | |
| if self.seen_labels is not None: | |
| # from pdb import set_trace as bp | |
| # bp() | |
| all_candidate_labels = [self.seen_labels.tolist().index(self.label_to_id[x]) for x in self.shortlists[idx] if self.label_to_id[x] in self.seen_labels][:int(0.8*self.sampleRandom)] | |
| # print(f'BM got: {len(all_candidate_labels)}') | |
| # Choose the remaining randomly from set of seen_labels - all_candidates | |
| all_candidate_labels += np.random.choice(list({x for x in range(len(self.seen_labels))}.difference(set(all_candidate_labels))), self.sampleRandom - len(all_candidate_labels), replace = False).tolist() | |
| else: | |
| all_candidate_labels = np.random.choice(range(len(labels_new)) , self.sampleRandom , replace = False) | |
| # prepend positive labels | |
| pos_labels = np.where(np.array(labels_new) == 1)[0] | |
| all_candidate_labels = np.concatenate([pos_labels, all_candidate_labels]) | |
| # Remove duplicates | |
| all_candidate_labels = np.unique(all_candidate_labels)[:self.sampleRandom] | |
| if len(pos_labels) < self.cl_min_positive_descs: | |
| addn_pos_labels = np.random.choice(pos_labels, self.cl_min_positive_descs - len(pos_labels)) | |
| all_candidate_labels = np.concatenate([addn_pos_labels, all_candidate_labels])[:self.sampleRandom] | |
| np.random.shuffle(all_candidate_labels) | |
| item['all_candidate_labels'] = all_candidate_labels | |
| # NOTE: ids will be according to seen labels | |
| # Now update the labels based on all_candidate_labels | |
| # print('Getting Data') | |
| if self.semsup: | |
| # print(len(item['label'])) | |
| if 'all_candidate_labels' not in item: | |
| item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in range(len(item['label']))] | |
| if self.return_desc_embeddings: | |
| item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
| item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
| if self.use_precomputed_embeddings: | |
| new_indices = [i*5 + x for i,x in enumerate(item['label_desc_ids'])] | |
| # item['desc_inputs_embeds'] = [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in range(len(item['label']))] | |
| # item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] for label_key in range(len(item['label']))] | |
| if self.seen_labels is not None: | |
| new_indices = [x for i, x in enumerate(new_indices) if i in self.seen_labels] | |
| item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[new_indices] | |
| item['all_candidate_labels'] = range(len(item['label'])) | |
| if self.seen_labels is not None: | |
| item['label_desc_ids'] = (np.array(item['label_desc_ids'])[self.seen_labels]).tolist() | |
| if self.return_desc_embeddings: | |
| item['desc_input_ids'] = (np.array(item['desc_input_ids']))[self.seen_labels].tolist() | |
| item['desc_attention_mask'] = (np.array(item['desc_attention_mask']))[self.seen_labels].tolist() | |
| # if self.use_precomputed_embeddings: | |
| # item['desc_inputs_embeds'] = torch.tensor(item['desc_inputs_embeds'])[self.seen_labels] | |
| item['all_candidate_labels'] = (np.array(item['all_candidate_labels']))[self.seen_labels].tolist() | |
| item['label'] = (np.array(item['label']))[self.seen_labels].tolist() | |
| elif 'all_candidate_labels' in item: | |
| # print('Computing') | |
| st = time.time() | |
| item['label_desc_ids'] = [np.random.randint(0, self.class_descs_len[self.id_to_label[label_key]]) for label_key in range(len(item['label']))] | |
| if self.seen_labels is not None: | |
| if self.return_desc_embeddings: | |
| item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
| item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in range(len(item['label']))] | |
| if self.use_precomputed_embeddings: | |
| new_indices = [i*5 + x for i,x in enumerate(item['label_desc_ids'])] | |
| # Now of the 4271 labels, chose only the seen labels | |
| new_indices = [x for i, x in enumerate(new_indices) if i in self.seen_labels] | |
| # Now choose all_candidate labels | |
| # print(len(new_indices)) | |
| new_indices = [new_indices[x] for x in sorted(item['all_candidate_labels'])] | |
| # print(len(new_indices), len(item['all_candidate_labels'])) | |
| # if len(new_indices)!=1500: | |
| # print('Some Issue Over Here') | |
| item['desc_inputs_embeds'] = self.computed_desc_inputs_embeds[new_indices] | |
| # [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in range(len(item['label']))] | |
| # print('Mid Calculation Done', item['desc_inputs_embeds'].shape, time.time() - st) | |
| item['label_desc_ids'] = np.array(item['label_desc_ids'])[self.seen_labels].tolist() | |
| item['label'] = np.array(item['label'])[self.seen_labels].tolist() | |
| item['label'] = np.array(item['label'])[all_candidate_labels].tolist() | |
| item['desc_input_ids'] = np.array(item['desc_input_ids'])[self.seen_labels][item['all_candidate_labels']].tolist() | |
| item['desc_attention_mask'] = np.array(item['desc_attention_mask'])[self.seen_labels][item['all_candidate_labels']].tolist() | |
| # if self.use_precomputed_embeddings: | |
| # print('Starting Final Compute', time.time() - st) | |
| # item['desc_inputs_embeds'] = item['desc_inputs_embeds'][self.seen_labels][item['all_candidate_labels']]#.tolist() | |
| # print('Computed', type(item['desc_inputs_embeds']), type(item['desc_inputs_embeds'][0]), time.time() - st) | |
| else: | |
| item['label'] = np.array(item['label'])[all_candidate_labels].tolist() | |
| if self.return_desc_embeddings: | |
| item['desc_input_ids'] = [self.class_descs_tokenized[self.id_to_label[label_key]][0][item['label_desc_ids'][label_key]] for label_key in np.array(item['all_candidate_labels'])] | |
| item['desc_attention_mask'] = [self.class_descs_tokenized[self.id_to_label[label_key]][1][item['label_desc_ids'][label_key]] for label_key in np.array(item['all_candidate_labels'])] | |
| if self.use_precomputed_embeddings: | |
| item['desc_inputs_embeds'] = [self.computed_desc_inputs_embeds[ item['label_desc_ids'][label_key], self.label_to_id[self.id_to_label[label_key]] ] for label_key in np.array(item['all_candidate_labels'])] | |
| if self.ignored_labels: | |
| if self.sampleRandom != -1 and self.seen_labels is not None: | |
| ignored_labels = [self.seen_labels.tolist().index(self.label_to_id[x]) for x in self.ignored_labels[idx]] | |
| item['all_candidate_labels'] = item['all_candidate_labels'].tolist() | |
| else: | |
| ignored_labels = [self.label_to_id[x] for x in self.ignored_labels[idx]] | |
| remove_pts = [item['all_candidate_labels'].index(x) for x in ignored_labels if x in item['all_candidate_labels']] | |
| keep_pts = [x for x in range(len(item['all_candidate_labels'])) if x not in remove_pts] | |
| # Keep pts can be less than sampleRandom. Manually pad after choosing some values | |
| # print('Before Len', len(keep_pts), len(item['desc_input_ids'])) | |
| if self.sampleRandom!=-1 and len(keep_pts) < self.sampleRandom: | |
| # print('Inside the choice function') | |
| keep_pts += np.random.choice(keep_pts, self.sampleRandom - len(keep_pts), replace = False).tolist() | |
| # print('After Len', len(keep_pts), len(item['desc_input_ids'])) | |
| # print(len(keep_pts), max(keep_pts)) | |
| item['desc_input_ids'] = np.array(item['desc_input_ids'])[keep_pts].tolist() | |
| item['desc_attention_mask'] = np.array(item['desc_attention_mask'])[keep_pts].tolist() | |
| if 'desc_inputs_embeds' in item: | |
| item['desc_inputs_embeds'] = np.array(item['desc_inputs_embeds'])[keep_pts].tolist() | |
| item['label_desc_ids'] = np.array(item['label_desc_ids'])[keep_pts].tolist() | |
| item['label'] = np.array(item['label'])[keep_pts].tolist() | |
| if self.coil_cluster_map: | |
| map_to_cluster = lambda x : self.coil_cluster_map[str(x)] | |
| if isinstance(item['input_ids'], list): | |
| item['clustered_input_ids'] = [self.coil_cluster_map[str(x)] for x in item['input_ids']] | |
| else: | |
| item['clustered_input_ids'] = item['input_ids'].vectorize(map_to_cluster) | |
| item['clustered_desc_ids'] = [[self.coil_cluster_map[str(x)] for x in xx] for xx in item['desc_input_ids']] | |
| return item | |
| else: | |
| return item | |