engine / main_process /refinement_router.py
VeuReu's picture
Update main_process/refinement_router.py
2148855 verified
from __future__ import annotations
from pathlib import Path
from typing import Optional
import os
import yaml
from fastapi import FastAPI, HTTPException, APIRouter, UploadFile, File, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from refinement.multiagent_refinement import (
execute_refinement,
execute_refinement_for_video,
_load_refinement_flags,
)
# --- Config y autenticación sencilla por token ---
ROOT = Path(__file__).resolve().parent
CONFIG_PATH = "config.yaml"
router = APIRouter(prefix="/refinement", tags=["Refinement Process"])
def _load_engine_token() -> Optional[str]:
"""Carga el token compartido del engine desde config.yaml o variables de entorno.
Sigue la misma convención que engine/api.py: usa API_SHARED_TOKEN si está
definido por entorno; en caso contrario, intenta leer demo/config.yaml.
"""
env_token = os.getenv("API_SHARED_TOKEN")
if env_token:
return env_token
# Fallback: leer demo/config.yaml desde la raíz del repo
try:
repo_root = ROOT
# Intentar detectar la carpeta demo en el mismo repo
demo_cfg = repo_root / "demo" / "config.yaml"
if demo_cfg.exists():
with demo_cfg.open("r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
api_cfg = cfg.get("api", {}) or {}
token = api_cfg.get("token")
if token and isinstance(token, str):
# Cuando viene de YAML con "${API_SHARED_TOKEN}" puede no estar
# resuelto; en ese caso preferimos None para forzar uso de entorno.
if "${" in token and "}" in token:
return None
return token
except Exception:
pass
return None
ENGINE_TOKEN = _load_engine_token()
def _assert_valid_token(token: str | None) -> None:
expected = ENGINE_TOKEN
if not expected:
# Si no hay token configurado, consideramos que la auth está desactivada
return
if not token or token != expected:
raise HTTPException(status_code=401, detail="Invalid or missing engine token")
# --- Esquemas de entrada/salida ---
class ApplyRefinementRequest(BaseModel):
token: Optional[str] = Field(None, description="Engine shared token")
srt_content: Optional[str] = Field(
None,
description=(
"Contenido del SRT a refinar. Opcional si se proporciona sha1sum+version, "
"en cuyo caso se leerá el SRT desde audiodescriptions.db."
),
)
sha1sum: Optional[str] = Field(
None,
description=(
"Identificador sha1sum del vídeo. Si se proporciona junto con version, "
"se utilizará el pipeline basat en BDs (audiodescriptions.db, casting.db, scenarios.db)."
),
)
version: Optional[str] = Field(
None,
description=(
"Versió de l'audiodescripció (p.ex. 'MoE', 'Salamandra', 'HITL'). "
"Necessària si s'especifica sha1sum per utilitzar el pipeline de vídeo."
),
)
reflection_enabled: bool = Field(True, description="Activar paso de reflection")
reflexion_enabled: bool = Field(False, description="Activar paso de reflexion")
introspection_enabled: bool = Field(False, description="Activar paso de introspection")
class ApplyRefinementResponse(BaseModel):
refined_srt: str
class TrainMultiagentRefinementRequest(BaseModel):
audiodescriptions_db_path: str = Field(..., description="Ruta a la base de datos tipo audiodescriptions.db")
videos_db_path: str = Field(..., description="Ruta a la base de datos tipo videos.db")
casting_db_path: str = Field(..., description="Ruta a la base de datos tipo casting.db")
scenarios_db_path: str = Field(..., description="Ruta a la base de datos tipo scenarios.db")
system_to_train: str = Field(..., pattern="^(reflexion|introspection)$", description="Sistema a entrenar: 'reflexion' o 'introspection'")
class TrainMultiagentRefinementResponse(BaseModel):
ok: bool
detail: str
# --- Endpoints ---
@router.post("/apply_refinement", tags=["Refinement Process"], response_model=ApplyRefinementResponse)
def apply_refinement(payload: ApplyRefinementRequest) -> ApplyRefinementResponse:
"""Aplica el pipeline multi‑agente de refinamiento sobre un SRT.
- Valida el token del engine.
- Aplica los pasos reflection/reflexion/introspection según los flags
recibidos en la petición.
- Devuelve el SRT refinado.
"""
_assert_valid_token(payload.token)
# Partimos de los flags por defecto de config.yaml y los sobreescribimos con
# los que llegan en la petición para este job concreto.
flags = _load_refinement_flags()
flags["reflection_enabled"] = bool(payload.reflection_enabled)
flags["reflexion_enabled"] = bool(payload.reflexion_enabled)
flags["introspection_enabled"] = bool(payload.introspection_enabled)
# Ejecutar el pipeline con los flags actuales. Como execute_refinement y
# execute_refinement_for_video actualmente solo leen flags desde
# config.yaml, para no romper sus firmas guardamos temporalmente una copia
# de demo/config.yaml con los flags ajustados para esta llamada.
# NOTA: esta implementación asume ús en contextos de un sol procés.
# Localizar demo/config.yaml en la raíz del repo
repo_root = ROOT
demo_cfg = repo_root / "demo" / "config.yaml"
if not demo_cfg.exists():
raise HTTPException(status_code=500, detail="demo/config.yaml not found")
original_yaml = demo_cfg.read_text(encoding="utf-8")
try:
cfg = yaml.safe_load(original_yaml) or {}
ref_cfg = cfg.get("refinement", {}) or {}
ref_cfg["reflection_enabled"] = flags["reflection_enabled"]
ref_cfg["reflexion_enabled"] = flags["reflexion_enabled"]
ref_cfg["introspection_enabled"] = flags["introspection_enabled"]
cfg["refinement"] = ref_cfg
demo_cfg.write_text(yaml.safe_dump(cfg, allow_unicode=True), encoding="utf-8")
# Decidir el flux segons si tenim sha1sum+version o bé un SRT pla
if payload.sha1sum and payload.version:
refined = execute_refinement_for_video(
payload.sha1sum,
payload.version,
config_path=demo_cfg,
)
else:
if not payload.srt_content:
raise HTTPException(
status_code=400,
detail=(
"Cal proporcionar o bé sha1sum+version, o bé srt_content "
"per poder aplicar el refinament."
),
)
refined = execute_refinement(payload.srt_content, config_path=demo_cfg)
finally:
# Restaurar el YAML original para no afectar a otras llamadas
demo_cfg.write_text(original_yaml, encoding="utf-8")
return ApplyRefinementResponse(refined_srt=refined)
@router.post("/train_multiagent_refinement", tags=["Refinement Process"], response_model=TrainMultiagentRefinementResponse)
def train_multiagent_refinement(payload: TrainMultiagentRefinementRequest) -> TrainMultiagentRefinementResponse:
"""Endpoint placeholder para entrenar els sistemes de reflexion / introspection.
De moment no implementa cap lògica; simplement valida la càrrega i retorna
un missatge indicant que és un stub.
"""
# Aquí en el futur es podrà afegir la lògica d'entrenament que utilitzi
# les bases de dades proporcionades i el flag system_to_train.
return TrainMultiagentRefinementResponse(
ok=True,
detail=(
"train_multiagent_refinement està definit com a stub; encara no s'ha "
"implementat la lògica d'entrenament per als sistemes 'reflexion' o 'introspection'."
),
)