| | |
| | import copy |
| | import numbers |
| | from functools import partial |
| | from typing import Any |
| | from typing import Callable |
| | from typing import List |
| | from typing import Optional |
| | from typing import Tuple |
| | from typing import Union |
| |
|
| | import torch |
| | from .activation import MultiheadAttention |
| | from .scaling import BalancedDoubleSwish |
| | from torch import nn |
| | from torch import Tensor |
| | from torch.nn import functional as F |
| |
|
| | _shape_t = Union[int, List[int], torch.Size] |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] |
| | normalized_shape: Tuple[int, ...] |
| | eps: float |
| | elementwise_affine: bool |
| |
|
| | def __init__( |
| | self, |
| | normalized_shape: _shape_t, |
| | eps: float = 1e-5, |
| | elementwise_affine: bool = True, |
| | device=None, |
| | dtype=None, |
| | ) -> None: |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super(LayerNorm, self).__init__() |
| | if isinstance(normalized_shape, numbers.Integral): |
| | |
| | normalized_shape = (normalized_shape,) |
| | self.normalized_shape = tuple(normalized_shape) |
| | self.eps = eps |
| | self.elementwise_affine = elementwise_affine |
| | if self.elementwise_affine: |
| | self.weight = nn.Parameter( |
| | torch.empty(self.normalized_shape, **factory_kwargs) |
| | ) |
| | self.bias = nn.Parameter( |
| | torch.empty(self.normalized_shape, **factory_kwargs) |
| | ) |
| | else: |
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self) -> None: |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | nn.init.zeros_(self.bias) |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | return ( |
| | F.layer_norm( |
| | input, |
| | self.normalized_shape, |
| | self.weight, |
| | self.bias, |
| | self.eps, |
| | ), |
| | embedding, |
| | ) |
| |
|
| | assert embedding is None |
| | return F.layer_norm( |
| | input, self.normalized_shape, self.weight, self.bias, self.eps |
| | ) |
| |
|
| | def extra_repr(self) -> str: |
| | return ( |
| | "{normalized_shape}, eps={eps}, " |
| | "elementwise_affine={elementwise_affine}".format(**self.__dict__) |
| | ) |
| |
|
| |
|
| | class IdentityNorm(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | eps: float = 1e-5, |
| | device=None, |
| | dtype=None, |
| | ) -> None: |
| | super(IdentityNorm, self).__init__() |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | return input |
| |
|
| | assert embedding is None |
| | return input |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | r"""TransformerEncoder is a stack of N encoder layers. Users can build the |
| | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. |
| | |
| | Args: |
| | encoder_layer: an instance of the TransformerEncoderLayer() class (required). |
| | num_layers: the number of sub-encoder-layers in the encoder (required). |
| | norm: the layer normalization component (optional). |
| | enable_nested_tensor: if True, input will automatically convert to nested tensor |
| | (and convert back on output). This will improve the overall performance of |
| | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). |
| | |
| | Examples:: |
| | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) |
| | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) |
| | >>> src = torch.rand(10, 32, 512) |
| | >>> out = transformer_encoder(src) |
| | """ |
| | __constants__ = ["norm"] |
| |
|
| | def __init__(self, encoder_layer, num_layers, norm=None): |
| | super(TransformerEncoder, self).__init__() |
| | self.layers = _get_clones(encoder_layer, num_layers) |
| | self.num_layers = num_layers |
| | self.norm = norm |
| |
|
| | def forward( |
| | self, |
| | src: Tensor, |
| | mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | return_layer_states: bool = False, |
| | cache=None, |
| | ) -> Tensor: |
| | r"""Pass the input through the encoder layers in turn. |
| | |
| | Args: |
| | src: the sequence to the encoder (required). |
| | mask: the mask for the src sequence (optional). |
| | src_key_padding_mask: the mask for the src keys per batch (optional). |
| | return_layer_states: return layers' state (optional). |
| | |
| | Shape: |
| | see the docs in Transformer class. |
| | """ |
| | if return_layer_states: |
| | layer_states = [] |
| | output = src |
| | for mod in self.layers: |
| | output = mod( |
| | output, |
| | src_mask=mask, |
| | src_key_padding_mask=src_key_padding_mask, |
| | cache=cache, |
| | ) |
| | layer_states.append(output[0]) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| |
|
| | return layer_states, output |
| |
|
| | output = src |
| | for mod in self.layers: |
| | output = mod( |
| | output, |
| | src_mask=mask, |
| | src_key_padding_mask=src_key_padding_mask, |
| | cache=cache, |
| | ) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| |
|
| | return output |
| |
|
| |
|
| | class TransformerEncoderLayer(nn.Module): |
| | __constants__ = ["batch_first", "norm_first"] |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | dim_feedforward: int = 2048, |
| | dropout: float = 0.1, |
| | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, |
| | batch_first: bool = False, |
| | norm_first: bool = False, |
| | device=None, |
| | dtype=None, |
| | linear1_self_attention_cls: nn.Module = nn.Linear, |
| | linear2_self_attention_cls: nn.Module = nn.Linear, |
| | linear1_feedforward_cls: nn.Module = nn.Linear, |
| | linear2_feedforward_cls: nn.Module = nn.Linear, |
| | layer_norm_cls: nn.Module = LayerNorm, |
| | layer_norm_eps: float = 1e-5, |
| | adaptive_layer_norm=False, |
| | ) -> None: |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super(TransformerEncoderLayer, self).__init__() |
| | |
| | |
| | |
| | self.self_attn = MultiheadAttention( |
| | d_model, |
| | nhead, |
| | dropout=dropout, |
| | batch_first=batch_first, |
| | linear1_cls=linear1_self_attention_cls, |
| | linear2_cls=linear2_self_attention_cls, |
| | **factory_kwargs, |
| | ) |
| |
|
| | |
| | self.linear1 = linear1_feedforward_cls( |
| | d_model, dim_feedforward, **factory_kwargs |
| | ) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = linear2_feedforward_cls( |
| | dim_feedforward, d_model, **factory_kwargs |
| | ) |
| |
|
| | self.norm_first = norm_first |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| |
|
| | |
| | if isinstance(activation, str): |
| | activation = _get_activation_fn(activation) |
| | elif isinstance(activation, partial): |
| | activation = activation(d_model) |
| | elif activation == BalancedDoubleSwish: |
| | activation = BalancedDoubleSwish(d_model) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.activation = activation |
| |
|
| | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) |
| | if layer_norm_cls == IdentityNorm: |
| | norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
| | else: |
| | norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) |
| |
|
| | if adaptive_layer_norm: |
| | self.norm1 = AdaptiveLayerNorm(d_model, norm1) |
| | self.norm2 = AdaptiveLayerNorm(d_model, norm2) |
| | else: |
| | self.norm1 = norm1 |
| | self.norm2 = norm2 |
| |
|
| | def __setstate__(self, state): |
| | super(TransformerEncoderLayer, self).__setstate__(state) |
| | if not hasattr(self, "activation"): |
| | self.activation = F.relu |
| |
|
| | def forward( |
| | self, |
| | src: Tensor, |
| | src_mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | cache=None, |
| | ) -> Tensor: |
| | r"""Pass the input through the encoder layer. |
| | |
| | Args: |
| | src: the sequence to the encoder layer (required). |
| | src_mask: the mask for the src sequence (optional). |
| | src_key_padding_mask: the mask for the src keys per batch (optional). |
| | |
| | Shape: |
| | see the docs in Transformer class. |
| | """ |
| | x, stage_embedding = src, None |
| | is_src_tuple = False |
| | if isinstance(src, tuple): |
| | x, stage_embedding = src |
| | is_src_tuple = True |
| |
|
| | if src_key_padding_mask is not None: |
| | _skpm_dtype = src_key_padding_mask.dtype |
| | if _skpm_dtype != torch.bool and not torch.is_floating_point( |
| | src_key_padding_mask |
| | ): |
| | raise AssertionError( |
| | "only bool and floating types of key_padding_mask are supported" |
| | ) |
| |
|
| | if self.norm_first: |
| | x = x + self._sa_block( |
| | self.norm1(x, stage_embedding), |
| | src_mask, |
| | src_key_padding_mask, |
| | cache=cache, |
| | ) |
| | x = x + self._ff_block(self.norm2(x, stage_embedding)) |
| | else: |
| | x = self.norm1( |
| | x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), |
| | stage_embedding, |
| | ) |
| | x = self.norm2(x + self._ff_block(x), stage_embedding) |
| |
|
| | if is_src_tuple: |
| | return (x, stage_embedding) |
| | return x |
| |
|
| | |
| | def _sa_block( |
| | self, |
| | x: Tensor, |
| | attn_mask: Optional[Tensor], |
| | key_padding_mask: Optional[Tensor], |
| | cache=None, |
| | ) -> Tensor: |
| | |
| | |
| | |
| | |
| | x = self.self_attn( |
| | x, |
| | x, |
| | x, |
| | attn_mask=attn_mask, |
| | key_padding_mask=key_padding_mask, |
| | need_weights=False, |
| | cache=cache, |
| | )[0] |
| | return self.dropout1(x) |
| |
|
| | |
| | def _ff_block(self, x: Tensor) -> Tensor: |
| | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| | return self.dropout2(x) |
| |
|
| |
|
| | class AdaptiveLayerNorm(nn.Module): |
| | r"""Adaptive Layer Normalization""" |
| |
|
| | def __init__(self, d_model, norm) -> None: |
| | super(AdaptiveLayerNorm, self).__init__() |
| | self.project_layer = nn.Linear(d_model, 2 * d_model) |
| | self.norm = norm |
| | self.d_model = d_model |
| | self.eps = self.norm.eps |
| |
|
| | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | weight, bias = torch.split( |
| | self.project_layer(embedding), |
| | split_size_or_sections=self.d_model, |
| | dim=-1, |
| | ) |
| | return (weight * self.norm(input) + bias, embedding) |
| |
|
| | weight, bias = torch.split( |
| | self.project_layer(embedding), |
| | split_size_or_sections=self.d_model, |
| | dim=-1, |
| | ) |
| | return weight * self.norm(input) + bias |
| |
|
| |
|
| | def _get_clones(module, N): |
| | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
| |
|