|
|
from inference.inference import ( |
|
|
force_CPU, |
|
|
generate_text_stream, |
|
|
list_checkpoints, |
|
|
load_model, |
|
|
) |
|
|
import argparse |
|
|
import torch |
|
|
from inference.model import ByteTokenizer |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Text generation with DiffAttention LLM", |
|
|
formatter_class=argparse.RawTextHelpFormatter, |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default="", |
|
|
help="Run in single-shot mode with the given prompt.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-c", "--chat", action="store_true", help="Run in interactive chat mode." |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--system", |
|
|
type=str, |
|
|
default="You are a helpful chatbot.", |
|
|
help="System prompt for chat mode.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--user_role", |
|
|
type=str, |
|
|
default="user", |
|
|
help="Role name for the user in chat mode.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--assistant_role", |
|
|
type=str, |
|
|
default="assistant", |
|
|
help="Role name for the assistant in chat mode.", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
type=str, |
|
|
default="model.pt", |
|
|
help="Path to the checkpoint file.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stop", |
|
|
nargs="+", |
|
|
default=[], |
|
|
help='One or more stop sequences. e.g. --stop "world" """', |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_tokens", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum number of new tokens to generate.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", type=float, default=0.35, help="Sampling temperature." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_k", |
|
|
type=int, |
|
|
default=7, |
|
|
help="Top-k sampling parameter (0 to disable).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--repetition_penalty", |
|
|
type=float, |
|
|
default=1.35, |
|
|
help="Repetition penalty (1.0 for no penalty).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--list_checkpoints", |
|
|
action="store_true", |
|
|
help="List available checkpoints and exit.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if not args.prompt and not args.chat and not args.list_checkpoints: |
|
|
parser.print_help() |
|
|
sys.exit( |
|
|
"\nError: Either --prompt, --chat, or --list_checkpoints must be specified." |
|
|
) |
|
|
|
|
|
|
|
|
if args.list_checkpoints: |
|
|
print("Available checkpoints:") |
|
|
checkpoints = list_checkpoints() |
|
|
if not checkpoints: |
|
|
print("No checkpoints found.") |
|
|
for i, ckpt in enumerate(checkpoints): |
|
|
print(f"{i+1}. {ckpt}") |
|
|
return |
|
|
|
|
|
checkpoint_path = args.checkpoint |
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"Checkpoint file not found: {checkpoint_path}") |
|
|
print("Searching for latest checkpoint in 'checkpoints/' directory...") |
|
|
checkpoints = list_checkpoints() |
|
|
if not checkpoints: |
|
|
sys.exit( |
|
|
"No checkpoints found. Please train a model or specify a valid path." |
|
|
) |
|
|
|
|
|
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] |
|
|
if end_checkpoints: |
|
|
latest_checkpoint = max(end_checkpoints) |
|
|
else: |
|
|
latest_checkpoint = max(checkpoints) |
|
|
|
|
|
checkpoint_path = os.path.join("checkpoints", latest_checkpoint) |
|
|
print(f"Using latest checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available() and not force_CPU: |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu" |
|
|
) |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
tokenizer = ByteTokenizer() |
|
|
|
|
|
|
|
|
model = load_model(checkpoint_path, device) |
|
|
|
|
|
|
|
|
if args.chat: |
|
|
stop_sequences = args.stop + ["<|im_end|>"] |
|
|
history = f"<|im_start|>system\n{args.system}<|im_end|>\n" |
|
|
print("\n--- Interactive Chat ---") |
|
|
print(f"System Prompt: {args.system}") |
|
|
print("Type 'exit' or 'quit' to end the session.") |
|
|
print("-" * 26) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_prompt_display = f"<|im_start|>{args.user_role}\n" |
|
|
user_input = input(user_prompt_display) |
|
|
|
|
|
if user_input.lower() in ["exit", "quit"]: |
|
|
break |
|
|
|
|
|
prompt = ( |
|
|
history |
|
|
+ f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n" |
|
|
+ f"<|im_start|>{args.assistant_role}\n" |
|
|
) |
|
|
|
|
|
print(f"<|im_start|>{args.assistant_role}") |
|
|
sys.stdout.flush() |
|
|
|
|
|
generated_text_parts = [] |
|
|
for chunk in generate_text_stream( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_k=args.top_k, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
device=device, |
|
|
stop_sequences=stop_sequences, |
|
|
): |
|
|
print(chunk, end="", flush=True) |
|
|
generated_text_parts.append(chunk) |
|
|
|
|
|
generated_text = "".join(generated_text_parts) |
|
|
|
|
|
history += ( |
|
|
f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n" |
|
|
+ f"<|im_start|>{args.assistant_role}\n{generated_text}<|im_end|>\n" |
|
|
) |
|
|
print() |
|
|
|
|
|
except (KeyboardInterrupt, EOFError): |
|
|
print("\nExiting chat.") |
|
|
break |
|
|
else: |
|
|
print(f"\nGenerating text with prompt: '{args.prompt}'") |
|
|
print( |
|
|
f"Parameters: temp={args.temperature}, top_k={args.top_k}, repetition_penalty={args.repetition_penalty}" |
|
|
) |
|
|
print("\n--- Generation Start ---") |
|
|
|
|
|
generated_text_parts = [] |
|
|
for chunk in generate_text_stream( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_k=args.top_k, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
device=device, |
|
|
stop_sequences=args.stop, |
|
|
): |
|
|
print(chunk, end="", flush=True) |
|
|
generated_text_parts.append(chunk) |
|
|
|
|
|
print("\n--- Generation End ---") |
|
|
|
|
|
generated_text = "".join(generated_text_parts) |
|
|
full_text = args.prompt + generated_text |
|
|
|
|
|
print("\n\nFull generated text (for reference):") |
|
|
print("-" * 40) |
|
|
print(full_text) |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|