Dave-test-1 / app.py
DSDUDEd's picture
Update app.py
56f888d verified
raw
history blame
1.38 kB
import gradio as gr
import torch
from transformers import PreTrainedTokenizerFast
import os
# -----------------------------
# Load tokenizer
# -----------------------------
tokenizer_path = "./Dave/tokenizer.json"
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
# -----------------------------
# Load custom model
# -----------------------------
from dave_model import SimpleTransformer # put your model class in dave_model.py
vocab_size = tokenizer.vocab_size
model_path = "./Dave/pytorch_model.bin"
model = SimpleTransformer(vocab_size=vocab_size)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# -----------------------------
# Inference function
# -----------------------------
def generate_response(prompt):
input_ids = torch.tensor([tokenizer.encode(prompt).ids])
with torch.no_grad():
output = model(input_ids)
predicted_ids = output.argmax(-1)[0].tolist()
response = tokenizer.decode(predicted_ids)
return response
# -----------------------------
# Gradio interface
# -----------------------------
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=2, placeholder="Type your prompt here..."),
outputs="text",
title="Dave – Fully Custom AI",
description="Interact with your fully custom AI model."
)
if __name__ == "__main__":
iface.launch()