rasouzadev's picture
trying fast api again
d89a33f verified
raw
history blame
2.53 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from typing import Dict, Any
import uvicorn
INTENT_MODEL = "rasouzadev/medgo-intent-classifier"
HATE_MODEL = "unitary/unbiased-toxic-roberta"
print("Loading pipelines (this may take a minute)...")
intent_pipe = pipeline("text-classification", model=INTENT_MODEL, truncation=True)
hate_pipe = pipeline("text-classification", model=HATE_MODEL, truncation=True, top_k=None)
app = FastAPI(title="MedGo - Intent & Hate Detector API")
class InputText(BaseModel):
text: str
class PredictionResponse(BaseModel):
intent: str
intent_score: float
hate_label: str | None
hate_score: float
note: str | None
def unify_scores(pipe_output):
if not pipe_output:
return []
if isinstance(pipe_output[0], list):
return pipe_output[0]
return pipe_output
@app.get("/")
def root():
return {
"message": "MedGo API - Intent & Hate Detector",
"endpoints": {
"predict": "/predict",
"health": "/health",
"docs": "/docs"
}
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
def classify(input_data: InputText) -> Dict[str, Any]:
text = input_data.text
hate_raw = hate_pipe(text)
hate_scores = unify_scores(hate_raw)
best_hate = max(hate_scores, key=lambda x: x.get("score", 0.0), default=None)
hate_label = best_hate.get("label") if best_hate else None
hate_score = float(best_hate.get("score", 0.0)) if best_hate else 0.0
if hate_label:
low = hate_label.lower()
if any(k in low for k in ["toxic", "hate", "offensive", "insult", "abusive"]) and hate_score >= 0.6:
return {
"intent": "HateSpeech",
"intent_score": hate_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": "flagged_by_hate_model"
}
intent_res = intent_pipe(text)
intent_label = intent_res[0].get("label") if intent_res and isinstance(intent_res, list) else None
intent_score = float(intent_res[0].get("score", 0.0)) if intent_res and isinstance(intent_res, list) else 0.0
return {
"intent": intent_label,
"intent_score": intent_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": None
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)