|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
import torch.quantization |
|
|
from .model import ( |
|
|
DiffTransformerLLM, |
|
|
ByteTokenizer, |
|
|
IM_START_TOKEN, |
|
|
IM_END_TOKEN, |
|
|
PAD_TOKEN, |
|
|
) |
|
|
|
|
|
force_CPU = False |
|
|
|
|
|
|
|
|
def list_checkpoints(checkpoint_dir="checkpoints"): |
|
|
"""List all available checkpoints in the directory.""" |
|
|
if not os.path.exists(checkpoint_dir): |
|
|
print(f"Checkpoint directory {checkpoint_dir} not found.") |
|
|
return [] |
|
|
|
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] |
|
|
return sorted(checkpoints) |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, device=None, fp16=True): |
|
|
"""Load a trained model from a checkpoint, applying optimizations as needed.""" |
|
|
import torch |
|
|
|
|
|
if device is None: |
|
|
if torch.backends.mps.is_available() and not force_CPU: |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu" |
|
|
) |
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
vocab_size = 259 |
|
|
embed_dim = 768 |
|
|
num_layers = 28 |
|
|
num_heads = 12 |
|
|
ffn_hidden_dim = embed_dim * 4 |
|
|
max_seq_len = 2048 |
|
|
dropout = 0.1 |
|
|
|
|
|
|
|
|
model = DiffTransformerLLM( |
|
|
vocab_size=vocab_size, |
|
|
embed_dim=embed_dim, |
|
|
num_layers=num_layers, |
|
|
num_heads=num_heads, |
|
|
ffn_hidden_dim=ffn_hidden_dim, |
|
|
max_seq_len=max_seq_len, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if device.type == "cpu": |
|
|
print("Optimizing for CPU with dynamic quantization (int8).") |
|
|
|
|
|
torch.backends.quantized.engine = "qnnpack" |
|
|
|
|
|
model = torch.quantization.quantize_dynamic( |
|
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
|
) |
|
|
elif device.type == "cuda" and fp16: |
|
|
print("Casting model to fp16 for CUDA.") |
|
|
model = model.half() |
|
|
elif device.type == "mps": |
|
|
print("Optimizing for MPS.") |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
print("Model loaded successfully.") |
|
|
return model |
|
|
|
|
|
|
|
|
def generate_text_stream( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt, |
|
|
max_new_tokens=100, |
|
|
temperature=1.0, |
|
|
top_k=0, |
|
|
repetition_penalty=1.0, |
|
|
device=None, |
|
|
stop_sequences=[], |
|
|
): |
|
|
""" |
|
|
Generate text from a prompt using the trained model, yielding decoded strings in a stream. |
|
|
This function is a generator. |
|
|
""" |
|
|
if device is None: |
|
|
if torch.backends.mps.is_available() and not force_CPU: |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu" |
|
|
) |
|
|
|
|
|
prompt_bytes = prompt.encode("utf-8", errors="replace") |
|
|
input_ids = ( |
|
|
torch.tensor( |
|
|
tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long |
|
|
) |
|
|
.unsqueeze(0) |
|
|
.to(device) |
|
|
) |
|
|
|
|
|
stop_sequences_ids = [ |
|
|
tokenizer.encode(seq.encode("utf-8", errors="replace"), add_special_tokens=False) |
|
|
for seq in stop_sequences |
|
|
] |
|
|
|
|
|
generated_ids = input_ids.clone() |
|
|
byte_buffer = b"" |
|
|
|
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
if generated_ids.size(1) > model.max_seq_len: |
|
|
current_input_ids = generated_ids[:, -model.max_seq_len :] |
|
|
else: |
|
|
current_input_ids = generated_ids |
|
|
|
|
|
logits = model(current_input_ids) |
|
|
next_token_logits = logits[:, -1, :].squeeze(0) |
|
|
|
|
|
if temperature > 0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
if repetition_penalty > 1.0: |
|
|
seen_tokens = set(generated_ids[0].tolist()) |
|
|
for token_id in seen_tokens: |
|
|
next_token_logits[token_id] /= repetition_penalty |
|
|
|
|
|
if top_k > 0: |
|
|
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) |
|
|
filtered_logits = torch.full_like(next_token_logits, float("-inf")) |
|
|
filtered_logits.scatter_(0, top_k_indices, top_k_logits) |
|
|
next_token_logits = filtered_logits |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=0) |
|
|
next_token = torch.multinomial(probs, 1) |
|
|
|
|
|
|
|
|
token_byte = tokenizer.decode([next_token.item()]) |
|
|
byte_buffer += token_byte |
|
|
|
|
|
try: |
|
|
decoded_str = byte_buffer.decode("utf-8") |
|
|
yield decoded_str |
|
|
byte_buffer = b"" |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) |
|
|
|
|
|
stop_generation = False |
|
|
current_sequence_list = generated_ids.tolist()[0] |
|
|
for stop_seq_ids in stop_sequences_ids: |
|
|
if len(current_sequence_list) >= len(stop_seq_ids): |
|
|
if current_sequence_list[-len(stop_seq_ids) :] == stop_seq_ids: |
|
|
stop_generation = True |
|
|
break |
|
|
if stop_generation: |
|
|
break |
|
|
|
|
|
|
|
|
if byte_buffer: |
|
|
yield byte_buffer.decode("utf-8", errors="replace") |
|
|
|
|
|
|
|
|
def generate_text( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt, |
|
|
max_new_tokens=100, |
|
|
temperature=1.0, |
|
|
top_k=0, |
|
|
repetition_penalty=1.0, |
|
|
device=None, |
|
|
stop_sequences=[], |
|
|
): |
|
|
""" |
|
|
Generate text from a prompt using the trained model. |
|
|
This is a convenience wrapper around generate_text_stream. |
|
|
""" |
|
|
generated_text = "".join( |
|
|
generate_text_stream( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
device=device, |
|
|
stop_sequences=stop_sequences, |
|
|
) |
|
|
) |
|
|
full_text = prompt + generated_text |
|
|
return generated_text, full_text |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Text generation with DiffAttention LLM" |
|
|
) |
|
|
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default="""\nHow many 'b's are in "barber"? \n""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_tokens", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Maximum number of tokens to generate", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", type=float, default=0.7, help="Sampling temperature" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_p", |
|
|
type=float, |
|
|
default=0.9, |
|
|
help="Top-p (nucleus) sampling parameter (0 to disable)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--repetition_penalty", |
|
|
type=float, |
|
|
default=1.2, |
|
|
help="Repetition penalty (1.0 for no penalty)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--list_checkpoints", |
|
|
action="store_true", |
|
|
help="List available checkpoints and exit", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.list_checkpoints: |
|
|
print("Available checkpoints:") |
|
|
checkpoints = list_checkpoints() |
|
|
for i, ckpt in enumerate(checkpoints): |
|
|
print(f"{i+1}. {ckpt}") |
|
|
return |
|
|
|
|
|
|
|
|
if not args.checkpoint: |
|
|
checkpoints = list_checkpoints() |
|
|
if not checkpoints: |
|
|
print("No checkpoints found. Please train the model first.") |
|
|
return |
|
|
|
|
|
|
|
|
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] |
|
|
if end_checkpoints: |
|
|
latest_checkpoint = max(end_checkpoints) |
|
|
else: |
|
|
latest_checkpoint = max(checkpoints) |
|
|
|
|
|
checkpoint_path = os.path.join("checkpoints", latest_checkpoint) |
|
|
else: |
|
|
checkpoint_path = args.checkpoint |
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available() and not force_CPU: |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu" |
|
|
) |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
tokenizer = ByteTokenizer() |
|
|
|
|
|
|
|
|
model = load_model(checkpoint_path, device) |
|
|
|
|
|
|
|
|
print(f"\nGenerating text with prompt: '{args.prompt}'") |
|
|
print( |
|
|
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" |
|
|
) |
|
|
print("\nGenerating...") |
|
|
|
|
|
generated_text, full_text = generate_text( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_k=args.top_k, |
|
|
top_p=args.top_p, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
print("\n\nGenerated completion only:") |
|
|
print("-" * 40) |
|
|
print(generated_text) |
|
|
print("-" * 40) |
|
|
|
|
|
print("\nFull generated text (prompt + completion):") |
|
|
print("-" * 40) |
|
|
print(full_text) |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
main() |
|
|
|