Spaces:
Runtime error
Runtime error
| """ | |
| Metrics for multi-label text classification | |
| """ | |
| import numpy as np | |
| from scipy.special import expit | |
| import itertools | |
| import copy | |
| import torch | |
| import pickle | |
| import os | |
| import json | |
| # Metrics | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| f1_score, | |
| classification_report, | |
| precision_score, | |
| recall_score, | |
| label_ranking_average_precision_score, | |
| coverage_error | |
| ) | |
| import time | |
| def multilabel_metrics(data_args, id2label, label2id, fbr, training_args = None): | |
| """ | |
| Metrics function used for multilabel classification. | |
| :fbr : A dict containing global thresholds to be used for selecting a class. | |
| We use global thresholds because we want to handle unseen classes, | |
| for which the threshold is not known in advance. | |
| """ | |
| func_call_counts = 0 | |
| def compute_metrics(p): | |
| # global func_call_counts | |
| # Save the predictions, maintaining a global counter | |
| if training_args is not None and training_args.local_rank <= 0: | |
| preds_fol = os.path.join(training_args.output_dir, 'predictions') | |
| os.makedirs(preds_fol, exist_ok = True) | |
| func_call_counts = time.time() #np.random.randn() | |
| pickle.dump(p.predictions, open(os.path.join(preds_fol, f'predictions_{(func_call_counts+1) * training_args.eval_steps }.pkl'), 'wb')) | |
| pickle.dump(p.label_ids, open(os.path.join(preds_fol, f'label_ids_{(func_call_counts+1) * training_args.eval_steps }.pkl'), 'wb')) | |
| # func_call_counts += 1 | |
| # print(func_call_counts) | |
| # Collect the logits | |
| print('Here we go!') | |
| preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
| # Compute the logistic sigmoid | |
| preds = expit(preds) | |
| # METRIC 0: Compute P@1, P@3, P@5 | |
| if 'precision' not in fbr.keys(): | |
| top_values = [1,3,5] | |
| # denoms = {k:0 for k in top_values} | |
| tops = {k:0 for k in top_values} | |
| i = 0 | |
| preds_preck = None | |
| print(p.label_ids) | |
| for i, (logit, label) in enumerate(zip(p.predictions[0], p.label_ids)): | |
| logit = torch.from_numpy(logit) | |
| label = torch.from_numpy(label) | |
| _, indexes = torch.topk(logit.float(), k = max(top_values)) | |
| for val in top_values: | |
| if preds_preck is None: | |
| tops[val] += len([x for x in indexes[:val] if label[x]!=0]) | |
| else: | |
| tops[val] += len([x for x in preds_preck[i][indexes[:val]] if label[x]!=0]) | |
| # denoms[val] += min(val, label.nonzero().shape[0]) | |
| precisions_at_k = {k:v/((i+1)*k) for k,v in tops.items()} | |
| # rprecisions_at_k = {k:v/denoms[v] for k,v in tops.items()} | |
| print('Evaluation Result: precision@{} = {}'.format(top_values, precisions_at_k)) | |
| # print('Evaluation Result: rprecision@{} = {}'.format(top_values, rprecisions_at_k)) | |
| # p.predictions = p.predictions[0] | |
| # p.label_ids = p.label_ids[0] | |
| # METRIC 1: Compute accuracy | |
| if 'accuracy' not in fbr.keys(): | |
| performance = {} | |
| for threshold in np.arange(0.1, 1, 0.1): | |
| accuracy_preds = np.where(preds > threshold, 1, 0) | |
| performance[threshold] = np.sum(p.label_ids == accuracy_preds) / accuracy_preds.size * 100 | |
| # Choose the best threshold | |
| best_threshold = max(performance, key=performance.get) | |
| fbr['accuracy'] = best_threshold | |
| accuracy = performance[best_threshold] | |
| else: | |
| accuracy_preds = np.where(preds > fbr['accuracy'], 1, 0) | |
| accuracy = np.sum(p.label_ids == accuracy_preds) / accuracy_preds.size * 100 | |
| # METRIC 2: Compute the subset accuracy | |
| if 'subset_accuracy' not in fbr.keys(): | |
| performance = {} | |
| for threshold in np.arange(0.1, 1, 0.1): | |
| subset_accuracy_preds = np.where(preds > threshold, 1, 0) | |
| performance[threshold] = accuracy_score(p.label_ids, subset_accuracy_preds) | |
| # Choose the best threshold | |
| best_threshold = max(performance, key=performance.get) | |
| fbr['subset_accuracy'] = best_threshold | |
| subset_accuracy = performance[best_threshold] | |
| else: | |
| subset_accuracy_preds = np.where(preds > fbr['subset_accuracy'], 1, 0) | |
| subset_accuracy = accuracy_score(p.label_ids, subset_accuracy_preds) | |
| # METRIC 3: Macro F-1 | |
| if 'macro_f1' not in fbr.keys(): | |
| performance = {} | |
| for threshold in np.arange(0.1, 1, 0.1): | |
| macro_f1_preds = np.where(preds > threshold, 1, 0) | |
| performance[threshold] = f1_score(p.label_ids, macro_f1_preds, average='macro') | |
| # Choose the best threshold | |
| best_threshold = max(performance, key=performance.get) | |
| fbr['macro_f1'] = best_threshold | |
| macro_f1 = performance[best_threshold] | |
| else: | |
| macro_f1_preds = np.where(preds > fbr['macro_f1'], 1, 0) | |
| macro_f1 = f1_score(p.label_ids, macro_f1_preds, average='macro') | |
| # METRIC 4: Micro F-1 | |
| if 'micro_f1' not in fbr.keys(): | |
| performance = {} | |
| for threshold in np.arange(0.1, 1, 0.1): | |
| micro_f1_preds = np.where(preds > threshold, 1, 0) | |
| performance[threshold] = f1_score(p.label_ids, micro_f1_preds, average='micro') | |
| # Choose the best threshold | |
| best_threshold = max(performance, key=performance.get) | |
| fbr['micro_f1'] = best_threshold | |
| micro_f1 = performance[best_threshold] | |
| else: | |
| micro_f1_preds = np.where(preds > fbr['micro_f1'], 1, 0) | |
| micro_f1 = f1_score(p.label_ids, micro_f1_preds, average='micro') | |
| # Multi-label classification report | |
| # Optimized for Micro F-1 | |
| try: | |
| report = classification_report(p.label_ids, micro_f1_preds, target_names=[id2label[i] for i in range(len(id2label))]) | |
| print('Classification Report: \n', report) | |
| except: | |
| report = classification_report(p.label_ids, micro_f1_preds) | |
| print('Classification Report: \n', report) | |
| return_dict = { | |
| "accuracy": accuracy, | |
| "subset_accuracy": subset_accuracy, | |
| "macro_f1": macro_f1, | |
| "micro_f1": micro_f1, | |
| # "hier_micro_f1": hier_micro_f1, | |
| "fbr": fbr | |
| } | |
| for k in precisions_at_k: | |
| return_dict[f'P@{k}'] = precisions_at_k[k] | |
| if training_args is not None and training_args.local_rank <= 0: | |
| try: | |
| metrics_fol = os.path.join(training_args.output_dir, 'metrics') | |
| os.makedirs(metrics_fol, exist_ok = True) | |
| json.dump(return_dict, open(os.path.join(metrics_fol, f'metrics_{(func_call_counts+1) * training_args.eval_steps }.json'), 'w'), indent = 2) | |
| except Exception as e: | |
| print('Error in metrics', e) | |
| return return_dict | |
| return compute_metrics |