File size: 2,775 Bytes
2f33571
 
c1acc0b
136b8fc
 
 
 
 
 
2f33571
136b8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f33571
 
136b8fc
2f33571
b7fd636
 
 
 
 
 
 
 
 
 
 
 
 
2f33571
 
 
 
b7fd636
136b8fc
 
2f33571
990c2ea
136b8fc
 
 
 
 
 
990c2ea
2f33571
 
136b8fc
 
 
 
 
 
 
990c2ea
 
2f33571
 
136b8fc
2f33571
 
990c2ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import pipeline

# Define available models
MODELS = {
    "DistilBERT (English, binary sentiment)": "distilbert-base-uncased-finetuned-sst-2-english",
    "Twitter RoBERTa (multi-sentiment)": "cardiffnlp/twitter-roberta-base-sentiment",
    "Multilingual BERT (5-class)": "nlptown/bert-base-multilingual-uncased-sentiment"
}

# Function to load selected model
def load_model(model_name):
    try:
        return pipeline("sentiment-analysis", model=model_name)
    except Exception as e:
        return None

# Cache models as they’re used
loaded_models = {}

# Function to run inference
def analyze_text(model_label, text):
    if not text.strip():
        return "Please enter some text."
    model_name = MODELS[model_label]

    if model_label not in loaded_models:
        loaded_models[model_label] = load_model(model_name)

    analyzer = loaded_models[model_label]
    if not analyzer:
        return "Error loading model."

    try:
        result = analyzer(text)[0]
        label = result.get("label", "UNKNOWN")
        score = round(result.get("score", 0), 3)

        # Fix model-specific labels for clarity
        if "cardiffnlp/twitter-roberta-base-sentiment" in model_name:
            label_map = {"LABEL_0": "Negative", "LABEL_1": "Neutral", "LABEL_2": "Positive"}
            label = label_map.get(label, label)

        elif "nlptown/bert-base-multilingual-uncased-sentiment" in model_name:
            label = label.replace("LABEL_", "⭐️ ")  # optional, fun visual

        # Strip out any extra text or URLs that sneak into outputs
        label = str(label).split("https://")[0].strip()

        return f"{label} ({score})"
    except Exception as e:
        return f"Error during inference: {e}"


# Build Gradio interface
with gr.Blocks(title="CIS1160 – Inference Explorer") as demo:
    gr.Markdown(
        """
        ### Explore AI Model Inference
        Select a model and enter text to see how it interprets sentiment.

        - **DistilBERT:** Trained on English movie reviews (2 classes: Positive/Negative)  
        - **RoBERTa (Twitter):** Trained on tweets (3 classes: Negative/Neutral/Positive)  
        - **Multilingual BERT:** Handles multiple languages (5-star rating output)
        """
    )

    with gr.Row():
        model_choice = gr.Dropdown(
            choices=list(MODELS.keys()),
            value=list(MODELS.keys())[0],
            label="Choose a model"
        )

    input_text = gr.Textbox(label="Enter text to analyze", lines=2)
    output_text = gr.Textbox(label="Model Prediction")

    run_button = gr.Button("Run Inference")
    run_button.click(analyze_text, inputs=[model_choice, input_text], outputs=output_text)

if __name__ == "__main__":
    demo.launch()