| """ | |
| Classes reused from: | |
| 1. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fp32_group_norm.py | |
| 2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/same_pad.py | |
| 3. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fairseq_dropout.py | |
| """ | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Fp32GroupNorm(nn.GroupNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.group_norm( | |
| input.float(), | |
| self.num_groups, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| class SamePad(nn.Module): | |
| def __init__(self, kernel_size, causal=False): | |
| super().__init__() | |
| if causal: | |
| self.remove = kernel_size - 1 | |
| else: | |
| self.remove = 1 if kernel_size % 2 == 0 else 0 | |
| def forward(self, x): | |
| if self.remove > 0: | |
| x = x[:, :, : -self.remove] | |
| return x | |
| class FairseqDropout(nn.Module): | |
| def __init__(self, p, module_name=None): | |
| super().__init__() | |
| self.p = p | |
| self.module_name = module_name | |
| self.apply_during_inference = False | |
| def forward(self, x, inplace: bool = False): | |
| if self.p > 0 and (self.training or self.apply_during_inference): | |
| return F.dropout(x, p=self.p, training=True, inplace=inplace) | |
| else: | |
| return x | |