dfbowers's picture
Update app.py
b7fd636 verified
import gradio as gr
from transformers import pipeline
# Define available models
MODELS = {
"DistilBERT (English, binary sentiment)": "distilbert-base-uncased-finetuned-sst-2-english",
"Twitter RoBERTa (multi-sentiment)": "cardiffnlp/twitter-roberta-base-sentiment",
"Multilingual BERT (5-class)": "nlptown/bert-base-multilingual-uncased-sentiment"
}
# Function to load selected model
def load_model(model_name):
try:
return pipeline("sentiment-analysis", model=model_name)
except Exception as e:
return None
# Cache models as they’re used
loaded_models = {}
# Function to run inference
def analyze_text(model_label, text):
if not text.strip():
return "Please enter some text."
model_name = MODELS[model_label]
if model_label not in loaded_models:
loaded_models[model_label] = load_model(model_name)
analyzer = loaded_models[model_label]
if not analyzer:
return "Error loading model."
try:
result = analyzer(text)[0]
label = result.get("label", "UNKNOWN")
score = round(result.get("score", 0), 3)
# Fix model-specific labels for clarity
if "cardiffnlp/twitter-roberta-base-sentiment" in model_name:
label_map = {"LABEL_0": "Negative", "LABEL_1": "Neutral", "LABEL_2": "Positive"}
label = label_map.get(label, label)
elif "nlptown/bert-base-multilingual-uncased-sentiment" in model_name:
label = label.replace("LABEL_", "⭐️ ") # optional, fun visual
# Strip out any extra text or URLs that sneak into outputs
label = str(label).split("https://")[0].strip()
return f"{label} ({score})"
except Exception as e:
return f"Error during inference: {e}"
# Build Gradio interface
with gr.Blocks(title="CIS1160 – Inference Explorer") as demo:
gr.Markdown(
"""
### Explore AI Model Inference
Select a model and enter text to see how it interprets sentiment.
- **DistilBERT:** Trained on English movie reviews (2 classes: Positive/Negative)
- **RoBERTa (Twitter):** Trained on tweets (3 classes: Negative/Neutral/Positive)
- **Multilingual BERT:** Handles multiple languages (5-star rating output)
"""
)
with gr.Row():
model_choice = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Choose a model"
)
input_text = gr.Textbox(label="Enter text to analyze", lines=2)
output_text = gr.Textbox(label="Model Prediction")
run_button = gr.Button("Run Inference")
run_button.click(analyze_text, inputs=[model_choice, input_text], outputs=output_text)
if __name__ == "__main__":
demo.launch()