dfbowers commited on
Commit
990c2ea
·
verified ·
1 Parent(s): 9c2e6c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -86
app.py CHANGED
@@ -1,108 +1,37 @@
1
  import gradio as gr
2
- import time
3
- import random
4
- import matplotlib.pyplot as plt
5
  from transformers import pipeline
6
 
 
7
  MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
8
 
9
- # -------- Load the real inference model once at startup --------
10
- try:
11
- sentiment_analyzer = pipeline("sentiment-analysis", model=MODEL_NAME)
12
- except Exception as e:
13
- sentiment_analyzer = None
14
- print("Model load failed:", e)
15
 
16
- # -------- Simulated Training with Chart + Log --------
17
- def simulate_training(epochs, learning_rate):
18
- import matplotlib.pyplot as plt
19
- import random, time
20
-
21
- accuracies = []
22
- logs = []
23
-
24
- # Start lower for visible climb
25
- current_acc = 0.55 + random.uniform(0, 0.05)
26
-
27
- for epoch in range(epochs):
28
- time.sleep(0.3)
29
-
30
- # Improvement scaled for realistic visible growth
31
- improvement = learning_rate * random.uniform(20, 80)
32
- current_acc += improvement
33
-
34
- # Add small random noise and plateau effect
35
- if epoch > epochs * 0.6:
36
- current_acc += random.uniform(-0.015, 0.005)
37
-
38
- # Clamp values between 0.55 and 0.95
39
- current_acc = max(min(current_acc, 0.95), 0.55)
40
- accuracies.append(round(current_acc, 3))
41
- logs.append(f"Epoch {epoch+1}: Validation Accuracy = {current_acc:.3f}")
42
-
43
- # Create the chart
44
- plt.figure(figsize=(4, 2))
45
- plt.plot(range(1, epochs + 1), accuracies, marker="o", color="tab:blue")
46
- plt.axhline(y=accuracies[0], color="gray", linestyle="--", linewidth=1, label="Starting accuracy")
47
- plt.title("Simulated Validation Accuracy per Epoch")
48
- plt.xlabel("Epoch")
49
- plt.ylabel("Accuracy")
50
- plt.ylim(0.5, 1.0)
51
- plt.grid(True)
52
- plt.legend()
53
-
54
- final_acc = round(accuracies[-1], 3)
55
- return plt, "\n".join(logs), final_acc
56
-
57
-
58
- # -------- Real Inference Function --------
59
  def analyze_text(text):
60
- if not sentiment_analyzer:
61
- return "Model not loaded. Refresh the page and try again."
62
  try:
63
  result = sentiment_analyzer(text)[0]
64
  label = result.get("label", "UNKNOWN")
65
- score = round(result.get("score", 0), 3)
66
  return f"{label} ({score})"
67
  except Exception as e:
68
  return f"Error during inference: {e}"
69
 
70
- # -------- Gradio Interface Layout --------
71
- with gr.Blocks(title="CIS1160 Training & Inference Demo") as demo:
72
- gr.Markdown(
73
- "### Part 1 – Simulated Training\n"
74
- "Adjust the settings and click **Train** to visualize accuracy improvement.\n"
75
- "_(This is a simulation – no real data is trained.)_"
76
- )
77
-
78
- with gr.Row():
79
- epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
80
- lr = gr.Slider(0.001, 0.01, value=0.005, step=0.001, label="Learning Rate")
81
-
82
- train_button = gr.Button("Train (Simulated)")
83
- train_chart = gr.Plot(label="Simulated Validation Accuracy Chart")
84
- train_log = gr.Textbox(label="Training Log", lines=8)
85
- final_acc = gr.Number(label="Final Accuracy")
86
-
87
- train_button.click(
88
- simulate_training,
89
- inputs=[epochs, lr],
90
- outputs=[train_chart, train_log, final_acc]
91
- )
92
-
93
  gr.Markdown(
94
- "### Part 2 – Real Inference\n"
95
- "Use the trained model to make predictions on new text.\n"
96
- "_(This uses a real pre-trained model for sentiment analysis.)_"
 
 
97
  )
98
 
99
- with gr.Row():
100
- input_text = gr.Textbox(label="Enter text to analyze")
101
- output_text = gr.Textbox(label="Sentiment Prediction")
102
 
103
  run_button = gr.Button("Run Inference")
104
  run_button.click(analyze_text, inputs=input_text, outputs=output_text)
105
 
106
- # -------- Launch Application --------
107
  if __name__ == "__main__":
108
- demo.launch()
 
1
  import gradio as gr
 
 
 
2
  from transformers import pipeline
3
 
4
+ # Choose a reliable model (you can change this)
5
  MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
6
 
7
+ # Load the model once
8
+ sentiment_analyzer = pipeline("sentiment-analysis", model=MODEL_NAME)
 
 
 
 
9
 
10
+ # Function to analyze text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def analyze_text(text):
 
 
12
  try:
13
  result = sentiment_analyzer(text)[0]
14
  label = result.get("label", "UNKNOWN")
15
+ score = round(result.get("score", 3))
16
  return f"{label} ({score})"
17
  except Exception as e:
18
  return f"Error during inference: {e}"
19
 
20
+ # Gradio interface
21
+ with gr.Blocks(title="CIS1160 LLM Inference Demo") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  gr.Markdown(
23
+ """
24
+ ### Explore Inference
25
+ Enter any sentence and see how a trained model interprets it.
26
+ Try clearly positive, negative, and neutral examples.
27
+ """
28
  )
29
 
30
+ input_text = gr.Textbox(label="Enter text to analyze", lines=2)
31
+ output_text = gr.Textbox(label="Model Prediction")
 
32
 
33
  run_button = gr.Button("Run Inference")
34
  run_button.click(analyze_text, inputs=input_text, outputs=output_text)
35
 
 
36
  if __name__ == "__main__":
37
+ demo.launch()