DINO-HuVITS / src /inference_hubert /fairseq_modules.py
SazerLife's picture
feat: added model
36a67ca
"""
Classes reused from:
1. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fp32_group_norm.py
2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/same_pad.py
3. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fairseq_dropout.py
"""
import torch.nn as nn
import torch.nn.functional as F
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class FairseqDropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
else:
return x