Spaces:
Sleeping
Sleeping
File size: 2,533 Bytes
d89a33f 180cbe7 d89a33f 180cbe7 cd82daf d89a33f cd82daf 180cbe7 d89a33f 180cbe7 cb32b30 d89a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from typing import Dict, Any
import uvicorn
INTENT_MODEL = "rasouzadev/medgo-intent-classifier"
HATE_MODEL = "unitary/unbiased-toxic-roberta"
print("Loading pipelines (this may take a minute)...")
intent_pipe = pipeline("text-classification", model=INTENT_MODEL, truncation=True)
hate_pipe = pipeline("text-classification", model=HATE_MODEL, truncation=True, top_k=None)
app = FastAPI(title="MedGo - Intent & Hate Detector API")
class InputText(BaseModel):
text: str
class PredictionResponse(BaseModel):
intent: str
intent_score: float
hate_label: str | None
hate_score: float
note: str | None
def unify_scores(pipe_output):
if not pipe_output:
return []
if isinstance(pipe_output[0], list):
return pipe_output[0]
return pipe_output
@app.get("/")
def root():
return {
"message": "MedGo API - Intent & Hate Detector",
"endpoints": {
"predict": "/predict",
"health": "/health",
"docs": "/docs"
}
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
def classify(input_data: InputText) -> Dict[str, Any]:
text = input_data.text
hate_raw = hate_pipe(text)
hate_scores = unify_scores(hate_raw)
best_hate = max(hate_scores, key=lambda x: x.get("score", 0.0), default=None)
hate_label = best_hate.get("label") if best_hate else None
hate_score = float(best_hate.get("score", 0.0)) if best_hate else 0.0
if hate_label:
low = hate_label.lower()
if any(k in low for k in ["toxic", "hate", "offensive", "insult", "abusive"]) and hate_score >= 0.6:
return {
"intent": "HateSpeech",
"intent_score": hate_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": "flagged_by_hate_model"
}
intent_res = intent_pipe(text)
intent_label = intent_res[0].get("label") if intent_res and isinstance(intent_res, list) else None
intent_score = float(intent_res[0].get("score", 0.0)) if intent_res and isinstance(intent_res, list) else 0.0
return {
"intent": intent_label,
"intent_score": intent_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": None
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |