Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import Dict, List | |
| class NextSentencePredictionTokenizer: | |
| def __init__(self, _tokenizer, **_tokenizer_args): | |
| self.tokenizer = _tokenizer | |
| self.tokenizer_args = _tokenizer_args | |
| self.max_length_ctx = self.tokenizer_args.get("max_length_ctx") | |
| self.max_length_res = self.tokenizer_args.get("max_length_res") | |
| self.special_token = self.tokenizer_args.get("special_token") | |
| self.tokenizer_args["max_length"] = self.max_length_ctx + self.max_length_res | |
| # cleaning | |
| for key_to_delete in ["special_token", "naive_approach", "max_length_ctx", "max_length_res", "approach"]: | |
| if key_to_delete in self.tokenizer_args: | |
| del self.tokenizer_args[key_to_delete] | |
| def get_item(self, context: List[str], actual_sentence: str): | |
| context_str = f" {self.special_token} ".join(context) if self.special_token != " " else " ".join(context) | |
| actual_item = {"ctx": context_str, "res": actual_sentence} | |
| tokenized = self._tokenize_row(actual_item) | |
| for key in tokenized.data.keys(): | |
| tokenized.data[key] = torch.reshape(torch.from_numpy(tokenized.data[key]), (1, -1)) | |
| return tokenized | |
| def _tokenize_row(self, row: Dict): | |
| ctx_tokens = row["ctx"].split(" ") | |
| res_tokens = row["res"].split(" ") | |
| # -5 for additional information like [SEP], [CLS] | |
| ctx_tokens = ctx_tokens[-self.max_length_ctx:] | |
| res_tokens = res_tokens[-self.max_length_res:] | |
| _args = (ctx_tokens, res_tokens) | |
| tokenized_row = self.tokenizer(*_args, **self.tokenizer_args) | |
| return tokenized_row | |