| """This module uses parts of rut5compressed. It shares the same module | |
| structure as model used in neural network compression experiments with | |
| rut5compressed. | |
| """ | |
| from functools import partial | |
| from typing import Optional | |
| import torch as T | |
| from transformers import BartForConditionalGeneration | |
| from .configuration_bart import SVDCompressedBartConfig | |
| from .modules import SVDCompressedLinear | |
| from .util import compress_linear_svd, map_module | |
| class SVDCompressedBartForConditionGeneration(BartForConditionalGeneration): | |
| """Class SVDCompressedBartForConditionGeneration defines a BART-based model | |
| with compressed linear layers with SVD. | |
| """ | |
| LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' | |
| config_class = SVDCompressedBartConfig | |
| def __init__(self, config: SVDCompressedBartConfig, | |
| rank: Optional[int] = None, | |
| compress: bool = False): | |
| super().__init__(config) | |
| self.rank = rank or config.rank | |
| compress_fn = partial(compress_linear_svd, rank=self.rank) | |
| if not compress: | |
| compress_fn = self.convert | |
| self.model = map_module(self.model, compress_fn, self.LAYERS) | |
| def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: | |
| if not isinstance(module, T.nn.Linear): | |
| return module | |
| return SVDCompressedLinear.from_random(module.in_features, | |
| module.out_features, self.rank) | |
| SVDCompressedBartForConditionGeneration \ | |
| .register_for_auto_class('AutoModelForSeq2SeqLM') | |