Text Generation
Safetensors
English
DAT-Byte-Small / inference /inference.py
hudsongouge's picture
Update inference/inference.py
3e34d8e verified
import torch
import torch.nn.functional as F
import os
import torch.quantization
from .model import (
DiffTransformerLLM,
ByteTokenizer,
IM_START_TOKEN,
IM_END_TOKEN,
PAD_TOKEN,
)
force_CPU = False
def list_checkpoints(checkpoint_dir="checkpoints"):
"""List all available checkpoints in the directory."""
if not os.path.exists(checkpoint_dir):
print(f"Checkpoint directory {checkpoint_dir} not found.")
return []
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
return sorted(checkpoints)
def load_model(checkpoint_path, device=None, fp16=True):
"""Load a trained model from a checkpoint, applying optimizations as needed."""
import torch
if device is None:
if torch.backends.mps.is_available() and not force_CPU:
device = torch.device("mps")
else:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Hyperparams
vocab_size = 259 # 256 bytes + 3 special tokens
embed_dim = 768
num_layers = 28
num_heads = 12
ffn_hidden_dim = embed_dim * 4
max_seq_len = 2048
dropout = 0.1 # For inference you can set dropout=0
# Model
model = DiffTransformerLLM(
vocab_size=vocab_size,
embed_dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
ffn_hidden_dim=ffn_hidden_dim,
max_seq_len=max_seq_len,
dropout=dropout,
)
# The checkpoint is the state dict itself
state_dict = checkpoint
# Load the state dict into the float32 model first
model.load_state_dict(state_dict)
model.eval()
# Apply device-specific optimizations
if device.type == "cpu":
print("Optimizing for CPU with dynamic quantization (int8).")
# Set the quantization engine
torch.backends.quantized.engine = "qnnpack"
# Quantize the linear layers to int8 for performance
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
elif device.type == "cuda" and fp16:
print("Casting model to fp16 for CUDA.")
model = model.half()
elif device.type == "mps":
print("Optimizing for MPS.")
model = model.to(device)
print("Model loaded successfully.")
return model
def generate_text_stream(
model,
tokenizer,
prompt,
max_new_tokens=100,
temperature=1.0,
top_k=0,
repetition_penalty=1.0,
device=None,
stop_sequences=[],
):
"""
Generate text from a prompt using the trained model, yielding decoded strings in a stream.
This function is a generator.
"""
if device is None:
if torch.backends.mps.is_available() and not force_CPU:
device = torch.device("mps")
else:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
prompt_bytes = prompt.encode("utf-8", errors="replace")
input_ids = (
torch.tensor(
tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
)
.unsqueeze(0)
.to(device)
)
stop_sequences_ids = [
tokenizer.encode(seq.encode("utf-8", errors="replace"), add_special_tokens=False)
for seq in stop_sequences
]
generated_ids = input_ids.clone()
byte_buffer = b""
model.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
if generated_ids.size(1) > model.max_seq_len:
current_input_ids = generated_ids[:, -model.max_seq_len :]
else:
current_input_ids = generated_ids
logits = model(current_input_ids)
next_token_logits = logits[:, -1, :].squeeze(0)
if temperature > 0:
next_token_logits = next_token_logits / temperature
if repetition_penalty > 1.0:
seen_tokens = set(generated_ids[0].tolist())
for token_id in seen_tokens:
next_token_logits[token_id] /= repetition_penalty
if top_k > 0:
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
filtered_logits = torch.full_like(next_token_logits, float("-inf"))
filtered_logits.scatter_(0, top_k_indices, top_k_logits)
next_token_logits = filtered_logits
probs = F.softmax(next_token_logits, dim=0)
next_token = torch.multinomial(probs, 1)
# Decode the token and handle the byte buffer FIRST.
token_byte = tokenizer.decode([next_token.item()])
byte_buffer += token_byte
try:
decoded_str = byte_buffer.decode("utf-8")
yield decoded_str
byte_buffer = b""
except UnicodeDecodeError:
# Incomplete character, continue to accumulate bytes.
pass
# THEN, update the generated IDs and check for a stop sequence.
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
stop_generation = False
current_sequence_list = generated_ids.tolist()[0]
for stop_seq_ids in stop_sequences_ids:
if len(current_sequence_list) >= len(stop_seq_ids):
if current_sequence_list[-len(stop_seq_ids) :] == stop_seq_ids:
stop_generation = True
break
if stop_generation:
break
# If there's anything left in the buffer, decode it with replacement for errors.
if byte_buffer:
yield byte_buffer.decode("utf-8", errors="replace")
def generate_text(
model,
tokenizer,
prompt,
max_new_tokens=100,
temperature=1.0,
top_k=0,
repetition_penalty=1.0,
device=None,
stop_sequences=[],
):
"""
Generate text from a prompt using the trained model.
This is a convenience wrapper around generate_text_stream.
"""
generated_text = "".join(
generate_text_stream(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=repetition_penalty,
device=device,
stop_sequences=stop_sequences,
)
)
full_text = prompt + generated_text
return generated_text, full_text
def main():
parser = argparse.ArgumentParser(
description="Text generation with DiffAttention LLM"
)
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
parser.add_argument(
"--prompt",
type=str,
default="""\nHow many 'b's are in "barber"? \n""",
)
parser.add_argument(
"--max_tokens",
type=int,
default=500,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Sampling temperature"
)
parser.add_argument(
"--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
)
parser.add_argument(
"--top_p",
type=float,
default=0.9,
help="Top-p (nucleus) sampling parameter (0 to disable)",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.2,
help="Repetition penalty (1.0 for no penalty)",
)
parser.add_argument(
"--list_checkpoints",
action="store_true",
help="List available checkpoints and exit",
)
args = parser.parse_args()
# List checkpoints if requested
if args.list_checkpoints:
print("Available checkpoints:")
checkpoints = list_checkpoints()
for i, ckpt in enumerate(checkpoints):
print(f"{i+1}. {ckpt}")
return
# If no checkpoint specified, use the latest one
if not args.checkpoint:
checkpoints = list_checkpoints()
if not checkpoints:
print("No checkpoints found. Please train the model first.")
return
# Find the latest epoch_end checkpoint
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
if end_checkpoints:
latest_checkpoint = max(end_checkpoints)
else:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
else:
checkpoint_path = args.checkpoint
# Set device
if torch.backends.mps.is_available() and not force_CPU:
device = torch.device("mps")
else:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Using device: {device}")
# Initialize tokenizer
tokenizer = ByteTokenizer()
# Load model
model = load_model(checkpoint_path, device)
# Generate text
print(f"\nGenerating text with prompt: '{args.prompt}'")
print(
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
)
print("\nGenerating...")
generated_text, full_text = generate_text(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
device=device,
)
print("\n\nGenerated completion only:")
print("-" * 40)
print(generated_text)
print("-" * 40)
print("\nFull generated text (prompt + completion):")
print("-" * 40)
print(full_text)
print("-" * 40)
if __name__ == "__main__":
import argparse
main()