VeuReu commited on
Commit
fe0e2a2
·
verified ·
1 Parent(s): 102b5e7

Create refinement_router.py

Browse files
Files changed (1) hide show
  1. main_process/refinement_router.py +205 -0
main_process/refinement_router.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import os
7
+ import yaml
8
+ from fastapi import FastAPI, HTTPException, APIRouter, UploadFile, File, Query
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+
12
+ from .refinement.multiagent_refinement import (
13
+ execute_refinement,
14
+ execute_refinement_for_video,
15
+ _load_refinement_flags,
16
+ )
17
+
18
+
19
+
20
+ # --- Config y autenticación sencilla por token ---
21
+
22
+ ROOT = Path(__file__).resolve().parent
23
+ CONFIG_PATH = "config.yaml"
24
+ router = APIRouter(prefix="/refinement", tags=["Refinement Process"])
25
+
26
+
27
+ def _load_engine_token() -> Optional[str]:
28
+ """Carga el token compartido del engine desde config.yaml o variables de entorno.
29
+
30
+ Sigue la misma convención que engine/api.py: usa API_SHARED_TOKEN si está
31
+ definido por entorno; en caso contrario, intenta leer demo/config.yaml.
32
+ """
33
+
34
+ env_token = os.getenv("API_SHARED_TOKEN")
35
+ if env_token:
36
+ return env_token
37
+
38
+ # Fallback: leer demo/config.yaml desde la raíz del repo
39
+ try:
40
+ repo_root = ROOT
41
+ # Intentar detectar la carpeta demo en el mismo repo
42
+ demo_cfg = repo_root / "demo" / "config.yaml"
43
+ if demo_cfg.exists():
44
+ with demo_cfg.open("r", encoding="utf-8") as f:
45
+ cfg = yaml.safe_load(f) or {}
46
+ api_cfg = cfg.get("api", {}) or {}
47
+ token = api_cfg.get("token")
48
+ if token and isinstance(token, str):
49
+ # Cuando viene de YAML con "${API_SHARED_TOKEN}" puede no estar
50
+ # resuelto; en ese caso preferimos None para forzar uso de entorno.
51
+ if "${" in token and "}" in token:
52
+ return None
53
+ return token
54
+ except Exception:
55
+ pass
56
+
57
+ return None
58
+
59
+
60
+ ENGINE_TOKEN = _load_engine_token()
61
+
62
+
63
+ def _assert_valid_token(token: str | None) -> None:
64
+ expected = ENGINE_TOKEN
65
+ if not expected:
66
+ # Si no hay token configurado, consideramos que la auth está desactivada
67
+ return
68
+ if not token or token != expected:
69
+ raise HTTPException(status_code=401, detail="Invalid or missing engine token")
70
+
71
+
72
+ # --- Esquemas de entrada/salida ---
73
+
74
+
75
+ class ApplyRefinementRequest(BaseModel):
76
+ token: Optional[str] = Field(None, description="Engine shared token")
77
+ srt_content: Optional[str] = Field(
78
+ None,
79
+ description=(
80
+ "Contenido del SRT a refinar. Opcional si se proporciona sha1sum+version, "
81
+ "en cuyo caso se leerá el SRT desde audiodescriptions.db."
82
+ ),
83
+ )
84
+ sha1sum: Optional[str] = Field(
85
+ None,
86
+ description=(
87
+ "Identificador sha1sum del vídeo. Si se proporciona junto con version, "
88
+ "se utilizará el pipeline basat en BDs (audiodescriptions.db, casting.db, scenarios.db)."
89
+ ),
90
+ )
91
+ version: Optional[str] = Field(
92
+ None,
93
+ description=(
94
+ "Versió de l'audiodescripció (p.ex. 'MoE', 'Salamandra', 'HITL'). "
95
+ "Necessària si s'especifica sha1sum per utilitzar el pipeline de vídeo."
96
+ ),
97
+ )
98
+ reflection_enabled: bool = Field(True, description="Activar paso de reflection")
99
+ reflexion_enabled: bool = Field(False, description="Activar paso de reflexion")
100
+ introspection_enabled: bool = Field(False, description="Activar paso de introspection")
101
+
102
+
103
+ class ApplyRefinementResponse(BaseModel):
104
+ refined_srt: str
105
+
106
+
107
+ class TrainMultiagentRefinementRequest(BaseModel):
108
+ audiodescriptions_db_path: str = Field(..., description="Ruta a la base de datos tipo audiodescriptions.db")
109
+ videos_db_path: str = Field(..., description="Ruta a la base de datos tipo videos.db")
110
+ casting_db_path: str = Field(..., description="Ruta a la base de datos tipo casting.db")
111
+ scenarios_db_path: str = Field(..., description="Ruta a la base de datos tipo scenarios.db")
112
+ system_to_train: str = Field(..., regex="^(reflexion|introspection)$", description="Sistema a entrenar: 'reflexion' o 'introspection'")
113
+
114
+
115
+ class TrainMultiagentRefinementResponse(BaseModel):
116
+ ok: bool
117
+ detail: str
118
+
119
+
120
+ # --- Endpoints ---
121
+
122
+
123
+ @router.post("/apply_refinement", tags=["Refinement Process"], response_model=ApplyRefinementResponse)
124
+ def apply_refinement(payload: ApplyRefinementRequest) -> ApplyRefinementResponse:
125
+ """Aplica el pipeline multi‑agente de refinamiento sobre un SRT.
126
+
127
+ - Valida el token del engine.
128
+ - Aplica los pasos reflection/reflexion/introspection según los flags
129
+ recibidos en la petición.
130
+ - Devuelve el SRT refinado.
131
+ """
132
+
133
+ _assert_valid_token(payload.token)
134
+
135
+ # Partimos de los flags por defecto de config.yaml y los sobreescribimos con
136
+ # los que llegan en la petición para este job concreto.
137
+ flags = _load_refinement_flags()
138
+ flags["reflection_enabled"] = bool(payload.reflection_enabled)
139
+ flags["reflexion_enabled"] = bool(payload.reflexion_enabled)
140
+ flags["introspection_enabled"] = bool(payload.introspection_enabled)
141
+
142
+ # Ejecutar el pipeline con los flags actuales. Como execute_refinement y
143
+ # execute_refinement_for_video actualmente solo leen flags desde
144
+ # config.yaml, para no romper sus firmas guardamos temporalmente una copia
145
+ # de demo/config.yaml con los flags ajustados para esta llamada.
146
+ # NOTA: esta implementación asume ús en contextos de un sol procés.
147
+
148
+ # Localizar demo/config.yaml en la raíz del repo
149
+ repo_root = ROOT
150
+ demo_cfg = repo_root / "demo" / "config.yaml"
151
+ if not demo_cfg.exists():
152
+ raise HTTPException(status_code=500, detail="demo/config.yaml not found")
153
+
154
+ original_yaml = demo_cfg.read_text(encoding="utf-8")
155
+ try:
156
+ cfg = yaml.safe_load(original_yaml) or {}
157
+ ref_cfg = cfg.get("refinement", {}) or {}
158
+ ref_cfg["reflection_enabled"] = flags["reflection_enabled"]
159
+ ref_cfg["reflexion_enabled"] = flags["reflexion_enabled"]
160
+ ref_cfg["introspection_enabled"] = flags["introspection_enabled"]
161
+ cfg["refinement"] = ref_cfg
162
+ demo_cfg.write_text(yaml.safe_dump(cfg, allow_unicode=True), encoding="utf-8")
163
+
164
+ # Decidir el flux segons si tenim sha1sum+version o bé un SRT pla
165
+ if payload.sha1sum and payload.version:
166
+ refined = execute_refinement_for_video(
167
+ payload.sha1sum,
168
+ payload.version,
169
+ config_path=demo_cfg,
170
+ )
171
+ else:
172
+ if not payload.srt_content:
173
+ raise HTTPException(
174
+ status_code=400,
175
+ detail=(
176
+ "Cal proporcionar o bé sha1sum+version, o bé srt_content "
177
+ "per poder aplicar el refinament."
178
+ ),
179
+ )
180
+ refined = execute_refinement(payload.srt_content, config_path=demo_cfg)
181
+ finally:
182
+ # Restaurar el YAML original para no afectar a otras llamadas
183
+ demo_cfg.write_text(original_yaml, encoding="utf-8")
184
+
185
+ return ApplyRefinementResponse(refined_srt=refined)
186
+
187
+
188
+ @router.post("/train_multiagent_refinement", tags=["Refinement Process"], response_model=TrainMultiagentRefinementResponse)
189
+ def train_multiagent_refinement(payload: TrainMultiagentRefinementRequest) -> TrainMultiagentRefinementResponse:
190
+ """Endpoint placeholder para entrenar els sistemes de reflexion / introspection.
191
+
192
+ De moment no implementa cap lògica; simplement valida la càrrega i retorna
193
+ un missatge indicant que és un stub.
194
+ """
195
+
196
+ # Aquí en el futur es podrà afegir la lògica d'entrenament que utilitzi
197
+ # les bases de dades proporcionades i el flag system_to_train.
198
+
199
+ return TrainMultiagentRefinementResponse(
200
+ ok=True,
201
+ detail=(
202
+ "train_multiagent_refinement està definit com a stub; encara no s'ha "
203
+ "implementat la lògica d'entrenament per als sistemes 'reflexion' o 'introspection'."
204
+ ),
205
+ )