Text Generation
Safetensors
English
DAT-Byte-Small / run.py
hudsongouge's picture
Re-upload run.py
be18c4a verified
from inference.inference import (
force_CPU,
generate_text_stream,
list_checkpoints,
load_model,
)
import argparse
import torch
from inference.model import ByteTokenizer
import os
import sys
def main():
parser = argparse.ArgumentParser(
description="Text generation with DiffAttention LLM",
formatter_class=argparse.RawTextHelpFormatter,
)
# Generation mode arguments
parser.add_argument(
"--prompt",
type=str,
default="",
help="Run in single-shot mode with the given prompt.",
)
parser.add_argument(
"-c", "--chat", action="store_true", help="Run in interactive chat mode."
)
# Chat mode arguments
parser.add_argument(
"--system",
type=str,
default="You are a helpful chatbot.",
help="System prompt for chat mode.",
)
parser.add_argument(
"--user_role",
type=str,
default="user",
help="Role name for the user in chat mode.",
)
parser.add_argument(
"--assistant_role",
type=str,
default="assistant",
help="Role name for the assistant in chat mode.",
)
# Common arguments
parser.add_argument(
"--checkpoint",
type=str,
default="model.pt",
help="Path to the checkpoint file.",
)
parser.add_argument(
"--stop",
nargs="+",
default=[],
help='One or more stop sequences. e.g. --stop "world" """',
)
parser.add_argument(
"--max_tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--temperature", type=float, default=0.35, help="Sampling temperature."
)
parser.add_argument(
"--top_k",
type=int,
default=7,
help="Top-k sampling parameter (0 to disable).",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.35,
help="Repetition penalty (1.0 for no penalty).",
)
parser.add_argument(
"--list_checkpoints",
action="store_true",
help="List available checkpoints and exit.",
)
args = parser.parse_args()
if not args.prompt and not args.chat and not args.list_checkpoints:
parser.print_help()
sys.exit(
"\nError: Either --prompt, --chat, or --list_checkpoints must be specified."
)
# List checkpoints if requested
if args.list_checkpoints:
print("Available checkpoints:")
checkpoints = list_checkpoints()
if not checkpoints:
print("No checkpoints found.")
for i, ckpt in enumerate(checkpoints):
print(f"{i+1}. {ckpt}")
return
checkpoint_path = args.checkpoint
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
print("Searching for latest checkpoint in 'checkpoints/' directory...")
checkpoints = list_checkpoints()
if not checkpoints:
sys.exit(
"No checkpoints found. Please train a model or specify a valid path."
)
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
if end_checkpoints:
latest_checkpoint = max(end_checkpoints)
else:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
print(f"Using latest checkpoint: {checkpoint_path}")
# Set device
if torch.backends.mps.is_available() and not force_CPU:
device = torch.device("mps")
else:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Using device: {device}")
tokenizer = ByteTokenizer()
# Load model
model = load_model(checkpoint_path, device)
# --- Mode Handling ---
if args.chat:
stop_sequences = args.stop + ["<|im_end|>"]
history = f"<|im_start|>system\n{args.system}<|im_end|>\n"
print("\n--- Interactive Chat ---")
print(f"System Prompt: {args.system}")
print("Type 'exit' or 'quit' to end the session.")
print("-" * 26)
while True:
try:
user_prompt_display = f"<|im_start|>{args.user_role}\n"
user_input = input(user_prompt_display)
if user_input.lower() in ["exit", "quit"]:
break
prompt = (
history
+ f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
+ f"<|im_start|>{args.assistant_role}\n"
)
print(f"<|im_start|>{args.assistant_role}")
sys.stdout.flush()
generated_text_parts = []
for chunk in generate_text_stream(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
device=device,
stop_sequences=stop_sequences,
):
print(chunk, end="", flush=True)
generated_text_parts.append(chunk)
generated_text = "".join(generated_text_parts)
history += (
f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
+ f"<|im_start|>{args.assistant_role}\n{generated_text}<|im_end|>\n"
)
print() # Newline after assistant output
except (KeyboardInterrupt, EOFError):
print("\nExiting chat.")
break
else:
print(f"\nGenerating text with prompt: '{args.prompt}'")
print(
f"Parameters: temp={args.temperature}, top_k={args.top_k}, repetition_penalty={args.repetition_penalty}"
)
print("\n--- Generation Start ---")
generated_text_parts = []
for chunk in generate_text_stream(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
device=device,
stop_sequences=args.stop,
):
print(chunk, end="", flush=True)
generated_text_parts.append(chunk)
print("\n--- Generation End ---")
generated_text = "".join(generated_text_parts)
full_text = args.prompt + generated_text
print("\n\nFull generated text (for reference):")
print("-" * 40)
print(full_text)
print("-" * 40)
if __name__ == "__main__":
main()