DSDUDEd commited on
Commit
1c686e6
·
verified ·
1 Parent(s): 06059b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -1,23 +1,23 @@
1
- # app.py
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- # Load model and tokenizer
7
  model_name = "DSDUDEd/firebase"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
- # Ensure model uses GPU if available
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model = model.to(device)
 
 
 
 
14
 
15
- # Function to generate responses
16
  def chat_with_model(user_input, chat_history=[]):
17
- # Append user input to history
18
  chat_history.append({"role": "user", "content": user_input})
19
 
20
- # Prepare prompt
21
  prompt = ""
22
  for turn in chat_history:
23
  if turn["role"] == "user":
@@ -25,7 +25,8 @@ def chat_with_model(user_input, chat_history=[]):
25
  else:
26
  prompt += f"AI: {turn['content']}\n"
27
 
28
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
29
  outputs = model.generate(
30
  **inputs,
31
  max_new_tokens=150,
@@ -33,14 +34,14 @@ def chat_with_model(user_input, chat_history=[]):
33
  top_p=0.9,
34
  temperature=0.7,
35
  )
36
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
38
- # Extract AI response (assume last part after "AI: ")
 
39
  response_text = response.split("AI:")[-1].strip()
40
 
41
  chat_history.append({"role": "ai", "content": response_text})
42
 
43
- # Prepare chat history for Gradio
44
  chat_for_gradio = [(turn["content"], "") if turn["role"]=="user" else ("", turn["content"]) for turn in chat_history]
45
 
46
  return chat_for_gradio, chat_history
@@ -51,7 +52,8 @@ with gr.Blocks() as demo:
51
  chatbot = gr.Chatbot()
52
  msg = gr.Textbox(label="Enter your message")
53
  submit = gr.Button("Send")
54
-
55
  submit.click(chat_with_model, inputs=[msg, chat_history_state], outputs=[chatbot, chat_history_state])
56
 
57
  demo.launch()
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Model name
6
  model_name = "DSDUDEd/firebase"
 
 
7
 
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ device_map="auto", # automatically assigns to GPU if available
13
+ load_in_8bit=True # load in 8-bit to save memory
14
+ )
15
 
16
+ # Function to generate AI responses
17
  def chat_with_model(user_input, chat_history=[]):
 
18
  chat_history.append({"role": "user", "content": user_input})
19
 
20
+ # Build the prompt from chat history
21
  prompt = ""
22
  for turn in chat_history:
23
  if turn["role"] == "user":
 
25
  else:
26
  prompt += f"AI: {turn['content']}\n"
27
 
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
+
30
  outputs = model.generate(
31
  **inputs,
32
  max_new_tokens=150,
 
34
  top_p=0.9,
35
  temperature=0.7,
36
  )
 
37
 
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ # Get only the AI's response
40
  response_text = response.split("AI:")[-1].strip()
41
 
42
  chat_history.append({"role": "ai", "content": response_text})
43
 
44
+ # Prepare Gradio chat format
45
  chat_for_gradio = [(turn["content"], "") if turn["role"]=="user" else ("", turn["content"]) for turn in chat_history]
46
 
47
  return chat_for_gradio, chat_history
 
52
  chatbot = gr.Chatbot()
53
  msg = gr.Textbox(label="Enter your message")
54
  submit = gr.Button("Send")
55
+
56
  submit.click(chat_with_model, inputs=[msg, chat_history_state], outputs=[chatbot, chat_history_state])
57
 
58
  demo.launch()
59
+