Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| import torch.nn as nn | |
| from tokenizers import Tokenizer | |
| import gradio as gr | |
| # ----------------------------- | |
| # Load tokenizer | |
| # ----------------------------- | |
| tokenizer_path = "tokenizer.json" | |
| tokenizer = Tokenizer.from_file(tokenizer_path) | |
| vocab_size = tokenizer.get_vocab_size() | |
| # ----------------------------- | |
| # Define the same transformer as used in training | |
| # ----------------------------- | |
| class SimpleTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=4): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x = x.transpose(0, 1) | |
| x = self.transformer(x) | |
| x = x.transpose(0, 1) | |
| return self.fc(x) | |
| # ----------------------------- | |
| # Load model weights | |
| # ----------------------------- | |
| model = SimpleTransformer(vocab_size) | |
| model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu")) | |
| model.eval() | |
| # ----------------------------- | |
| # Generation function | |
| # ----------------------------- | |
| def generate(prompt): | |
| input_ids = tokenizer.encode(prompt).ids | |
| input_tensor = torch.tensor([input_ids]) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| predicted_ids = torch.argmax(output, dim=-1)[0].tolist() | |
| response_text = tokenizer.decode(predicted_ids) | |
| return response_text | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| iface = gr.Interface( | |
| fn=generate, | |
| inputs=gr.Textbox(lines=2, placeholder="Type a prompt for Dave..."), | |
| outputs="text", | |
| title="Dave β Fully Custom AI", | |
| description="Interact with your fully custom AI trained from scratch." | |
| ) | |
| iface.launch() | |