|
|
""" |
|
|
The `InferenceHubertBase` class is a lightweight version of the model from this repository: |
|
|
https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/hubert/hubert.py#L248C5-L248C6 |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
|
|
|
from .fairseq_modules import Fp32GroupNorm, SamePad, FairseqDropout |
|
|
|
|
|
|
|
|
class InferenceHubertBase(nn.Module): |
|
|
def __init__(self, *args, **kwargs) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
self.feature_extractor = ConvFeatureExtractor() |
|
|
self.layer_norm = nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True) |
|
|
self.post_extract_proj = nn.Linear(in_features=512, out_features=768, bias=True) |
|
|
self.dropout_input = nn.Dropout(p=0.1, inplace=False) |
|
|
self.dropout_features = nn.Dropout(p=0.1, inplace=False) |
|
|
self.encoder = TransformerEncoder() |
|
|
|
|
|
def extract_features( |
|
|
self, |
|
|
source: Tensor, |
|
|
padding_mask: Optional[Tensor] = None, |
|
|
output_layer: int = 12, |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
features = self.feature_extractor(source).transpose(1, 2) |
|
|
features = self.layer_norm(features) |
|
|
if padding_mask is not None: |
|
|
padding_mask = self.__apply_padding_mask(features, padding_mask) |
|
|
features = self.post_extract_proj(features) |
|
|
features = self.dropout_input(features) |
|
|
features = self.encoder( |
|
|
features, padding_mask=padding_mask, tgt_layer=output_layer - 1 |
|
|
) |
|
|
return features, padding_mask |
|
|
|
|
|
def __apply_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor: |
|
|
extra = padding_mask.size(1) % features.size(1) |
|
|
if extra > 0: |
|
|
padding_mask = padding_mask[:, :-extra] |
|
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) |
|
|
padding_mask = padding_mask.all(-1) |
|
|
return padding_mask |
|
|
|
|
|
|
|
|
class ConvFeatureExtractor(nn.Module): |
|
|
def __init__(self, *args, **kwargs) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
conv_layers = [ |
|
|
nn.Sequential( |
|
|
nn.Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False), |
|
|
nn.Dropout(p=0.0, inplace=False), |
|
|
Fp32GroupNorm(512, 512, eps=1e-05, affine=True), |
|
|
nn.GELU(approximate="none"), |
|
|
), |
|
|
*[ |
|
|
nn.Sequential( |
|
|
nn.Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False), |
|
|
nn.Dropout(p=0.0, inplace=False), |
|
|
nn.GELU(approximate="none"), |
|
|
) |
|
|
for _ in range(4) |
|
|
], |
|
|
*[ |
|
|
nn.Sequential( |
|
|
nn.Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False), |
|
|
nn.Dropout(p=0.0, inplace=False), |
|
|
nn.GELU(approximate="none"), |
|
|
) |
|
|
for _ in range(2) |
|
|
], |
|
|
] |
|
|
self.conv_layers = nn.ModuleList(conv_layers) |
|
|
|
|
|
def forward(self, x: Tensor): |
|
|
x = x.unsqueeze(1) |
|
|
for conv in self.conv_layers: |
|
|
x = conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, dropout=0.1, required_seq_len_multiple=2, *args, **kwargs |
|
|
) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
self.dropout = dropout |
|
|
self.required_seq_len_multiple = required_seq_len_multiple |
|
|
|
|
|
pos_conv = nn.Conv1d( |
|
|
768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16 |
|
|
) |
|
|
self.pos_conv = nn.Sequential( |
|
|
nn.utils.weight_norm(pos_conv, name="weight", dim=2), |
|
|
SamePad(128), |
|
|
nn.GELU(approximate="none"), |
|
|
) |
|
|
self.layers = nn.ModuleList( |
|
|
[TransformerSentenceEncoderLayer() for _ in range(12)] |
|
|
) |
|
|
self.layer_norm = nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, x: Tensor, padding_mask=None, tgt_layer=None): |
|
|
if padding_mask is not None: |
|
|
|
|
|
x[padding_mask] = 0 |
|
|
|
|
|
x_conv = self.pos_conv(x.transpose(1, 2)) |
|
|
x_conv = x_conv.transpose(1, 2) |
|
|
x = x + x_conv |
|
|
|
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
x, pad_length = pad_to_multiple( |
|
|
x, self.required_seq_len_multiple, dim=-2, value=0 |
|
|
) |
|
|
if pad_length > 0 and padding_mask is None: |
|
|
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) |
|
|
padding_mask[:, -pad_length:] = True |
|
|
else: |
|
|
padding_mask, _ = pad_to_multiple( |
|
|
padding_mask, self.required_seq_len_multiple, dim=-1, value=True |
|
|
) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
|
x, _ = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) |
|
|
if i == tgt_layer: |
|
|
break |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerSentenceEncoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: float = 768, |
|
|
ffn_embedding_dim: float = 3072, |
|
|
num_attention_heads: int = 12, |
|
|
dropout: float = 0.1, |
|
|
attention_dropout: float = 0.1, |
|
|
activation_dropout: float = 0.1, |
|
|
layer_norm_first: bool = False, |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
self.embedding_dim = embedding_dim |
|
|
self.ffn_embedding_dim = ffn_embedding_dim |
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
|
|
self.self_attn = MultiheadAttention( |
|
|
self.embedding_dim, |
|
|
num_attention_heads, |
|
|
dropout=attention_dropout, |
|
|
) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(activation_dropout) |
|
|
self.dropout3 = nn.Dropout(dropout) |
|
|
self.layer_norm_first = layer_norm_first |
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) |
|
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
|
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
|
|
self.final_layer_norm = nn.LayerNorm(self.embedding_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
self_attn_mask: torch.Tensor = None, |
|
|
self_attn_padding_mask: torch.Tensor = None, |
|
|
need_weights: bool = False, |
|
|
att_args=None, |
|
|
): |
|
|
residual = x |
|
|
x, attn = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
key_padding_mask=self_attn_padding_mask, |
|
|
need_weights=False, |
|
|
) |
|
|
|
|
|
x = self.dropout1(x) |
|
|
x = residual + x |
|
|
|
|
|
x = self.self_attn_layer_norm(x) |
|
|
|
|
|
residual = x |
|
|
x = F.gelu(self.fc1(x).float()).type_as(x) |
|
|
x = self.dropout2(x) |
|
|
x = self.fc2(x) |
|
|
|
|
|
layer_result = x |
|
|
|
|
|
x = self.dropout3(x) |
|
|
x = residual + x |
|
|
x = self.final_layer_norm(x) |
|
|
|
|
|
return x, (attn, layer_result) |
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
def __init__( |
|
|
self, embed_dim: int, num_heads: int, dropout=0.1, bias=True, *args, **kwargs |
|
|
) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
self.dropout_module = FairseqDropout(p=dropout) |
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: Tensor, |
|
|
key: Tensor, |
|
|
value: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
need_weights: bool = False, |
|
|
attn_mask: Optional[Tensor] = None, |
|
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
|
src_len = tgt_len |
|
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" |
|
|
src_len, key_bsz, _ = key.size() |
|
|
assert src_len, key_bsz == value.shape[:2] |
|
|
return F.multi_head_attention_forward( |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
self.embed_dim, |
|
|
self.num_heads, |
|
|
torch.empty([0]), |
|
|
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), |
|
|
None, |
|
|
None, |
|
|
False, |
|
|
self.dropout_module.p, |
|
|
self.out_proj.weight, |
|
|
self.out_proj.bias, |
|
|
self.training or self.dropout_module.apply_during_inference, |
|
|
key_padding_mask.bool() if key_padding_mask is not None else None, |
|
|
need_weights, |
|
|
attn_mask, |
|
|
use_separate_proj_weight=True, |
|
|
q_proj_weight=self.q_proj.weight, |
|
|
k_proj_weight=self.k_proj.weight, |
|
|
v_proj_weight=self.v_proj.weight, |
|
|
) |
|
|
|
|
|
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0): |
|
|
|
|
|
if x is None: |
|
|
return None, 0 |
|
|
tsz = x.size(dim) |
|
|
m = tsz / multiple |
|
|
remainder = math.ceil(m) * multiple - tsz |
|
|
if m.is_integer(): |
|
|
return x, 0 |
|
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
|
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder |
|
|
|