|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import math
|
|
|
|
|
|
from .optimized_diffattn import MultiheadDiffAttn
|
|
|
|
|
|
|
|
|
|
|
|
IM_START_TOKEN = "<|im_start|>"
|
|
|
IM_END_TOKEN = "<|im_end|>"
|
|
|
PAD_TOKEN = "<pad>"
|
|
|
|
|
|
SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
|
|
|
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
|
|
|
|
|
|
|
|
|
token_to_id = {}
|
|
|
id_to_token = {}
|
|
|
|
|
|
for i in range(256):
|
|
|
token_to_id[bytes([i])] = i
|
|
|
id_to_token[i] = bytes([i])
|
|
|
|
|
|
for i, token_str in enumerate(SPECIAL_TOKENS):
|
|
|
token_id = 256 + i
|
|
|
token_to_id[token_str] = token_id
|
|
|
id_to_token[token_id] = token_str
|
|
|
|
|
|
PAD_ID = token_to_id[PAD_TOKEN]
|
|
|
IM_START_ID = token_to_id[IM_START_TOKEN]
|
|
|
IM_END_ID = token_to_id[IM_END_TOKEN]
|
|
|
|
|
|
|
|
|
class ByteTokenizer:
|
|
|
def __init__(self):
|
|
|
self.token_to_id = token_to_id
|
|
|
self.id_to_token = id_to_token
|
|
|
self.vocab_size = VOCAB_SIZE
|
|
|
self.pad_id = PAD_ID
|
|
|
self.im_start_id = IM_START_ID
|
|
|
self.im_end_id = IM_END_ID
|
|
|
|
|
|
def encode(self, text_bytes: bytes, add_special_tokens=True):
|
|
|
ids = [self.token_to_id[bytes([b])] for b in text_bytes]
|
|
|
if add_special_tokens:
|
|
|
return [self.im_start_id] + ids + [self.im_end_id]
|
|
|
return ids
|
|
|
|
|
|
def decode(self, ids: list[int]):
|
|
|
tokens = []
|
|
|
for i in ids:
|
|
|
token = self.id_to_token.get(i)
|
|
|
if token is None:
|
|
|
|
|
|
tokens.append(b"?")
|
|
|
elif isinstance(token, bytes):
|
|
|
tokens.append(token)
|
|
|
|
|
|
return b"".join(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
|
|
|
if dim_model % 2 != 0:
|
|
|
raise ValueError(f"dim_model must be even, got {dim_model}")
|
|
|
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
|
|
|
div_term = torch.exp(
|
|
|
torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
|
|
|
)
|
|
|
angles = position * div_term
|
|
|
cos_emb = torch.cos(angles)
|
|
|
sin_emb = torch.sin(angles)
|
|
|
return cos_emb, sin_emb
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
def __init__(self, embed_dim, hidden_dim, dropout=0.1):
|
|
|
super().__init__()
|
|
|
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
|
|
self.fc2 = nn.Linear(hidden_dim, embed_dim)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.act = nn.GELU()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.fc2(self.dropout(self.act(self.fc1(x))))
|
|
|
|
|
|
|
|
|
class DiffTransformerBlock(nn.Module):
|
|
|
def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
|
|
|
super().__init__()
|
|
|
self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
|
|
|
self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
|
|
|
self.norm1 = nn.LayerNorm(embed_dim)
|
|
|
self.norm2 = nn.LayerNorm(embed_dim)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x, rel_pos, attn_mask=None):
|
|
|
|
|
|
attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
|
|
|
x = x + self.dropout(attn_out)
|
|
|
ffn_out = self.ffn(self.norm2(x))
|
|
|
x = x + self.dropout(ffn_out)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class DiffTransformerLLM(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size,
|
|
|
embed_dim,
|
|
|
num_layers,
|
|
|
num_heads,
|
|
|
ffn_hidden_dim,
|
|
|
max_seq_len,
|
|
|
dropout=0.1,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.embed_dim = embed_dim
|
|
|
self.max_seq_len = max_seq_len
|
|
|
|
|
|
self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
self.layers = nn.ModuleList(
|
|
|
[
|
|
|
DiffTransformerBlock(
|
|
|
embed_dim, num_heads, depth, ffn_hidden_dim, dropout
|
|
|
)
|
|
|
for depth in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
self.norm_out = nn.LayerNorm(embed_dim)
|
|
|
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.token_embeddings.weight = self.lm_head.weight
|
|
|
|
|
|
|
|
|
|
|
|
self.rope_head_dim = embed_dim // num_heads // 2
|
|
|
cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
|
|
|
self.register_buffer("cos_emb", cos_emb, persistent=False)
|
|
|
self.register_buffer("sin_emb", sin_emb, persistent=False)
|
|
|
|
|
|
def forward(self, input_ids, attn_mask=None):
|
|
|
batch_size, seq_len = input_ids.shape
|
|
|
|
|
|
x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
|
|
|
x = self.dropout(x)
|
|
|
|
|
|
|
|
|
rel_pos = (
|
|
|
self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
|
|
|
self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
|
|
|
)
|
|
|
|
|
|
|
|
|
if attn_mask is None:
|
|
|
|
|
|
|
|
|
causal_mask = torch.triu(
|
|
|
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
|
|
|
diagonal=1,
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal_mask = torch.triu(
|
|
|
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
|
|
|
diagonal=1,
|
|
|
)
|
|
|
|
|
|
for layer in self.layers:
|
|
|
x = layer(x, rel_pos, attn_mask=causal_mask)
|
|
|
|
|
|
x = self.norm_out(x)
|
|
|
logits = self.lm_head(x)
|
|
|
return logits
|
|
|
|
|
|
def count_parameters(self):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|