| import torch | |
| import torch.nn as nn | |
| from .attentions import MultiHeadAttention | |
| class VAEMemoryBank(nn.Module): | |
| def __init__( | |
| self, | |
| bank_size=1000, | |
| n_hidden_dims=512, | |
| n_attn_heads=2, | |
| init_values=None, | |
| output_channels=192, | |
| ): | |
| super().__init__() | |
| self.bank_size = bank_size | |
| self.n_hidden_dims = n_hidden_dims | |
| self.n_attn_heads = n_attn_heads | |
| self.encoder = MultiHeadAttention( | |
| channels=n_hidden_dims, | |
| out_channels=n_hidden_dims, | |
| n_heads=n_attn_heads, | |
| ) | |
| self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size)) | |
| self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1) | |
| if init_values is not None: | |
| with torch.no_grad(): | |
| self.memory_bank.copy_(init_values) | |
| def forward(self, z: torch.Tensor): | |
| b, _, _ = z.shape | |
| ret = self.encoder( | |
| z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None | |
| ) | |
| ret = self.proj(ret) | |
| return ret | |