DSDUDEd commited on
Commit
c46d75e
·
verified ·
1 Parent(s): 320dd2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -28
app.py CHANGED
@@ -1,49 +1,74 @@
1
- # app.py – Hugging Face Space for Fully Custom "Dave" Model
2
  import torch
3
- from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  import uvicorn
 
7
 
8
  # -----------------------------
9
- # Load tokenizer and model
10
  # -----------------------------
11
- tokenizer_path = "tokenizer.json"
12
- model_path = "pytorch_model.bin"
 
 
 
13
 
14
- tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- pretrained_model_name_or_path=".",
17
- config="config.json",
18
- state_dict=torch.load(model_path, map_location="cpu")
19
- )
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model.eval()
22
 
23
  # -----------------------------
24
- # FastAPI app
25
  # -----------------------------
26
  app = FastAPI()
27
 
28
- class Prompt(BaseModel):
29
- text: str
 
30
 
31
  @app.post("/generate")
32
- def generate_text(prompt: Prompt):
33
- inputs = tokenizer(prompt.text, return_tensors="pt")
 
34
  with torch.no_grad():
35
- outputs = model.generate(
36
- **inputs,
37
- max_length=64,
38
- do_sample=True,
39
- temperature=0.7,
40
- top_p=0.9
41
- )
42
- decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
43
- return {"response": decoded[0]}
44
-
45
- # -----------------------------
46
- # Run the app (for local testing)
47
  # -----------------------------
48
  if __name__ == "__main__":
49
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ from tokenizers import Tokenizer
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  import uvicorn
7
+ import json
8
 
9
  # -----------------------------
10
+ # Settings
11
  # -----------------------------
12
+ MODEL_PATH = "./pytorch_model.bin"
13
+ TOKENIZER_PATH = "./tokenizer.json"
14
+ CONFIG_PATH = "./config.json"
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # -----------------------------
19
+ # Load config
20
+ # -----------------------------
21
+ with open(CONFIG_PATH) as f:
22
+ config = json.load(f)
23
+
24
+ # -----------------------------
25
+ # Define the same architecture
26
+ # -----------------------------
27
+ class SimpleTransformer(nn.Module):
28
+ def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=4):
29
+ super().__init__()
30
+ self.embedding = nn.Embedding(vocab_size, d_model)
31
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
32
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
33
+ self.fc = nn.Linear(d_model, vocab_size)
34
 
35
+ def forward(self, x):
36
+ x = self.embedding(x)
37
+ x = x.transpose(0,1)
38
+ x = self.transformer(x)
39
+ x = x.transpose(0,1)
40
+ return self.fc(x)
41
+
42
+ # -----------------------------
43
+ # Load tokenizer and model
44
+ # -----------------------------
45
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
46
+ vocab_size = config["vocab_size"]
47
+ model = SimpleTransformer(vocab_size, config["d_model"], config["nhead"], config["num_layers"]).to(device)
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
49
  model.eval()
50
 
51
  # -----------------------------
52
+ # FastAPI setup
53
  # -----------------------------
54
  app = FastAPI()
55
 
56
+ class Query(BaseModel):
57
+ prompt: str
58
+ max_length: int = 64
59
 
60
  @app.post("/generate")
61
+ def generate(query: Query):
62
+ input_ids = tokenizer.encode(query.prompt).ids
63
+ input_tensor = torch.tensor([input_ids], device=device)
64
  with torch.no_grad():
65
+ output = model(input_tensor)
66
+ predicted_ids = torch.argmax(output, dim=-1).squeeze().tolist()
67
+ response = tokenizer.decode(predicted_ids, skip_special_tokens=True)
68
+ return {"response": response}
69
+
70
+ # -----------------------------
71
+ # For running locally
 
 
 
 
 
72
  # -----------------------------
73
  if __name__ == "__main__":
74
  uvicorn.run(app, host="0.0.0.0", port=7860)