SazerLife's picture
feat: added model
36a67ca
"""
The `InferenceHubertBase` class is a lightweight version of the model from this repository:
https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/hubert/hubert.py#L248C5-L248C6
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .fairseq_modules import Fp32GroupNorm, SamePad, FairseqDropout
class InferenceHubertBase(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.feature_extractor = ConvFeatureExtractor()
self.layer_norm = nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True)
self.post_extract_proj = nn.Linear(in_features=512, out_features=768, bias=True)
self.dropout_input = nn.Dropout(p=0.1, inplace=False)
self.dropout_features = nn.Dropout(p=0.1, inplace=False)
self.encoder = TransformerEncoder()
def extract_features(
self,
source: Tensor,
padding_mask: Optional[Tensor] = None,
output_layer: int = 12,
) -> Tuple[Tensor, Tensor]:
features = self.feature_extractor(source).transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.__apply_padding_mask(features, padding_mask)
features = self.post_extract_proj(features)
features = self.dropout_input(features)
features = self.encoder(
features, padding_mask=padding_mask, tgt_layer=output_layer - 1
)
return features, padding_mask
def __apply_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
class ConvFeatureExtractor(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
conv_layers = [
nn.Sequential(
nn.Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False),
nn.Dropout(p=0.0, inplace=False),
Fp32GroupNorm(512, 512, eps=1e-05, affine=True),
nn.GELU(approximate="none"),
),
*[
nn.Sequential(
nn.Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False),
nn.Dropout(p=0.0, inplace=False),
nn.GELU(approximate="none"),
)
for _ in range(4)
],
*[
nn.Sequential(
nn.Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False),
nn.Dropout(p=0.0, inplace=False),
nn.GELU(approximate="none"),
)
for _ in range(2)
],
]
self.conv_layers = nn.ModuleList(conv_layers)
def forward(self, x: Tensor):
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
class TransformerEncoder(nn.Module):
def __init__(
self, dropout=0.1, required_seq_len_multiple=2, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.dropout = dropout # 0.1
self.required_seq_len_multiple = required_seq_len_multiple # 2
pos_conv = nn.Conv1d(
768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
)
self.pos_conv = nn.Sequential(
nn.utils.weight_norm(pos_conv, name="weight", dim=2),
SamePad(128),
nn.GELU(approximate="none"),
)
self.layers = nn.ModuleList(
[TransformerSentenceEncoderLayer() for _ in range(12)]
)
self.layer_norm = nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True)
@torch.no_grad()
def forward(self, x: Tensor, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
# x = index_put(x, padding_mask, 0)
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
for i, layer in enumerate(self.layers):
x, _ = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
if i == tgt_layer:
break
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: int = 12,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
layer_norm_first: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.embedding_dim = embedding_dim
self.ffn_embedding_dim = ffn_embedding_dim
self.num_attention_heads = num_attention_heads
self.self_attn = MultiheadAttention(
self.embedding_dim, # 768
num_attention_heads, # 12
dropout=attention_dropout, # 0.1
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
):
residual = x
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = F.gelu(self.fc1(x).float()).type_as(x)
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, (attn, layer_result)
class MultiheadAttention(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, dropout=0.1, bias=True, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(p=dropout)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
src_len, key_bsz, _ = key.size()
assert src_len, key_bsz == value.shape[:2]
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
None,
None,
False,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask.bool() if key_padding_mask is not None else None,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
def pad_to_multiple(x, multiple, dim=-1, value=0):
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder