Spaces:
Sleeping
Sleeping
| from timm.models.layers import trunc_normal_ | |
| from einops import rearrange, repeat | |
| import math | |
| import torch | |
| from torch import nn | |
| import einops | |
| import torch | |
| from torch import nn | |
| from torch_geometric.nn.pool import radius_graph | |
| from torch_scatter import segment_csr | |
| import torch.nn.functional as F | |
| ACTIVATION = {'gelu': nn.GELU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid, 'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU(0.1), | |
| 'softplus': nn.Softplus, 'ELU': nn.ELU, 'silu': nn.SiLU} | |
| class ContinuousSincosEmbed(nn.Module): | |
| """Embedding layer for continuous coordinates using sine and cosine functions as used in transformers. | |
| This implementation is able to deal with arbitrary coordinate dimensions (e.g., 2D and 3D coordinate systems). | |
| Args: | |
| dim: Dimensionality of the embedded input coordinates. | |
| ndim: Number of dimensions of the input domain. | |
| max_wavelength: Max length. Defaults to 10000. | |
| assert_positive: If true, assert if all input coordiantes are positive. Defaults to True. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| ndim: int, | |
| max_wavelength: int = 10000, | |
| assert_positive: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.ndim = ndim | |
| # if dim is not cleanly divisible -> cut away trailing dimensions | |
| self.ndim_padding = dim % ndim | |
| dim_per_ndim = (dim - self.ndim_padding) // ndim | |
| self.sincos_padding = dim_per_ndim % 2 | |
| self.max_wavelength = max_wavelength | |
| self.padding = self.ndim_padding + self.sincos_padding * ndim | |
| self.assert_positive = assert_positive | |
| effective_dim_per_wave = (self.dim - self.padding) // ndim | |
| assert effective_dim_per_wave > 0 | |
| arange = torch.arange(0, effective_dim_per_wave, 2, dtype=torch.float32) | |
| self.register_buffer( | |
| "omega", | |
| 1.0 / max_wavelength**(arange / effective_dim_per_wave), | |
| ) | |
| self.surface_bias = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.GELU(), | |
| nn.Linear(dim, dim), | |
| ) | |
| def forward(self, coords: torch.Tensor) -> torch.Tensor: | |
| """Forward method of the ContinuousSincosEmbed layer. | |
| Args: | |
| coords: Tensor of coordinates. The shape of the tensor should be | |
| (batch size, number of points, coordinate dimension) or (number of points, coordinate dimension). | |
| Returns: | |
| Tensor with embedded coordinates. | |
| """ | |
| if self.assert_positive: | |
| # check if coords are positive | |
| assert torch.all(coords >= 0) | |
| # fp32 to avoid numerical imprecision | |
| coords = coords.float() | |
| with torch.autocast(device_type=str(coords.device).split(":")[0], enabled=False): | |
| coordinate_ndim = coords.shape[-1] | |
| assert self.ndim == coordinate_ndim | |
| out = coords.unsqueeze(-1) @ self.omega.unsqueeze(0) | |
| emb = torch.concat([torch.sin(out), torch.cos(out)], dim=-1) | |
| if coords.ndim == 3: | |
| emb = einops.rearrange(emb, "bs num_points ndim dim -> bs num_points (ndim dim)") | |
| elif coords.ndim == 2: | |
| emb = einops.rearrange(emb, "num_points ndim dim -> num_points (ndim dim)") | |
| else: | |
| raise NotImplementedError | |
| if self.padding > 0: | |
| padding = torch.zeros(*emb.shape[:-1], self.padding, device=emb.device, dtype=emb.dtype) | |
| emb = torch.concat([emb, padding], dim=-1) | |
| emb = self.surface_bias(emb) | |
| return emb | |
| class MLP(nn.Module): | |
| def __init__(self, n_input, n_hidden, n_output, n_layers=0, res=False): | |
| super(MLP, self).__init__() | |
| act = nn.GELU | |
| self.n_input = n_input | |
| self.n_hidden = n_hidden | |
| self.n_output = n_output | |
| self.n_layers = n_layers | |
| self.res = res | |
| self.linear_pre = nn.Sequential(nn.Linear(n_input, n_hidden), act()) | |
| self.linear_post = nn.Linear(n_hidden, n_output) | |
| self.linears = nn.ModuleList([nn.Sequential(nn.Linear(n_hidden, n_hidden), act()) for _ in range(n_layers)]) | |
| def forward(self, x): | |
| x = self.linear_pre(x) | |
| for i in range(self.n_layers): | |
| if self.res: | |
| x = self.linears[i](x) + x | |
| else: | |
| x = self.linears[i](x) | |
| x = self.linear_post(x) | |
| return x | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads, dropout=0.0): | |
| super().__init__() | |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.head_dim = d_model // num_heads | |
| # Separate projections for Q, K, V | |
| self.q_proj = nn.Linear(d_model, d_model) | |
| self.k_proj = nn.Linear(d_model, d_model) | |
| self.v_proj = nn.Linear(d_model, d_model) | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) ## Effect of this? | |
| def forward(self, q, k=None, v=None): | |
| if k is None: | |
| k = q | |
| if v is None: | |
| v = k | |
| batch_size = q.size(0) | |
| # Project inputs to Q, K, V | |
| q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| # with torch.backends.cuda.sdp_kernel( | |
| # enable_flash=True, | |
| # enable_math=False, | |
| # enable_mem_efficient=False | |
| # ): | |
| # output = F.scaled_dot_product_attention(q, k, v) | |
| output = F.scaled_dot_product_attention(q, k, v) | |
| output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) | |
| output = self.out_proj(output) | |
| return output | |
| class TransformerSelfBlock(nn.Module): | |
| def __init__(self, n_hidden, n_heads, mlp_ratio = 1, dropout=0.0): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(n_hidden, n_heads, dropout) | |
| self.ffn = MLP(n_hidden, n_hidden*mlp_ratio, n_hidden) | |
| self.norm1 = nn.LayerNorm(n_hidden) | |
| self.norm2 = nn.LayerNorm(n_hidden) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| normx = self.norm1(x) | |
| self_output = self.self_attn( | |
| q=normx, | |
| k=normx, | |
| v=normx | |
| ) | |
| x = x + self.dropout(self_output) | |
| # Feedforward network | |
| ffn_output = self.ffn(self.norm2(x)) ## 2 layer 128 -> 256 ->128 (expansion) mlp ratio = 2 | |
| x = x + self.dropout(ffn_output) | |
| return x ## Dees head dimension matter? or n_heads matter? whats the intution? | |
| class ansysLPFMs(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| in_dim = cfg.indim | |
| out_dim = cfg.outdim | |
| self.n_decoder = cfg.n_decoder | |
| n_hidden = cfg.hidden_dim | |
| n_heads = cfg.n_heads | |
| mlp_ratio = cfg.mlp_ratio | |
| self.save_latent = getattr(cfg, "save_latent", False) | |
| if cfg.pos_embed_sincos: | |
| self.pos_embed = ContinuousSincosEmbed(dim=n_hidden, ndim=in_dim) | |
| else: | |
| self.pos_embed = MLP(in_dim, n_hidden * 2, n_hidden, n_layers=0, res=False) | |
| self.decoders = nn.ModuleList([TransformerSelfBlock(n_hidden, n_heads, mlp_ratio) for _ in range(self.n_decoder)]) | |
| self.linear_proj_out = nn.Linear(n_hidden, out_dim) | |
| # Initialize weights properly for stability | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) ## and between std deviation of -2 and 2 | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| # def _init_weights(self, m): | |
| # if isinstance(m, nn.Linear): | |
| # nn.init.xavier_uniform_(m.weight, gain=1.0) | |
| # if m.bias is not None: | |
| # nn.init.constant_(m.bias, 0) | |
| # elif isinstance(m, nn.LayerNorm): | |
| # nn.init.constant_(m.bias, 0) | |
| # nn.init.constant_(m.weight, 1.0) | |
| def forward(self, data): | |
| input_pos = data['input_pos'] | |
| x = self.pos_embed(input_pos) | |
| for i, decoder in enumerate(self.decoders): | |
| x = decoder(x) | |
| if i == self.n_decoder // 2: | |
| mid = x | |
| out = self.linear_proj_out(x) | |
| if self.save_latent: | |
| return out, mid | |
| return out | |