Text Generation
Safetensors
English
hudsongouge commited on
Commit
3e34d8e
·
verified ·
1 Parent(s): 874c445

Update inference/inference.py

Browse files
Files changed (1) hide show
  1. inference/inference.py +335 -335
inference/inference.py CHANGED
@@ -1,335 +1,335 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import os
4
- import torch.quantization
5
- from .model import (
6
- DiffTransformerLLM,
7
- ByteTokenizer,
8
- IM_START_TOKEN,
9
- IM_END_TOKEN,
10
- PAD_TOKEN,
11
- )
12
-
13
- force_CPU = False
14
-
15
-
16
- def list_checkpoints(checkpoint_dir="checkpoints"):
17
- """List all available checkpoints in the directory."""
18
- if not os.path.exists(checkpoint_dir):
19
- print(f"Checkpoint directory {checkpoint_dir} not found.")
20
- return []
21
-
22
- checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
23
- return sorted(checkpoints)
24
-
25
-
26
- def load_model(checkpoint_path, device=None, fp16=True):
27
- """Load a trained model from a checkpoint, applying optimizations as needed."""
28
- import torch
29
-
30
- if device is None:
31
- if torch.backends.mps.is_available() and not force_CPU:
32
- device = torch.device("mps")
33
- else:
34
- device = torch.device(
35
- "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
36
- )
37
-
38
- print(f"Loading checkpoint from {checkpoint_path}")
39
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
40
-
41
- # Hyperparams
42
- vocab_size = 259 # 256 bytes + 3 special tokens
43
- embed_dim = 768
44
- num_layers = 28
45
- num_heads = 12
46
- ffn_hidden_dim = embed_dim * 4
47
- max_seq_len = 512
48
- dropout = 0.1 # For inference you can set dropout=0
49
-
50
- # Model
51
- model = DiffTransformerLLM(
52
- vocab_size=vocab_size,
53
- embed_dim=embed_dim,
54
- num_layers=num_layers,
55
- num_heads=num_heads,
56
- ffn_hidden_dim=ffn_hidden_dim,
57
- max_seq_len=max_seq_len,
58
- dropout=dropout,
59
- )
60
-
61
- # The checkpoint is the state dict itself
62
- state_dict = checkpoint
63
-
64
- # Load the state dict into the float32 model first
65
- model.load_state_dict(state_dict)
66
- model.eval()
67
-
68
- # Apply device-specific optimizations
69
- if device.type == "cpu":
70
- print("Optimizing for CPU with dynamic quantization (int8).")
71
- # Set the quantization engine
72
- torch.backends.quantized.engine = "qnnpack"
73
- # Quantize the linear layers to int8 for performance
74
- model = torch.quantization.quantize_dynamic(
75
- model, {torch.nn.Linear}, dtype=torch.qint8
76
- )
77
- elif device.type == "cuda" and fp16:
78
- print("Casting model to fp16 for CUDA.")
79
- model = model.half()
80
- elif device.type == "mps":
81
- print("Optimizing for MPS.")
82
-
83
- model = model.to(device)
84
-
85
- print("Model loaded successfully.")
86
- return model
87
-
88
-
89
- def generate_text_stream(
90
- model,
91
- tokenizer,
92
- prompt,
93
- max_new_tokens=100,
94
- temperature=1.0,
95
- top_k=0,
96
- repetition_penalty=1.0,
97
- device=None,
98
- stop_sequences=[],
99
- ):
100
- """
101
- Generate text from a prompt using the trained model, yielding decoded strings in a stream.
102
- This function is a generator.
103
- """
104
- if device is None:
105
- if torch.backends.mps.is_available() and not force_CPU:
106
- device = torch.device("mps")
107
- else:
108
- device = torch.device(
109
- "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
110
- )
111
-
112
- prompt_bytes = prompt.encode("utf-8", errors="replace")
113
- input_ids = (
114
- torch.tensor(
115
- tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
116
- )
117
- .unsqueeze(0)
118
- .to(device)
119
- )
120
-
121
- stop_sequences_ids = [
122
- tokenizer.encode(seq.encode("utf-8", errors="replace"), add_special_tokens=False)
123
- for seq in stop_sequences
124
- ]
125
-
126
- generated_ids = input_ids.clone()
127
- byte_buffer = b""
128
-
129
- model.eval()
130
-
131
- with torch.no_grad():
132
- for _ in range(max_new_tokens):
133
- if generated_ids.size(1) > model.max_seq_len:
134
- current_input_ids = generated_ids[:, -model.max_seq_len :]
135
- else:
136
- current_input_ids = generated_ids
137
-
138
- logits = model(current_input_ids)
139
- next_token_logits = logits[:, -1, :].squeeze(0)
140
-
141
- if temperature > 0:
142
- next_token_logits = next_token_logits / temperature
143
-
144
- if repetition_penalty > 1.0:
145
- seen_tokens = set(generated_ids[0].tolist())
146
- for token_id in seen_tokens:
147
- next_token_logits[token_id] /= repetition_penalty
148
-
149
- if top_k > 0:
150
- top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
151
- filtered_logits = torch.full_like(next_token_logits, float("-inf"))
152
- filtered_logits.scatter_(0, top_k_indices, top_k_logits)
153
- next_token_logits = filtered_logits
154
-
155
- probs = F.softmax(next_token_logits, dim=0)
156
- next_token = torch.multinomial(probs, 1)
157
-
158
- # Decode the token and handle the byte buffer FIRST.
159
- token_byte = tokenizer.decode([next_token.item()])
160
- byte_buffer += token_byte
161
-
162
- try:
163
- decoded_str = byte_buffer.decode("utf-8")
164
- yield decoded_str
165
- byte_buffer = b""
166
- except UnicodeDecodeError:
167
- # Incomplete character, continue to accumulate bytes.
168
- pass
169
-
170
- # THEN, update the generated IDs and check for a stop sequence.
171
- generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
172
-
173
- stop_generation = False
174
- current_sequence_list = generated_ids.tolist()[0]
175
- for stop_seq_ids in stop_sequences_ids:
176
- if len(current_sequence_list) >= len(stop_seq_ids):
177
- if current_sequence_list[-len(stop_seq_ids) :] == stop_seq_ids:
178
- stop_generation = True
179
- break
180
- if stop_generation:
181
- break
182
-
183
- # If there's anything left in the buffer, decode it with replacement for errors.
184
- if byte_buffer:
185
- yield byte_buffer.decode("utf-8", errors="replace")
186
-
187
-
188
- def generate_text(
189
- model,
190
- tokenizer,
191
- prompt,
192
- max_new_tokens=100,
193
- temperature=1.0,
194
- top_k=0,
195
- repetition_penalty=1.0,
196
- device=None,
197
- stop_sequences=[],
198
- ):
199
- """
200
- Generate text from a prompt using the trained model.
201
- This is a convenience wrapper around generate_text_stream.
202
- """
203
- generated_text = "".join(
204
- generate_text_stream(
205
- model=model,
206
- tokenizer=tokenizer,
207
- prompt=prompt,
208
- max_new_tokens=max_new_tokens,
209
- temperature=temperature,
210
- top_k=top_k,
211
- repetition_penalty=repetition_penalty,
212
- device=device,
213
- stop_sequences=stop_sequences,
214
- )
215
- )
216
- full_text = prompt + generated_text
217
- return generated_text, full_text
218
-
219
-
220
- def main():
221
- parser = argparse.ArgumentParser(
222
- description="Text generation with DiffAttention LLM"
223
- )
224
- parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
225
- parser.add_argument(
226
- "--prompt",
227
- type=str,
228
- default="""\nHow many 'b's are in "barber"? \n""",
229
- )
230
- parser.add_argument(
231
- "--max_tokens",
232
- type=int,
233
- default=500,
234
- help="Maximum number of tokens to generate",
235
- )
236
- parser.add_argument(
237
- "--temperature", type=float, default=0.7, help="Sampling temperature"
238
- )
239
- parser.add_argument(
240
- "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
241
- )
242
- parser.add_argument(
243
- "--top_p",
244
- type=float,
245
- default=0.9,
246
- help="Top-p (nucleus) sampling parameter (0 to disable)",
247
- )
248
- parser.add_argument(
249
- "--repetition_penalty",
250
- type=float,
251
- default=1.2,
252
- help="Repetition penalty (1.0 for no penalty)",
253
- )
254
- parser.add_argument(
255
- "--list_checkpoints",
256
- action="store_true",
257
- help="List available checkpoints and exit",
258
- )
259
- args = parser.parse_args()
260
-
261
- # List checkpoints if requested
262
- if args.list_checkpoints:
263
- print("Available checkpoints:")
264
- checkpoints = list_checkpoints()
265
- for i, ckpt in enumerate(checkpoints):
266
- print(f"{i+1}. {ckpt}")
267
- return
268
-
269
- # If no checkpoint specified, use the latest one
270
- if not args.checkpoint:
271
- checkpoints = list_checkpoints()
272
- if not checkpoints:
273
- print("No checkpoints found. Please train the model first.")
274
- return
275
-
276
- # Find the latest epoch_end checkpoint
277
- end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
278
- if end_checkpoints:
279
- latest_checkpoint = max(end_checkpoints)
280
- else:
281
- latest_checkpoint = max(checkpoints)
282
-
283
- checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
284
- else:
285
- checkpoint_path = args.checkpoint
286
-
287
- # Set device
288
- if torch.backends.mps.is_available() and not force_CPU:
289
- device = torch.device("mps")
290
- else:
291
- device = torch.device(
292
- "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
293
- )
294
- print(f"Using device: {device}")
295
-
296
- # Initialize tokenizer
297
- tokenizer = ByteTokenizer()
298
-
299
- # Load model
300
- model = load_model(checkpoint_path, device)
301
-
302
- # Generate text
303
- print(f"\nGenerating text with prompt: '{args.prompt}'")
304
- print(
305
- f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
306
- )
307
- print("\nGenerating...")
308
-
309
- generated_text, full_text = generate_text(
310
- model=model,
311
- tokenizer=tokenizer,
312
- prompt=args.prompt,
313
- max_new_tokens=args.max_tokens,
314
- temperature=args.temperature,
315
- top_k=args.top_k,
316
- top_p=args.top_p,
317
- repetition_penalty=args.repetition_penalty,
318
- device=device,
319
- )
320
-
321
- print("\n\nGenerated completion only:")
322
- print("-" * 40)
323
- print(generated_text)
324
- print("-" * 40)
325
-
326
- print("\nFull generated text (prompt + completion):")
327
- print("-" * 40)
328
- print(full_text)
329
- print("-" * 40)
330
-
331
-
332
- if __name__ == "__main__":
333
- import argparse
334
-
335
- main()
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import os
4
+ import torch.quantization
5
+ from .model import (
6
+ DiffTransformerLLM,
7
+ ByteTokenizer,
8
+ IM_START_TOKEN,
9
+ IM_END_TOKEN,
10
+ PAD_TOKEN,
11
+ )
12
+
13
+ force_CPU = False
14
+
15
+
16
+ def list_checkpoints(checkpoint_dir="checkpoints"):
17
+ """List all available checkpoints in the directory."""
18
+ if not os.path.exists(checkpoint_dir):
19
+ print(f"Checkpoint directory {checkpoint_dir} not found.")
20
+ return []
21
+
22
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
23
+ return sorted(checkpoints)
24
+
25
+
26
+ def load_model(checkpoint_path, device=None, fp16=True):
27
+ """Load a trained model from a checkpoint, applying optimizations as needed."""
28
+ import torch
29
+
30
+ if device is None:
31
+ if torch.backends.mps.is_available() and not force_CPU:
32
+ device = torch.device("mps")
33
+ else:
34
+ device = torch.device(
35
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
36
+ )
37
+
38
+ print(f"Loading checkpoint from {checkpoint_path}")
39
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
40
+
41
+ # Hyperparams
42
+ vocab_size = 259 # 256 bytes + 3 special tokens
43
+ embed_dim = 768
44
+ num_layers = 28
45
+ num_heads = 12
46
+ ffn_hidden_dim = embed_dim * 4
47
+ max_seq_len = 2048
48
+ dropout = 0.1 # For inference you can set dropout=0
49
+
50
+ # Model
51
+ model = DiffTransformerLLM(
52
+ vocab_size=vocab_size,
53
+ embed_dim=embed_dim,
54
+ num_layers=num_layers,
55
+ num_heads=num_heads,
56
+ ffn_hidden_dim=ffn_hidden_dim,
57
+ max_seq_len=max_seq_len,
58
+ dropout=dropout,
59
+ )
60
+
61
+ # The checkpoint is the state dict itself
62
+ state_dict = checkpoint
63
+
64
+ # Load the state dict into the float32 model first
65
+ model.load_state_dict(state_dict)
66
+ model.eval()
67
+
68
+ # Apply device-specific optimizations
69
+ if device.type == "cpu":
70
+ print("Optimizing for CPU with dynamic quantization (int8).")
71
+ # Set the quantization engine
72
+ torch.backends.quantized.engine = "qnnpack"
73
+ # Quantize the linear layers to int8 for performance
74
+ model = torch.quantization.quantize_dynamic(
75
+ model, {torch.nn.Linear}, dtype=torch.qint8
76
+ )
77
+ elif device.type == "cuda" and fp16:
78
+ print("Casting model to fp16 for CUDA.")
79
+ model = model.half()
80
+ elif device.type == "mps":
81
+ print("Optimizing for MPS.")
82
+
83
+ model = model.to(device)
84
+
85
+ print("Model loaded successfully.")
86
+ return model
87
+
88
+
89
+ def generate_text_stream(
90
+ model,
91
+ tokenizer,
92
+ prompt,
93
+ max_new_tokens=100,
94
+ temperature=1.0,
95
+ top_k=0,
96
+ repetition_penalty=1.0,
97
+ device=None,
98
+ stop_sequences=[],
99
+ ):
100
+ """
101
+ Generate text from a prompt using the trained model, yielding decoded strings in a stream.
102
+ This function is a generator.
103
+ """
104
+ if device is None:
105
+ if torch.backends.mps.is_available() and not force_CPU:
106
+ device = torch.device("mps")
107
+ else:
108
+ device = torch.device(
109
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
110
+ )
111
+
112
+ prompt_bytes = prompt.encode("utf-8", errors="replace")
113
+ input_ids = (
114
+ torch.tensor(
115
+ tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
116
+ )
117
+ .unsqueeze(0)
118
+ .to(device)
119
+ )
120
+
121
+ stop_sequences_ids = [
122
+ tokenizer.encode(seq.encode("utf-8", errors="replace"), add_special_tokens=False)
123
+ for seq in stop_sequences
124
+ ]
125
+
126
+ generated_ids = input_ids.clone()
127
+ byte_buffer = b""
128
+
129
+ model.eval()
130
+
131
+ with torch.no_grad():
132
+ for _ in range(max_new_tokens):
133
+ if generated_ids.size(1) > model.max_seq_len:
134
+ current_input_ids = generated_ids[:, -model.max_seq_len :]
135
+ else:
136
+ current_input_ids = generated_ids
137
+
138
+ logits = model(current_input_ids)
139
+ next_token_logits = logits[:, -1, :].squeeze(0)
140
+
141
+ if temperature > 0:
142
+ next_token_logits = next_token_logits / temperature
143
+
144
+ if repetition_penalty > 1.0:
145
+ seen_tokens = set(generated_ids[0].tolist())
146
+ for token_id in seen_tokens:
147
+ next_token_logits[token_id] /= repetition_penalty
148
+
149
+ if top_k > 0:
150
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
151
+ filtered_logits = torch.full_like(next_token_logits, float("-inf"))
152
+ filtered_logits.scatter_(0, top_k_indices, top_k_logits)
153
+ next_token_logits = filtered_logits
154
+
155
+ probs = F.softmax(next_token_logits, dim=0)
156
+ next_token = torch.multinomial(probs, 1)
157
+
158
+ # Decode the token and handle the byte buffer FIRST.
159
+ token_byte = tokenizer.decode([next_token.item()])
160
+ byte_buffer += token_byte
161
+
162
+ try:
163
+ decoded_str = byte_buffer.decode("utf-8")
164
+ yield decoded_str
165
+ byte_buffer = b""
166
+ except UnicodeDecodeError:
167
+ # Incomplete character, continue to accumulate bytes.
168
+ pass
169
+
170
+ # THEN, update the generated IDs and check for a stop sequence.
171
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
172
+
173
+ stop_generation = False
174
+ current_sequence_list = generated_ids.tolist()[0]
175
+ for stop_seq_ids in stop_sequences_ids:
176
+ if len(current_sequence_list) >= len(stop_seq_ids):
177
+ if current_sequence_list[-len(stop_seq_ids) :] == stop_seq_ids:
178
+ stop_generation = True
179
+ break
180
+ if stop_generation:
181
+ break
182
+
183
+ # If there's anything left in the buffer, decode it with replacement for errors.
184
+ if byte_buffer:
185
+ yield byte_buffer.decode("utf-8", errors="replace")
186
+
187
+
188
+ def generate_text(
189
+ model,
190
+ tokenizer,
191
+ prompt,
192
+ max_new_tokens=100,
193
+ temperature=1.0,
194
+ top_k=0,
195
+ repetition_penalty=1.0,
196
+ device=None,
197
+ stop_sequences=[],
198
+ ):
199
+ """
200
+ Generate text from a prompt using the trained model.
201
+ This is a convenience wrapper around generate_text_stream.
202
+ """
203
+ generated_text = "".join(
204
+ generate_text_stream(
205
+ model=model,
206
+ tokenizer=tokenizer,
207
+ prompt=prompt,
208
+ max_new_tokens=max_new_tokens,
209
+ temperature=temperature,
210
+ top_k=top_k,
211
+ repetition_penalty=repetition_penalty,
212
+ device=device,
213
+ stop_sequences=stop_sequences,
214
+ )
215
+ )
216
+ full_text = prompt + generated_text
217
+ return generated_text, full_text
218
+
219
+
220
+ def main():
221
+ parser = argparse.ArgumentParser(
222
+ description="Text generation with DiffAttention LLM"
223
+ )
224
+ parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
225
+ parser.add_argument(
226
+ "--prompt",
227
+ type=str,
228
+ default="""\nHow many 'b's are in "barber"? \n""",
229
+ )
230
+ parser.add_argument(
231
+ "--max_tokens",
232
+ type=int,
233
+ default=500,
234
+ help="Maximum number of tokens to generate",
235
+ )
236
+ parser.add_argument(
237
+ "--temperature", type=float, default=0.7, help="Sampling temperature"
238
+ )
239
+ parser.add_argument(
240
+ "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
241
+ )
242
+ parser.add_argument(
243
+ "--top_p",
244
+ type=float,
245
+ default=0.9,
246
+ help="Top-p (nucleus) sampling parameter (0 to disable)",
247
+ )
248
+ parser.add_argument(
249
+ "--repetition_penalty",
250
+ type=float,
251
+ default=1.2,
252
+ help="Repetition penalty (1.0 for no penalty)",
253
+ )
254
+ parser.add_argument(
255
+ "--list_checkpoints",
256
+ action="store_true",
257
+ help="List available checkpoints and exit",
258
+ )
259
+ args = parser.parse_args()
260
+
261
+ # List checkpoints if requested
262
+ if args.list_checkpoints:
263
+ print("Available checkpoints:")
264
+ checkpoints = list_checkpoints()
265
+ for i, ckpt in enumerate(checkpoints):
266
+ print(f"{i+1}. {ckpt}")
267
+ return
268
+
269
+ # If no checkpoint specified, use the latest one
270
+ if not args.checkpoint:
271
+ checkpoints = list_checkpoints()
272
+ if not checkpoints:
273
+ print("No checkpoints found. Please train the model first.")
274
+ return
275
+
276
+ # Find the latest epoch_end checkpoint
277
+ end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
278
+ if end_checkpoints:
279
+ latest_checkpoint = max(end_checkpoints)
280
+ else:
281
+ latest_checkpoint = max(checkpoints)
282
+
283
+ checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
284
+ else:
285
+ checkpoint_path = args.checkpoint
286
+
287
+ # Set device
288
+ if torch.backends.mps.is_available() and not force_CPU:
289
+ device = torch.device("mps")
290
+ else:
291
+ device = torch.device(
292
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
293
+ )
294
+ print(f"Using device: {device}")
295
+
296
+ # Initialize tokenizer
297
+ tokenizer = ByteTokenizer()
298
+
299
+ # Load model
300
+ model = load_model(checkpoint_path, device)
301
+
302
+ # Generate text
303
+ print(f"\nGenerating text with prompt: '{args.prompt}'")
304
+ print(
305
+ f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
306
+ )
307
+ print("\nGenerating...")
308
+
309
+ generated_text, full_text = generate_text(
310
+ model=model,
311
+ tokenizer=tokenizer,
312
+ prompt=args.prompt,
313
+ max_new_tokens=args.max_tokens,
314
+ temperature=args.temperature,
315
+ top_k=args.top_k,
316
+ top_p=args.top_p,
317
+ repetition_penalty=args.repetition_penalty,
318
+ device=device,
319
+ )
320
+
321
+ print("\n\nGenerated completion only:")
322
+ print("-" * 40)
323
+ print(generated_text)
324
+ print("-" * 40)
325
+
326
+ print("\nFull generated text (prompt + completion):")
327
+ print("-" * 40)
328
+ print(full_text)
329
+ print("-" * 40)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ import argparse
334
+
335
+ main()