Text Generation
Safetensors
English
hudsongouge commited on
Commit
240812e
·
verified ·
1 Parent(s): 784f17f

Upload inference directory

Browse files
inference/model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from .optimized_diffattn import MultiheadDiffAttn
6
+
7
+ # --- Tokenizer Definition ---
8
+ # Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
9
+ IM_START_TOKEN = "<|im_start|>"
10
+ IM_END_TOKEN = "<|im_end|>"
11
+ PAD_TOKEN = "<pad>"
12
+
13
+ SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
14
+ VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
15
+
16
+ # Create token to id mapping
17
+ token_to_id = {}
18
+ id_to_token = {}
19
+
20
+ for i in range(256):
21
+ token_to_id[bytes([i])] = i
22
+ id_to_token[i] = bytes([i])
23
+
24
+ for i, token_str in enumerate(SPECIAL_TOKENS):
25
+ token_id = 256 + i
26
+ token_to_id[token_str] = token_id
27
+ id_to_token[token_id] = token_str
28
+
29
+ PAD_ID = token_to_id[PAD_TOKEN]
30
+ IM_START_ID = token_to_id[IM_START_TOKEN]
31
+ IM_END_ID = token_to_id[IM_END_TOKEN]
32
+
33
+
34
+ class ByteTokenizer:
35
+ def __init__(self):
36
+ self.token_to_id = token_to_id
37
+ self.id_to_token = id_to_token
38
+ self.vocab_size = VOCAB_SIZE
39
+ self.pad_id = PAD_ID
40
+ self.im_start_id = IM_START_ID
41
+ self.im_end_id = IM_END_ID
42
+
43
+ def encode(self, text_bytes: bytes, add_special_tokens=True):
44
+ ids = [self.token_to_id[bytes([b])] for b in text_bytes]
45
+ if add_special_tokens:
46
+ return [self.im_start_id] + ids + [self.im_end_id]
47
+ return ids
48
+
49
+ def decode(self, ids: list[int]):
50
+ tokens = []
51
+ for i in ids:
52
+ token = self.id_to_token.get(i)
53
+ if token is None:
54
+ # Handle unknown token ID if necessary, or raise error
55
+ tokens.append(b"?") # Placeholder for unknown
56
+ elif isinstance(token, bytes):
57
+ tokens.append(token)
58
+ # Ignore special tokens for decoding to raw text, or handle as needed
59
+ return b"".join(tokens)
60
+
61
+
62
+ # --- RoPE Embeddings --- (Reused from previous script)
63
+ def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
64
+ if dim_model % 2 != 0:
65
+ raise ValueError(f"dim_model must be even, got {dim_model}")
66
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
67
+ div_term = torch.exp(
68
+ torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
69
+ )
70
+ angles = position * div_term
71
+ cos_emb = torch.cos(angles)
72
+ sin_emb = torch.sin(angles)
73
+ return cos_emb, sin_emb
74
+
75
+
76
+ # --- Model Definition ---
77
+ class FeedForward(nn.Module):
78
+ def __init__(self, embed_dim, hidden_dim, dropout=0.1):
79
+ super().__init__()
80
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
81
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
82
+ self.dropout = nn.Dropout(dropout)
83
+ self.act = nn.GELU()
84
+
85
+ def forward(self, x):
86
+ return self.fc2(self.dropout(self.act(self.fc1(x))))
87
+
88
+
89
+ class DiffTransformerBlock(nn.Module):
90
+ def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
91
+ super().__init__()
92
+ self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
93
+ self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
94
+ self.norm1 = nn.LayerNorm(embed_dim)
95
+ self.norm2 = nn.LayerNorm(embed_dim)
96
+ self.dropout = nn.Dropout(dropout)
97
+
98
+ def forward(self, x, rel_pos, attn_mask=None):
99
+ # Pre-norm
100
+ attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
101
+ x = x + self.dropout(attn_out)
102
+ ffn_out = self.ffn(self.norm2(x))
103
+ x = x + self.dropout(ffn_out)
104
+ return x
105
+
106
+
107
+ class DiffTransformerLLM(nn.Module):
108
+ def __init__(
109
+ self,
110
+ vocab_size,
111
+ embed_dim,
112
+ num_layers,
113
+ num_heads,
114
+ ffn_hidden_dim,
115
+ max_seq_len,
116
+ dropout=0.1,
117
+ ):
118
+ super().__init__()
119
+ self.embed_dim = embed_dim
120
+ self.max_seq_len = max_seq_len
121
+
122
+ self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
123
+ # Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
124
+ self.dropout = nn.Dropout(dropout)
125
+
126
+ self.layers = nn.ModuleList(
127
+ [
128
+ DiffTransformerBlock(
129
+ embed_dim, num_heads, depth, ffn_hidden_dim, dropout
130
+ )
131
+ for depth in range(num_layers)
132
+ ]
133
+ )
134
+ self.norm_out = nn.LayerNorm(embed_dim)
135
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
136
+
137
+ # Tie weights
138
+ self.token_embeddings.weight = self.lm_head.weight
139
+
140
+ # RoPE precomputation
141
+ # The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
142
+ self.rope_head_dim = embed_dim // num_heads // 2
143
+ cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
144
+ self.register_buffer("cos_emb", cos_emb, persistent=False)
145
+ self.register_buffer("sin_emb", sin_emb, persistent=False)
146
+
147
+ def forward(self, input_ids, attn_mask=None):
148
+ batch_size, seq_len = input_ids.shape
149
+
150
+ x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
151
+ x = self.dropout(x)
152
+
153
+ # Ensure RoPE embeddings are on the same device *and* dtype as activations
154
+ rel_pos = (
155
+ self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
156
+ self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
157
+ )
158
+
159
+ # Create causal attention mask if not provided
160
+ if attn_mask is None:
161
+ # Standard causal mask for autoregressive decoding
162
+ # MultiheadDiffAttn expects a mask where -inf indicates masked positions
163
+ causal_mask = torch.triu(
164
+ torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
165
+ diagonal=1,
166
+ )
167
+ else:
168
+ # If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
169
+ # For MultiheadDiffAttn, 0 means attend, -inf means mask.
170
+ # Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
171
+ # We need to convert it: (1 - attn_mask) * -inf
172
+ # However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
173
+ # For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
174
+ # If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
175
+ # Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
176
+ causal_mask = torch.triu(
177
+ torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
178
+ diagonal=1,
179
+ )
180
+
181
+ for layer in self.layers:
182
+ x = layer(x, rel_pos, attn_mask=causal_mask)
183
+
184
+ x = self.norm_out(x)
185
+ logits = self.lm_head(x)
186
+ return logits
187
+
188
+ def count_parameters(self):
189
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
inference/optimized_diffattn.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ # Re-use rotary embedding helper from the original codebase
9
+ from .rotary import apply_rotary_emb
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Utility helpers (copied from the original implementation)
13
+ # -----------------------------------------------------------------------------
14
+
15
+
16
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
17
+ """Efficiently repeat keys / values for GQA without allocating new memory."""
18
+ bs, n_kv_heads, slen, head_dim = x.shape
19
+ if n_rep == 1:
20
+ return x
21
+ return (
22
+ x[:, :, None, :, :]
23
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
24
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
25
+ )
26
+
27
+
28
+ def lambda_init_fn(depth: int) -> float:
29
+ """Init schedule described in the DiffAttention paper."""
30
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
31
+
32
+
33
+ # -----------------------------------------------------------------------------
34
+ # Optimised Multi-head DiffAttention implementation
35
+ # -----------------------------------------------------------------------------
36
+
37
+
38
+ class MultiheadDiffAttn(nn.Module):
39
+ """Optimised DiffAttention block.
40
+
41
+ Differences from the original implementation:
42
+ 1. Removes the dependency on Apex / FusedRMSNorm; uses native LayerNorm.
43
+ 2. Keeps all tensors on-device and works well with autocast fp16/bf16.
44
+ 3. Minimises Python-side tensor reshapes and kernel launches.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim: int,
50
+ depth: int,
51
+ num_heads: int,
52
+ num_kv_heads: Optional[int] = None,
53
+ dropout: float = 0.1,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.embed_dim = embed_dim
58
+ self.num_heads = num_heads # query heads (will be doubled internally)
59
+ self.num_kv_heads = num_kv_heads or num_heads
60
+ self.n_rep = (
61
+ self.num_heads // self.num_kv_heads
62
+ ) # replication factor for keys / values (GQA)
63
+ self.attn_dropout = dropout # Store dropout rate for attention
64
+
65
+ # One half of a traditional head – DiffAttention uses pairs of heads
66
+ self.head_dim = embed_dim // self.num_heads // 2
67
+ assert (
68
+ self.head_dim * self.num_heads * 2 == embed_dim
69
+ ), "embed_dim must be divisible by num_heads * 2"
70
+ self.scaling = self.head_dim**-0.5
71
+
72
+ # Projections. We keep them separated because K/V are smaller (GQA)
73
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
74
+ self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
75
+ self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
76
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
77
+
78
+ # Add dropout for regularization
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ # DiffAttention lambda parameters (learnable)
82
+ self.lambda_init = lambda_init_fn(depth)
83
+ self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
84
+ self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
85
+ self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
86
+ self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
87
+
88
+ # Use standard LayerNorm which has a highly-optimised CUDA kernel
89
+ self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5)
90
+
91
+ # ---------------------------------------------------------------------
92
+ # Forward
93
+ # ---------------------------------------------------------------------
94
+ def forward(
95
+ self,
96
+ x: torch.Tensor, # [bsz, seq_len, embed_dim]
97
+ rel_pos: tuple[torch.Tensor, torch.Tensor],
98
+ attn_mask: Optional[torch.Tensor] = None,
99
+ ) -> torch.Tensor:
100
+ bsz, seq_len, _ = x.size()
101
+
102
+ # ---- Projections --------------------------------------------------
103
+ # Projections (run inside the outer autocast context so they stay in
104
+ # the low-precision dtype and use tensor cores)
105
+ q = self.q_proj(x)
106
+ k = self.k_proj(x)
107
+ v = self.v_proj(x)
108
+
109
+ # Reshape into paired heads (2 × heads)
110
+ q = q.view(bsz, seq_len, 2 * self.num_heads, self.head_dim)
111
+ k = k.view(bsz, seq_len, 2 * self.num_kv_heads, self.head_dim)
112
+ v = v.view(bsz, seq_len, self.num_kv_heads, 2 * self.head_dim)
113
+
114
+ # Rotary position encodings (ensure dtype matches q)
115
+ cos, sin = rel_pos
116
+ cos = cos.to(dtype=q.dtype)
117
+ sin = sin.to(dtype=q.dtype)
118
+ q = apply_rotary_emb(q, cos, sin, interleaved=True)
119
+ k = apply_rotary_emb(k, cos, sin, interleaved=True)
120
+
121
+ # ---- Prepare tensors for matmul ----------------------------------
122
+ # Shape conventions follow PyTorch’s `scaled_dot_product_attention`:
123
+ # (bsz, heads, seq, head_dim)
124
+ q = q.transpose(1, 2) # [bsz, 2*heads, seq, head_dim]
125
+ k = k.transpose(1, 2) # [bsz, 2*kv_heads, seq, head_dim]
126
+ v = v.transpose(1, 2) # [bsz, kv_heads, seq, 2*head_dim]
127
+
128
+ # Replicate k/v heads when using GQA
129
+ k = repeat_kv(k, self.n_rep) # [bsz, 2*heads, seq, head_dim]
130
+ v = repeat_kv(v, self.n_rep) # [bsz, heads, seq, 2*head_dim]
131
+
132
+ # ---- Fused scaled dot-product attention (Flash / SDPA) -----------
133
+ #
134
+ # We avoid instantiating the full (seq×seq) score matrix. Instead we
135
+ # run the fused attention kernel twice (positive/negative queries) and
136
+ # combine the resulting context tensors with the λ weighting. This
137
+ # keeps everything in fp16/bf16 and leverages Blackwell’s Flash/SDPA
138
+ # path, giving ~30-80× speed-up vs. the naive implementation.
139
+ # ------------------------------------------------------------------
140
+
141
+ # Re-arrange the paired heads: [bsz, 2*H, S, D] → [bsz, H, 2, S, D]
142
+ q_pairs = q.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
143
+ 0, 2, 1, 3, 4
144
+ )
145
+ k_pairs = k.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
146
+ 0, 2, 1, 3, 4
147
+ )
148
+
149
+ q_pos, q_neg = q_pairs[:, :, 0], q_pairs[:, :, 1] # [bsz, H, S, D]
150
+ k_pos, k_neg = k_pairs[:, :, 0], k_pairs[:, :, 1]
151
+
152
+ # λ scalar (identical across heads / sequence)
153
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)).type_as(q_pos)
154
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)).type_as(q_pos)
155
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init # scalar tensor
156
+
157
+ # --- Fused attention (only TWO SDPA calls) -------------------------
158
+ ctx_pos = F.scaled_dot_product_attention(
159
+ q_pos, k_pos, v, dropout_p=self.attn_dropout, is_causal=True
160
+ ) # [bsz, H, S, 2*D]
161
+ ctx_neg = F.scaled_dot_product_attention(
162
+ q_neg, k_neg, v, dropout_p=self.attn_dropout, is_causal=True
163
+ ) # [bsz, H, S, 2*D]
164
+
165
+ # DiffAttention combination
166
+ attn_out = ctx_pos - lambda_full * ctx_neg # [bsz, H, S, 2*D]
167
+
168
+ # LayerNorm & residual scaling
169
+ attn_out = self.subln(attn_out) * (1.0 - self.lambda_init)
170
+
171
+ # Collapse heads and project out
172
+ attn_out = attn_out.transpose(1, 2).reshape( # [bsz, seq, heads, 2*head_dim]
173
+ bsz, seq_len, self.embed_dim
174
+ )
175
+ # Apply output projection and dropout
176
+ out = self.out_proj(attn_out)
177
+ return self.dropout(out)
inference/rotary.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+
8
+ def apply_rotary_emb_torch(
9
+ x,
10
+ cos,
11
+ sin,
12
+ interleaved=False,
13
+ inplace=False,
14
+ seqlen_offsets=0,
15
+ cu_seqlens=None,
16
+ max_seqlen=None,
17
+ ):
18
+ # Only supports the basic (not interleaved, not variable-length) case.
19
+ rotary_dim = cos.shape[1] * 2
20
+ x1 = x[..., :rotary_dim]
21
+ x2 = x[..., rotary_dim:]
22
+
23
+ # Split [even, odd] pairs
24
+ x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2)
25
+
26
+ # Reshape cos/sin for broadcasting
27
+ # x: [batch, seqlen, nheads, rotary_dim]
28
+ # cos/sin: [seqlen, rotary_dim/2]
29
+ # reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
30
+ cos = cos.unsqueeze(0).unsqueeze(2)
31
+ sin = sin.unsqueeze(0).unsqueeze(2)
32
+
33
+ rot_x1 = x1_1 * cos - x1_2 * sin
34
+ rot_x2 = x1_1 * sin + x1_2 * cos
35
+ # Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
36
+ rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
37
+ out = torch.cat([rot_x, x2], dim=-1)
38
+ return out
39
+
40
+
41
+ def apply_rotary_emb(
42
+ x,
43
+ cos,
44
+ sin,
45
+ interleaved=False,
46
+ inplace=False,
47
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
48
+ cu_seqlens: Optional[torch.Tensor] = None,
49
+ max_seqlen: Optional[int] = None,
50
+ ):
51
+ """
52
+ Arguments:
53
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
54
+ else (total_seqlen, nheads, headdim)
55
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
56
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
57
+ of 1st half and 2nd half (GPT-NeoX style).
58
+ inplace: if True, apply rotary embedding in-place.
59
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
60
+ Most commonly used in inference when we have KV cache.
61
+ cu_seqlens: (batch + 1,) or None
62
+ max_seqlen: int
63
+ Return:
64
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
65
+ else (total_seqlen, nheads, headdim)
66
+ rotary_dim must be <= headdim
67
+ Apply rotary embedding to the first rotary_dim of x.
68
+ """
69
+ # We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
70
+ # for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
71
+ # break in `torch.compile`, pushing expensive operations to the CPU.
72
+ # By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
73
+ # graph, which should resolve the CPU bottleneck and improve GPU utilization.
74
+ return apply_rotary_emb_torch(
75
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
76
+ )