import torch import torch.nn.functional as F import os import torch.quantization from .model import ( DiffTransformerLLM, ByteTokenizer, IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN, ) force_CPU = False def list_checkpoints(checkpoint_dir="checkpoints"): """List all available checkpoints in the directory.""" if not os.path.exists(checkpoint_dir): print(f"Checkpoint directory {checkpoint_dir} not found.") return [] checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] return sorted(checkpoints) def load_model(checkpoint_path, device=None, fp16=True): """Load a trained model from a checkpoint, applying optimizations as needed.""" import torch if device is None: if torch.backends.mps.is_available() and not force_CPU: device = torch.device("mps") else: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") # Hyperparams vocab_size = 259 # 256 bytes + 3 special tokens embed_dim = 768 num_layers = 28 num_heads = 12 ffn_hidden_dim = embed_dim * 4 max_seq_len = 2048 dropout = 0.1 # For inference you can set dropout=0 # Model model = DiffTransformerLLM( vocab_size=vocab_size, embed_dim=embed_dim, num_layers=num_layers, num_heads=num_heads, ffn_hidden_dim=ffn_hidden_dim, max_seq_len=max_seq_len, dropout=dropout, ) # The checkpoint is the state dict itself state_dict = checkpoint # Load the state dict into the float32 model first model.load_state_dict(state_dict) model.eval() # Apply device-specific optimizations if device.type == "cpu": print("Optimizing for CPU with dynamic quantization (int8).") # Set the quantization engine torch.backends.quantized.engine = "qnnpack" # Quantize the linear layers to int8 for performance model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) elif device.type == "cuda" and fp16: print("Casting model to fp16 for CUDA.") model = model.half() elif device.type == "mps": print("Optimizing for MPS.") model = model.to(device) print("Model loaded successfully.") return model def generate_text_stream( model, tokenizer, prompt, max_new_tokens=100, temperature=1.0, top_k=0, repetition_penalty=1.0, device=None, stop_sequences=[], ): """ Generate text from a prompt using the trained model, yielding decoded strings in a stream. This function is a generator. """ if device is None: if torch.backends.mps.is_available() and not force_CPU: device = torch.device("mps") else: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) prompt_bytes = prompt.encode("utf-8", errors="replace") input_ids = ( torch.tensor( tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long ) .unsqueeze(0) .to(device) ) stop_sequences_ids = [ tokenizer.encode(seq.encode("utf-8", errors="replace"), add_special_tokens=False) for seq in stop_sequences ] generated_ids = input_ids.clone() byte_buffer = b"" model.eval() with torch.no_grad(): for _ in range(max_new_tokens): if generated_ids.size(1) > model.max_seq_len: current_input_ids = generated_ids[:, -model.max_seq_len :] else: current_input_ids = generated_ids logits = model(current_input_ids) next_token_logits = logits[:, -1, :].squeeze(0) if temperature > 0: next_token_logits = next_token_logits / temperature if repetition_penalty > 1.0: seen_tokens = set(generated_ids[0].tolist()) for token_id in seen_tokens: next_token_logits[token_id] /= repetition_penalty if top_k > 0: top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) filtered_logits = torch.full_like(next_token_logits, float("-inf")) filtered_logits.scatter_(0, top_k_indices, top_k_logits) next_token_logits = filtered_logits probs = F.softmax(next_token_logits, dim=0) next_token = torch.multinomial(probs, 1) # Decode the token and handle the byte buffer FIRST. token_byte = tokenizer.decode([next_token.item()]) byte_buffer += token_byte try: decoded_str = byte_buffer.decode("utf-8") yield decoded_str byte_buffer = b"" except UnicodeDecodeError: # Incomplete character, continue to accumulate bytes. pass # THEN, update the generated IDs and check for a stop sequence. generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) stop_generation = False current_sequence_list = generated_ids.tolist()[0] for stop_seq_ids in stop_sequences_ids: if len(current_sequence_list) >= len(stop_seq_ids): if current_sequence_list[-len(stop_seq_ids) :] == stop_seq_ids: stop_generation = True break if stop_generation: break # If there's anything left in the buffer, decode it with replacement for errors. if byte_buffer: yield byte_buffer.decode("utf-8", errors="replace") def generate_text( model, tokenizer, prompt, max_new_tokens=100, temperature=1.0, top_k=0, repetition_penalty=1.0, device=None, stop_sequences=[], ): """ Generate text from a prompt using the trained model. This is a convenience wrapper around generate_text_stream. """ generated_text = "".join( generate_text_stream( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, device=device, stop_sequences=stop_sequences, ) ) full_text = prompt + generated_text return generated_text, full_text def main(): parser = argparse.ArgumentParser( description="Text generation with DiffAttention LLM" ) parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") parser.add_argument( "--prompt", type=str, default="""\nHow many 'b's are in "barber"? \n""", ) parser.add_argument( "--max_tokens", type=int, default=500, help="Maximum number of tokens to generate", ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)" ) parser.add_argument( "--top_p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter (0 to disable)", ) parser.add_argument( "--repetition_penalty", type=float, default=1.2, help="Repetition penalty (1.0 for no penalty)", ) parser.add_argument( "--list_checkpoints", action="store_true", help="List available checkpoints and exit", ) args = parser.parse_args() # List checkpoints if requested if args.list_checkpoints: print("Available checkpoints:") checkpoints = list_checkpoints() for i, ckpt in enumerate(checkpoints): print(f"{i+1}. {ckpt}") return # If no checkpoint specified, use the latest one if not args.checkpoint: checkpoints = list_checkpoints() if not checkpoints: print("No checkpoints found. Please train the model first.") return # Find the latest epoch_end checkpoint end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] if end_checkpoints: latest_checkpoint = max(end_checkpoints) else: latest_checkpoint = max(checkpoints) checkpoint_path = os.path.join("checkpoints", latest_checkpoint) else: checkpoint_path = args.checkpoint # Set device if torch.backends.mps.is_available() and not force_CPU: device = torch.device("mps") else: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Using device: {device}") # Initialize tokenizer tokenizer = ByteTokenizer() # Load model model = load_model(checkpoint_path, device) # Generate text print(f"\nGenerating text with prompt: '{args.prompt}'") print( f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" ) print("\nGenerating...") generated_text, full_text = generate_text( model=model, tokenizer=tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, device=device, ) print("\n\nGenerated completion only:") print("-" * 40) print(generated_text) print("-" * 40) print("\nFull generated text (prompt + completion):") print("-" * 40) print(full_text) print("-" * 40) if __name__ == "__main__": import argparse main()