Spaces:
Runtime error
Runtime error
| from torch import nn | |
| import transformers | |
| from typing import List | |
| def get_class(_model_package, _model_class): | |
| mod = __import__(_model_package, fromlist=[_model_class]) | |
| return getattr(mod, _model_class) | |
| class OwnBertOnlyNSPHead(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64]) | |
| def forward(self, pooled_output): | |
| seq_relationship_score = self.seq_relationship(pooled_output) | |
| return seq_relationship_score | |
| def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()): | |
| module_list = [] | |
| _init_size = init_size | |
| for layer_dimension in layer_dimensions: | |
| module_list.append(nn.Linear(_init_size, layer_dimension)) | |
| module_list.append(activation) | |
| _init_size = layer_dimension | |
| module_list.append(nn.Linear(_init_size, 2)) | |
| return nn.Sequential(*module_list) | |
| class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # reinit cls layer to be more powerful | |
| self.cls = OwnBertOnlyNSPHead(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |