Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from typing import Dict, Any | |
| import uvicorn | |
| INTENT_MODEL = "rasouzadev/medgo-intent-classifier" | |
| HATE_MODEL = "unitary/unbiased-toxic-roberta" | |
| print("Loading pipelines (this may take a minute)...") | |
| intent_pipe = pipeline("text-classification", model=INTENT_MODEL, truncation=True) | |
| hate_pipe = pipeline("text-classification", model=HATE_MODEL, truncation=True, top_k=None) | |
| app = FastAPI(title="MedGo - Intent & Hate Detector API") | |
| class InputText(BaseModel): | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| intent: str | |
| intent_score: float | |
| hate_label: str | None | |
| hate_score: float | |
| note: str | None | |
| def unify_scores(pipe_output): | |
| if not pipe_output: | |
| return [] | |
| if isinstance(pipe_output[0], list): | |
| return pipe_output[0] | |
| return pipe_output | |
| def root(): | |
| return { | |
| "message": "MedGo API - Intent & Hate Detector", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "docs": "/docs" | |
| } | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def classify(input_data: InputText) -> Dict[str, Any]: | |
| text = input_data.text | |
| hate_raw = hate_pipe(text) | |
| hate_scores = unify_scores(hate_raw) | |
| best_hate = max(hate_scores, key=lambda x: x.get("score", 0.0), default=None) | |
| hate_label = best_hate.get("label") if best_hate else None | |
| hate_score = float(best_hate.get("score", 0.0)) if best_hate else 0.0 | |
| if hate_label: | |
| low = hate_label.lower() | |
| if any(k in low for k in ["toxic", "hate", "offensive", "insult", "abusive"]) and hate_score >= 0.6: | |
| return { | |
| "intent": "HateSpeech", | |
| "intent_score": hate_score, | |
| "hate_label": hate_label, | |
| "hate_score": hate_score, | |
| "note": "flagged_by_hate_model" | |
| } | |
| intent_res = intent_pipe(text) | |
| intent_label = intent_res[0].get("label") if intent_res and isinstance(intent_res, list) else None | |
| intent_score = float(intent_res[0].get("score", 0.0)) if intent_res and isinstance(intent_res, list) else 0.0 | |
| return { | |
| "intent": intent_label, | |
| "intent_score": intent_score, | |
| "hate_label": hate_label, | |
| "hate_score": hate_score, | |
| "note": None | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |