""" The `InferenceHubertBase` class is a lightweight version of the model from this repository: https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/hubert/hubert.py#L248C5-L248C6 """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from .fairseq_modules import Fp32GroupNorm, SamePad, FairseqDropout class InferenceHubertBase(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.feature_extractor = ConvFeatureExtractor() self.layer_norm = nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True) self.post_extract_proj = nn.Linear(in_features=512, out_features=768, bias=True) self.dropout_input = nn.Dropout(p=0.1, inplace=False) self.dropout_features = nn.Dropout(p=0.1, inplace=False) self.encoder = TransformerEncoder() def extract_features( self, source: Tensor, padding_mask: Optional[Tensor] = None, output_layer: int = 12, ) -> Tuple[Tensor, Tensor]: features = self.feature_extractor(source).transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.__apply_padding_mask(features, padding_mask) features = self.post_extract_proj(features) features = self.dropout_input(features) features = self.encoder( features, padding_mask=padding_mask, tgt_layer=output_layer - 1 ) return features, padding_mask def __apply_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask class ConvFeatureExtractor(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) conv_layers = [ nn.Sequential( nn.Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False), nn.Dropout(p=0.0, inplace=False), Fp32GroupNorm(512, 512, eps=1e-05, affine=True), nn.GELU(approximate="none"), ), *[ nn.Sequential( nn.Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False), nn.Dropout(p=0.0, inplace=False), nn.GELU(approximate="none"), ) for _ in range(4) ], *[ nn.Sequential( nn.Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False), nn.Dropout(p=0.0, inplace=False), nn.GELU(approximate="none"), ) for _ in range(2) ], ] self.conv_layers = nn.ModuleList(conv_layers) def forward(self, x: Tensor): x = x.unsqueeze(1) for conv in self.conv_layers: x = conv(x) return x class TransformerEncoder(nn.Module): def __init__( self, dropout=0.1, required_seq_len_multiple=2, *args, **kwargs ) -> None: super().__init__(*args, **kwargs) self.dropout = dropout # 0.1 self.required_seq_len_multiple = required_seq_len_multiple # 2 pos_conv = nn.Conv1d( 768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16 ) self.pos_conv = nn.Sequential( nn.utils.weight_norm(pos_conv, name="weight", dim=2), SamePad(128), nn.GELU(approximate="none"), ) self.layers = nn.ModuleList( [TransformerSentenceEncoderLayer() for _ in range(12)] ) self.layer_norm = nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True) @torch.no_grad() def forward(self, x: Tensor, padding_mask=None, tgt_layer=None): if padding_mask is not None: # x = index_put(x, padding_mask, 0) x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv x = self.layer_norm(x) # pad to the sequence length dimension x, pad_length = pad_to_multiple( x, self.required_seq_len_multiple, dim=-2, value=0 ) if pad_length > 0 and padding_mask is None: padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) padding_mask[:, -pad_length:] = True else: padding_mask, _ = pad_to_multiple( padding_mask, self.required_seq_len_multiple, dim=-1, value=True ) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) for i, layer in enumerate(self.layers): x, _ = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) if i == tgt_layer: break # T x B x C -> B x T x C x = x.transpose(0, 1) return x class TransformerSentenceEncoderLayer(nn.Module): def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: int = 12, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, layer_norm_first: bool = False, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.embedding_dim = embedding_dim self.ffn_embedding_dim = ffn_embedding_dim self.num_attention_heads = num_attention_heads self.self_attn = MultiheadAttention( self.embedding_dim, # 768 num_attention_heads, # 12 dropout=attention_dropout, # 0.1 ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.final_layer_norm = nn.LayerNorm(self.embedding_dim) def forward( self, x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, need_weights: bool = False, att_args=None, ): residual = x x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False, ) x = self.dropout1(x) x = residual + x x = self.self_attn_layer_norm(x) residual = x x = F.gelu(self.fc1(x).float()).type_as(x) x = self.dropout2(x) x = self.fc2(x) layer_result = x x = self.dropout3(x) x = residual + x x = self.final_layer_norm(x) return x, (attn, layer_result) class MultiheadAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, dropout=0.1, bias=True, *args, **kwargs ) -> None: super().__init__(*args, **kwargs) self.embed_dim = embed_dim self.num_heads = num_heads self.dropout_module = FairseqDropout(p=dropout) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = False, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: tgt_len, bsz, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" src_len, key_bsz, _ = key.size() assert src_len, key_bsz == value.shape[:2] return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), None, None, False, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) def pad_to_multiple(x, multiple, dim=-1, value=0): # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 if x is None: return None, 0 tsz = x.size(dim) m = tsz / multiple remainder = math.ceil(m) * multiple - tsz if m.is_integer(): return x, 0 pad_offset = (0,) * (-1 - dim) * 2 return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder