VeuReu commited on
Commit
e5dde7c
·
verified ·
1 Parent(s): 0c4bed4

Upload 5 files

Browse files
finetuning/finetuning.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
+ import logging
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import TypedDict, Annotated, List, Dict, Union
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
10
+ from langchain_openai import ChatOpenAI
11
+ from operator import itemgetter
12
+
13
+ # --- Configuración y Herramientas ---
14
+
15
+ # Directorios de trabajo
16
+ BASE_DIR = Path(__file__).resolve().parent
17
+ TEMP_DIR = BASE_DIR / "temp"
18
+ TEMP_DIR.mkdir(exist_ok=True)
19
+
20
+ LOG_FILE = TEMP_DIR / "finetuning.log"
21
+
22
+ # Configurar el logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(levelname)s: %(message)s',
26
+ handlers=[
27
+ logging.StreamHandler(),
28
+ logging.FileHandler(LOG_FILE, encoding="utf-8")
29
+ ],
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Asegúrate de configurar tu API Key en la variable de entorno OPENAI_API_KEY.
34
+ api_key = os.environ.get("OPENAI_API_KEY")
35
+ if not api_key:
36
+ raise EnvironmentError("OPENAI_API_KEY no está configurada. Define la variable de entorno antes de ejecutar finetuning.py.")
37
+
38
+ # Inicializar LLM (se usa GPT-4o por su capacidad de razonamiento)
39
+ # En producción, considera un modelo que soporte tus tokens y latencia requeridas.
40
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.3)
41
+
42
+ # --- Ficheros de Ejemplo ---
43
+
44
+ # Fichero SRT inicial (Narrador)
45
+ INITIAL_SRT_CONTENT = """
46
+ 1
47
+ 00:00:00,000 --> 00:00:05,340
48
+ [Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final.
49
+
50
+ 2
51
+ 00:00:04,340 --> 00:00:05,790
52
+ [Lucía] Hem de donar-ho tot.
53
+
54
+ 3
55
+ 00:00:05,790 --> 00:00:08,790
56
+ [Sandra] Ho sé, ho sé.
57
+
58
+ 4
59
+ 00:00:08,000 --> 00:00:10,000
60
+ (AD) De sobte, són al parc.
61
+
62
+ 5
63
+ 00:00:10,000 --> 00:00:14,000
64
+ (AD) Ara tallen menjar i fan una amanida a una cuina.
65
+ """
66
+
67
+ # Fichero JSON de contexto (ejemplo de la respuesta anterior, pero simplificado para el Narrador)
68
+ CONTEXT_JSON_CONTENT = """
69
+ {
70
+ "segments": [
71
+ {"id": 1, "start": "00:00:00,000", "end": "00:00:05,340", "type": "dialog", "text": "[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final."},
72
+ {"id": 2, "start": "00:00:04,340", "end": "00:00:05,790", "type": "dialog", "text": "[Lucía] Hem de donar-ho tot."},
73
+ {"id": 3, "start": "00:00:05,790", "end": "00:00:08,790", "type": "dialog", "text": "[Sandra] Ho sé, ho sé."},
74
+ {"id": 4, "start": "00:00:08,000", "end": "00:00:10,000", "type": "visual_context", "text": "Cambio de escena a un parque. Personajes caminando."},
75
+ {"id": 5, "start": "00:00:10,000", "end": "00:00:14,000", "type": "visual_context", "text": "Escena en una cocina. Los personajes están cortando vegetales y haciendo una ensalada."}
76
+ ]
77
+ }
78
+ """
79
+
80
+ # Fichero de Reglas UNE (Norma Técnica para el Crítico)
81
+ # Nota: Aquí se usa un resumen de las reglas pertinentes para un LLM.
82
+ UNE_RULES = """
83
+ ### Reglas UNE de Audiodescripción (Para el Crítico)
84
+ 1. **Objetividad y Foco Visual:** La descripción debe ser puramente objetiva, describiendo solo lo que se ve. Debe priorizar la acción y los elementos relevantes (personajes, objetos, localización).
85
+ 2. **Tiempo y Espacio (Sincronización):** Las audiodescripciones (AD) deben insertarse en los silencios del diálogo. El tiempo de la AD (entre START y END) debe ser suficiente para narrar el contenido sin solaparse con el diálogo o la música importante.
86
+ 3. **Concisión y Claridad:** Usar lenguaje simple y conciso. Evitar redundancias y juicios de valor.
87
+ 4. **Formato:** Cada segmento de AD debe tener un formato SRT válido, incluyendo el marcador (AD) al principio de la línea de texto.
88
+ 5. **Utilidad:** Cada segmento de AD debe ser útil para la comprensión y nunca ser redundante. En caso de repetir algo ya explicado antes, mejor no decir nada.
89
+ """
90
+
91
+ EVALUATION_CRITERIA = [
92
+ "Precisió Descriptiva",
93
+ "Sincronització Temporal",
94
+ "Claredat i Concisió",
95
+ "Inclusió de Diàleg/So",
96
+ "Contextualització",
97
+ "Flux i Ritme de la Narració",
98
+ ]
99
+
100
+ CRITERIA_WEIGHTS = {
101
+ "Precisió Descriptiva": 1,
102
+ "Sincronització Temporal": 4,
103
+ "Claredat i Concisió": 1,
104
+ "Inclusió de Diàleg/So": 1,
105
+ "Contextualització": 1,
106
+ "Flux i Ritme de la Narració": 1,
107
+ }
108
+
109
+ # Inicializar ficheros para la ejecución
110
+ def setup_files(initial_srt_content: str, context_json_content: str):
111
+ """Crea los ficheros iniciales necesarios en el sistema de archivos local."""
112
+ (TEMP_DIR / "une_ad_0.srt").write_text(initial_srt_content, encoding="utf-8")
113
+ (TEMP_DIR / "json_ad.json").write_text(context_json_content, encoding="utf-8")
114
+ logger.info("Ficheros iniciales 'une_ad_0.srt' y 'json_ad.json' creados.")
115
+
116
+ # --- Utilidades ---
117
+ def _strip_markdown_fences(content: str) -> str:
118
+ """Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
119
+ text = content.strip()
120
+ if text.startswith("```"):
121
+ lines = text.splitlines()
122
+ # descartar primera línea con ``` o ```json
123
+ lines = lines[1:]
124
+ # eliminar el cierre ``` (pueden existir varias líneas en blanco finales)
125
+ while lines and lines[-1].strip() == "```":
126
+ lines.pop()
127
+ text = "\n".join(lines).strip()
128
+ return text
129
+
130
+
131
+ def generate_evaluation_report(srt_content: str, iteration: int) -> tuple[float, float, Path]:
132
+ """Solicita al LLM una avaluació estructurada i guarda'n el CSV."""
133
+ criteria_formatted = "\n".join(f"- {name}" for name in EVALUATION_CRITERIA)
134
+ prompt = (
135
+ "Actua com un auditor UNE. Avalua l'SRT generat, puntuant cada característica de 0 a 7 "
136
+ "segons la qualitat observada. Dónega justificació breve però concreta per a cada cas. "
137
+ "Les característiques obligatòries són:\n"
138
+ f"{criteria_formatted}\n"
139
+ "Retorna ÚNICAMENT un array JSON d'objectes amb les claus: "
140
+ "'caracteristica', 'valoracio' (nombre enter de 0 a 7) i 'justificacio'."
141
+ )
142
+
143
+ response = llm.invoke(
144
+ [
145
+ SystemMessage(content=prompt),
146
+ HumanMessage(
147
+ content=(
148
+ "# SRT AVALUAT\n"
149
+ f"{srt_content}\n\n"
150
+ "Assegura't de complir el format indicat."
151
+ )
152
+ ),
153
+ ]
154
+ )
155
+
156
+ cleaned = _strip_markdown_fences(response.content)
157
+ try:
158
+ data = json.loads(cleaned)
159
+ if not isinstance(data, list):
160
+ raise ValueError("La resposta no és una llista.")
161
+ except Exception as exc:
162
+ logger.error(
163
+ "Error al generar l'avaluació estructurada: %s. Resposta original: %s",
164
+ exc,
165
+ response.content,
166
+ )
167
+ data = [
168
+ {
169
+ "caracteristica": "Avaluació fallida",
170
+ "valoracio": 1,
171
+ "justificacio": "No s'ha pogut obtenir l'avaluació del LLM.",
172
+ }
173
+ ]
174
+
175
+ eval_path = TEMP_DIR / f"eval_{iteration}.csv"
176
+ with eval_path.open("w", encoding="utf-8", newline="") as csvfile:
177
+ writer = csv.writer(csvfile)
178
+ writer.writerow(["Caracteristica", "Valoracio (0-7)", "Justificacio"])
179
+ for item in data:
180
+ writer.writerow(
181
+ [
182
+ item.get("caracteristica", ""),
183
+ item.get("valoracio", 0),
184
+ item.get("justificacio", ""),
185
+ ]
186
+ )
187
+
188
+ scores = []
189
+ weighted_sum = 0.0
190
+ total_weight = 0.0
191
+
192
+ for entry in data:
193
+ if not isinstance(entry, dict):
194
+ continue
195
+ try:
196
+ score = float(entry.get("valoracio", 0))
197
+ except (TypeError, ValueError):
198
+ score = 0.0
199
+ scores.append(score)
200
+
201
+ weight = CRITERIA_WEIGHTS.get(entry.get("caracteristica", ""), 1)
202
+ weighted_sum += score * weight
203
+ total_weight += weight
204
+
205
+ mean_score = sum(scores) / len(scores) if scores else 0.0
206
+ weighted_mean = weighted_sum / total_weight if total_weight else mean_score
207
+ return mean_score, weighted_mean, eval_path
208
+
209
+ # --- Definición del Estado de la Gráfica (StateGraph) ---
210
+ class ReflectionState(TypedDict):
211
+ """Representa el estado del bucle de reflexión."""
212
+ iteration: int # Ciclo actual (empezando en 0)
213
+ current_srt_path: str # Ruta al archivo SRT actual (e.g., une_ad_0.srt, une_ad_1.srt)
214
+ critic_report: Dict[str, Union[float, str]] # Último informe del crítico (puntuación y texto)
215
+ history: List[SystemMessage] # Historial de mensajes entre agentes
216
+ evaluation_mean: float
217
+ best_iteration: int
218
+ best_weighted_mean: float
219
+ best_srt_path: str
220
+ best_eval_path: str
221
+
222
+ # --- Nodos/Agentes de la Gráfica ---
223
+ def narrator_agent(state: ReflectionState):
224
+ """
225
+ Agente que genera o reescribe el SRT.
226
+ - En el ciclo 0, genera el SRT inicial.
227
+ - En ciclos > 0, reescribe el SRT basándose en el critic_report.
228
+ """
229
+ iteration = state["iteration"]
230
+ critic_report = state["critic_report"]
231
+ history = state["history"]
232
+
233
+ # Cargar contexto y último SRT
234
+ json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
235
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
236
+
237
+ # 1. Definir el prompt
238
+ if iteration == 0:
239
+ # Tarea inicial (aunque en este caso ya se proporciona une_ad_0.srt)
240
+ # Aquí se simula la generación inicial.
241
+ prompt = (
242
+ "Ets un Narrador expert en Audiodescripció (AD). La teva tasca inicial és generar "
243
+ "un fitxer SRT d'audiodescripcions basat en el JSON de context visual. "
244
+ "TOT I AIXÍ, per a aquesta primera iteració, l'SRT ja s'ha generat. "
245
+ "Simplement retorna el contingut de 'une_ad_0.srt' com si fos la teva sortida. "
246
+ "Assegura't que totes les audiodescripcions estiguin en català i que cadascuna pugui ser locutada "
247
+ "dins del temps disponible (utilitza un màxim aproximat d'11 caràcters per segon). Si el tram de temps "
248
+ "és massa curt (<1.5s), combina'l amb el bloc d'AD més proper i ajusta els timestamps perquè la narració sigui fluida. "
249
+ "Evita redundàncies: no repeteixis informació ja descrita en segments d'AD anteriors o al diàleg, i elimina qualsevol detall que no sigui essencial."
250
+ )
251
+ output_srt = current_srt
252
+ reflection_text = "Generación inicial. No hay reflexión."
253
+ else:
254
+ # Tarea de reflexión
255
+ prompt = (
256
+ "Ets un Narrador expert en Audiodescripció (AD). Has rebut una crítica sobre la teva última versió de l'SRT. "
257
+ "La teva tasca és REESCRIURE el contingut d'audiodescripció (línies amb '(AD)') del fitxer SRT, "
258
+ "assegurant que sigui coherent amb el JSON de context i, sobretot, que CORREGEIXIS TOTS els problemes "
259
+ "mencionats a l'Informe Crític adjunt. Mantén intactes els diàlegs (línies amb [Nom]) i escriu totes les audiodescripcions en català natural. "
260
+ "Garanteix que cada bloc d'AD pugui ser locutat dins del seu interval temporal disponible considerant un màxim d'11 caràcters per segon. "
261
+ "Si l'interval és massa curt (<1.5s), fusiona'l amb el bloc d'AD anterior o posterior més proper i ajusta els timestamps perquè quedin contínues. "
262
+ "Prefereix frases concises i accionables, prioritzant la informació visual essencial, i elimina redundàncies amb AD anteriors o amb els diàlegs."
263
+ )
264
+
265
+ # Concatenar la entrada para el LLM
266
+ input_content = f"""
267
+ # INFORME CRÍTICO
268
+ Porcentaje de Fiabilidad Anterior: {critic_report.get('reliability_percentage')}
269
+ Crítica Cualitativa: {critic_report.get('qualitative_critique')}
270
+
271
+ # JSON DE CONTEXTO VISUAL (Guía para la AD)
272
+ {json_context}
273
+
274
+ # ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration-1}.srt)
275
+ {current_srt}
276
+
277
+ REGLAS: Tu respuesta debe ser *SOLAMENTE* el contenido completo del nuevo archivo SRT (incluyendo diálogos), sin ningún comentario o explicación adicional.
278
+ """
279
+
280
+ # Llamada al LLM
281
+ response = llm.invoke(
282
+ [
283
+ SystemMessage(content=prompt),
284
+ HumanMessage(content=input_content)
285
+ ]
286
+ )
287
+
288
+ output_srt = response.content
289
+ reflection_text = f"Reescrito en base al informe crítico: {critic_report.get('qualitative_critique', 'N/A')}"
290
+
291
+ # 2. Guardar la nueva salida
292
+ new_srt_path = TEMP_DIR / f"une_ad_{iteration}.srt"
293
+ new_srt_path.write_text(output_srt, encoding="utf-8")
294
+
295
+ # 3. Guardar el pensamiento (reflection_text)
296
+ (TEMP_DIR / f"thinking_{iteration}.txt").write_text(reflection_text, encoding="utf-8")
297
+
298
+ logger.info(f"Narrador: Generada la versión {iteration} del SRT en '{new_srt_path}'.")
299
+
300
+ # 4. Actualizar el estado
301
+ new_history = history + [AIMessage(content=f"Narrador v{iteration} completado. Razón de reflexión: {reflection_text}")]
302
+ return {
303
+ "iteration": iteration,
304
+ "current_srt_path": str(new_srt_path),
305
+ "history": new_history,
306
+ "evaluation_mean": state.get("evaluation_mean", 0.0),
307
+ "best_iteration": state.get("best_iteration", -1),
308
+ "best_weighted_mean": state.get("best_weighted_mean", 0.0),
309
+ "best_srt_path": state.get("best_srt_path", str(new_srt_path)),
310
+ "best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
311
+ }
312
+
313
+ def identity_manager_agent(state: ReflectionState):
314
+ """
315
+ Agente que gestiona la identidad del usuario.
316
+ """
317
+ iteration = state["iteration"]
318
+ history = state["history"]
319
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
320
+
321
+ prompt = (
322
+ "Ets un gestor d'identitats. La teva tasca és verificar la identitat de l'usuari "
323
+ "i assegurar-te que les seves dades estiguin actualitzades."
324
+ )
325
+
326
+ input_content = f"""
327
+ # ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration}.srt):
328
+ {current_srt}
329
+
330
+ REGLAS: Tu respuesta debe ser *SOLAMENTE* un objeto JSON con la información de la identidad del usuario.
331
+ """
332
+
333
+ # Llamada al LLM
334
+ response = llm.invoke(
335
+ [
336
+ SystemMessage(content=prompt),
337
+ HumanMessage(content=input_content)
338
+ ]
339
+ )
340
+
341
+ # Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
342
+ try:
343
+ cleaned_response = _strip_markdown_fences(response.content)
344
+ identity_info = json.loads(cleaned_response)
345
+ if not isinstance(identity_info, dict):
346
+ raise ValueError("Estructura JSON incorrecta.")
347
+ except Exception as e:
348
+ logger.error(f"Error al parsear el JSON de la identidad: {e}. Respuesta: {response.content}")
349
+ identity_info = {"error": "No s'ha pogut obtenir la informació d'identitat."}
350
+
351
+ logger.info(f"Identity Manager: Información de identidad actualizada.")
352
+
353
+ new_history = history + [AIMessage(content=f"Identity Manager v{iteration} completado.")]
354
+ return {
355
+ "iteration": iteration,
356
+ "current_srt_path": state["current_srt_path"],
357
+ "history": new_history,
358
+ "evaluation_mean": state.get("evaluation_mean", 0.0),
359
+ "best_iteration": state.get("best_iteration", -1),
360
+ "best_weighted_mean": state.get("best_weighted_mean", 0.0),
361
+ "best_srt_path": state.get("best_srt_path", state["current_srt_path"]),
362
+ "best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
363
+ }
364
+
365
+ def critic_agent(state: ReflectionState):
366
+ """
367
+ Agente que evalúa la calidad del SRT generado por el Narrador basándose en las Reglas UNE.
368
+ Devuelve una puntuación y una crítica cualitativa.
369
+ """
370
+ iteration = state["iteration"]
371
+ history = state["history"]
372
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
373
+
374
+ prompt = (
375
+ "Ets un Crític d'Audiodescripció molt estricte. La teva tasca és avaluar l'SRT adjunt "
376
+ "únicament segons les Regles UNE proporcionades. L'avaluació ha de ser doble: "
377
+ "1. **Numèrica**: Un percentatge de fiabilitat (ex. 85.5) de 0 a 100%. "
378
+ "2. **Qualitativa**: Una crítica constructiva sobre les principals mancances de les AD respecte a les regles. "
379
+ "Has de ser EXTREMADAMENT estricte amb la sincronització (sense solapament amb el diàleg), "
380
+ "amb l'adequació temporal (velocitat màxima recomanada d'11 caràcters per segon) i amb l'absència de redundàncies. "
381
+ "Comprova també que totes les audiodescripcions estan escrites en català natural."
382
+ )
383
+
384
+ input_content = f"""
385
+ # REGLAS UNE DE AUDIODESCRIPCIÓN:
386
+ {UNE_RULES}
387
+
388
+ # ARCHIVO SRT A EVALUAR (une_ad_{iteration}.srt):
389
+ {current_srt}
390
+
391
+ REGLAS DE RESPUESTA:
392
+ Tu respuesta debe ser *SOLAMENTE* un objeto JSON con dos claves:
393
+ 1. "reliability_percentage": (float) El porcentaje de fiabilidad.
394
+ 2. "qualitative_critique": (string) La crítica cualitativa y sugerencias de mejora.
395
+ Ejemplo de respuesta: {{"reliability_percentage": 75.0, "qualitative_critique": "El segmento 4 se solapa 0.34s con el diálogo de Sandra. El segmento 5 es demasiado genérico y no describe bien la acción."}}
396
+ """
397
+
398
+ # Llamada al LLM
399
+ response = llm.invoke(
400
+ [
401
+ SystemMessage(content=prompt),
402
+ HumanMessage(content=input_content)
403
+ ]
404
+ )
405
+
406
+ # Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
407
+ try:
408
+ cleaned_response = _strip_markdown_fences(response.content)
409
+ report = json.loads(cleaned_response)
410
+ if not isinstance(report, dict) or 'reliability_percentage' not in report:
411
+ raise ValueError("Estructura JSON incorrecta.")
412
+ except Exception as e:
413
+ logger.error(f"Error al parsear el JSON del Crítico: {e}. Respuesta: {response.content}")
414
+ report = {"reliability_percentage": 1.0, "qualitative_critique": "El Crítico no devolvió un JSON válido. Reintentar."}
415
+
416
+ logger.info(f"Crítico: Evaluación completada. Fiabilidad: {report.get('reliability_percentage')}%.")
417
+
418
+ mean_score, weighted_mean, eval_path = generate_evaluation_report(current_srt, iteration)
419
+
420
+ thinking_path = TEMP_DIR / f"thinking_{iteration}.txt"
421
+ if thinking_path.exists():
422
+ previous_text = thinking_path.read_text(encoding="utf-8")
423
+ thinking_path.write_text(
424
+ (
425
+ f"{previous_text}\n\nMitjana simple d'avaluació: {mean_score:.2f} / 7"
426
+ f"\nMitjana ponderada d'avaluació: {weighted_mean:.2f} / 7"
427
+ ),
428
+ encoding="utf-8",
429
+ )
430
+
431
+ best_iteration = state.get("best_iteration", -1)
432
+ best_weighted_mean = state.get("best_weighted_mean", -1.0)
433
+ best_srt_path = state.get("best_srt_path", state["current_srt_path"])
434
+ best_eval_path = state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv"))
435
+
436
+ if weighted_mean > best_weighted_mean:
437
+ best_iteration = iteration
438
+ best_weighted_mean = weighted_mean
439
+ best_srt_path = state["current_srt_path"]
440
+ best_eval_path = str(eval_path)
441
+
442
+ new_history = history + [
443
+ AIMessage(
444
+ content=(
445
+ "Crítico v{iter} completado. Fiabilidad: {reliab}%. "
446
+ "Mitjana simple: {mean:.2f}/7. Mitjana ponderada: {wmean:.2f}/7"
447
+ ).format(
448
+ iter=iteration,
449
+ reliab=report.get("reliability_percentage"),
450
+ mean=mean_score,
451
+ wmean=weighted_mean,
452
+ )
453
+ )
454
+ ]
455
+ return {
456
+ "iteration": iteration + 1,
457
+ "critic_report": report,
458
+ "history": new_history,
459
+ "evaluation_mean": weighted_mean,
460
+ "best_iteration": best_iteration,
461
+ "best_weighted_mean": best_weighted_mean,
462
+ "best_srt_path": best_srt_path,
463
+ "best_eval_path": best_eval_path,
464
+ }
465
+
466
+ def identity_manager_agent(state: ReflectionState):
467
+ """
468
+ Agente que verifica coherencia entre hablantes en SRT, casting.csv y contexto visual.
469
+ Corrige asignaciones de hablantes y genera log de cambios.
470
+ """
471
+ iteration = state["iteration"]
472
+
473
+ # Cargar archivos
474
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
475
+ casting_path = TEMP_DIR / "casting.csv"
476
+ json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
477
+
478
+ # Verificar existencia de casting.csv
479
+ if not casting_path.exists():
480
+ logger.warning("Casting.csv no encontrado. Saltando identity_manager.")
481
+ return state
482
+
483
+ casting_content = casting_path.read_text(encoding="utf-8")
484
+
485
+ prompt = (
486
+ "Ets un Identity Manager. La teva tasca és:\n"
487
+ "1. Verificar que les assignacions de parlants a l'SRT coincideixen amb casting.csv\n"
488
+ "2. Comprovar que els parlants assignats són coherents amb el context visual de json_ad.json\n"
489
+ "3. Si trobes inconsistències, re-assigna els parlants corregint les etiquetes [Nom]\n"
490
+ "4. Justifica canvis al fitxer identity_log.txt\n"
491
+ "\n"
492
+ "Dades d'entrada:\n"
493
+ f"- CASTING.CSV:\n{casting_content}\n"
494
+ f"- JSON CONTEXT:\n{json_context}\n"
495
+ f"- SRT ACTUAL:\n{current_srt}\n"
496
+ "\n"
497
+ "REGLES:\n"
498
+ "- Només modifica les línies de diàleg (ex: [Nom])\n"
499
+ "- Manté la numeració i timestamps\n"
500
+ "- Si no hi ha canvis, retorna l'SRT original\n"
501
+ "\n"
502
+ "Format de sortida:\n"
503
+ "```json\n"
504
+ "{{\n"
505
+ " \"srt_content\": \"<nou contingut SRT>\",\n"
506
+ " \"log_message\": \"<explicació canvis o 'Sense canvis'>\"\n"
507
+ "}}\n"
508
+ "```"
509
+ )
510
+
511
+ response = llm.invoke([SystemMessage(content=prompt)])
512
+
513
+ try:
514
+ # Parsejar resposta JSON
515
+ cleaned = _strip_markdown_fences(response.content)
516
+ data = json.loads(cleaned)
517
+ new_srt = data["srt_content"]
518
+ log_msg = data["log_message"]
519
+
520
+ # Escriure log
521
+ log_path = TEMP_DIR / f"identity_log_{iteration}.txt"
522
+ log_path.write_text(f"Iteració {iteration}: {log_msg}", encoding="utf-8")
523
+
524
+ # Actualitzar SRT si hi ha canvis
525
+ if new_srt != current_srt:
526
+ new_srt_path = TEMP_DIR / f"une_ad_{iteration}_corrected.srt"
527
+ new_srt_path.write_text(new_srt, encoding="utf-8")
528
+ logger.info(f"Identity Manager: Correccions aplicades. Detalls: {log_msg}")
529
+ return {
530
+ **state,
531
+ "current_srt_path": str(new_srt_path)
532
+ }
533
+
534
+ except Exception as e:
535
+ logger.error(f"Error en identity_manager: {e}")
536
+
537
+ return state
538
+
539
+ def background_descriptor_agent(state: ReflectionState):
540
+ """
541
+ Agente que verifica coherencia entre escenarios en SRT y scenarios.csv.
542
+ Corrige nombres de escenarios usando descripciones coherentes.
543
+ """
544
+ iteration = state["iteration"]
545
+
546
+ # Cargar archivos
547
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
548
+ scenarios_path = TEMP_DIR / "scenarios.csv"
549
+
550
+ # Verificar existencia de scenarios.csv
551
+ if not scenarios_path.exists():
552
+ logger.warning("Scenarios.csv no encontrado. Saltando background_descriptor.")
553
+ return state
554
+
555
+ scenarios_content = scenarios_path.read_text(encoding="utf-8")
556
+
557
+ prompt = (
558
+ "Ets un Background Descriptor. La teva tasca és:\n"
559
+ "1. Verificar que les descripcions d'escenaris a l'SRT coincideixen amb scenarios.csv\n"
560
+ "2. Si trobes coincidències, reemplaça les descripcions genèriques pel nom oficial de l'escenari\n"
561
+ "3. Justifica canvis al fitxer background_log.txt\n"
562
+ "\n"
563
+ "Dades d'entrada:\n"
564
+ f"- SCENARIOS.CSV:\n{scenarios_content}\n"
565
+ f"- SRT ACTUAL:\n{current_srt}\n"
566
+ "\n"
567
+ "REGLES:\n"
568
+ "- Només modifica línies d'audiodescripció (ex: (AD) ...)\n"
569
+ "- Manté la numeració i timestamps\n"
570
+ "- Si no hi ha canvis, retorna l'SRT original\n"
571
+ "\n"
572
+ "Format de sortida:\n"
573
+ "```json\n"
574
+ "{{\n"
575
+ " \"srt_content\": \"<nou contingut SRT>\",\n"
576
+ " \"log_message\": \"<explicació canvis o 'Sense canvis'>\"\n"
577
+ "}}\n"
578
+ "```"
579
+ )
580
+
581
+ response = llm.invoke([SystemMessage(content=prompt)])
582
+
583
+ try:
584
+ # Parsejar resposta JSON
585
+ cleaned = _strip_markdown_fences(response.content)
586
+ data = json.loads(cleaned)
587
+ new_srt = data["srt_content"]
588
+ log_msg = data["log_message"]
589
+
590
+ # Escriure log
591
+ log_path = TEMP_DIR / f"background_log_{iteration}.txt"
592
+ log_path.write_text(f"Iteració {iteration}: {log_msg}", encoding="utf-8")
593
+
594
+ # Actualitzar SRT si hi ha canvis
595
+ if new_srt != current_srt:
596
+ new_srt_path = TEMP_DIR / f"une_ad_{iteration}_scenario_corrected.srt"
597
+ new_srt_path.write_text(new_srt, encoding="utf-8")
598
+ logger.info(f"Background Descriptor: Correccions aplicades. Detalls: {log_msg}")
599
+ return {
600
+ **state,
601
+ "current_srt_path": str(new_srt_path)
602
+ }
603
+
604
+ except Exception as e:
605
+ logger.error(f"Error en background_descriptor: {e}")
606
+
607
+ return state
608
+
609
+ # --- Condición de Salida del Bucle ---
610
+
611
+ def should_continue(state: ReflectionState) -> str:
612
+ """
613
+ Función de chequeo que decide si continuar iterando o finalizar.
614
+ """
615
+ MAX_ITERATIONS = 5 # Número máximo de ciclos
616
+ MIN_AVERAGE_SCORE = 6.0 # Umbral de calidad sobre 7
617
+
618
+ iteration = state["iteration"]
619
+ mean_score = state.get("evaluation_mean", 0.0)
620
+
621
+ if mean_score >= MIN_AVERAGE_SCORE:
622
+ logger.info(f"FIN: Mitjana ponderada d'avaluació assolida ({mean_score:.2f} >= {MIN_AVERAGE_SCORE}).")
623
+ return "end"
624
+
625
+ if iteration >= MAX_ITERATIONS:
626
+ logger.info(f"FIN: S'ha assolit el màxim d'iteracions ({iteration} / {MAX_ITERATIONS}).")
627
+ return "end"
628
+
629
+ logger.info(f"CONTINUAR: Iteració {iteration} / {MAX_ITERATIONS}. Mitjana ponderada actual: {mean_score:.2f} / 7.")
630
+ return "continue"
631
+
632
+ # --- Construcción de la Gráfica ---
633
+
634
+ # 1. Configurar el estado inicial
635
+ initial_state: ReflectionState = {
636
+ "iteration": 0,
637
+ "current_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
638
+ "critic_report": {"reliability_percentage": 0.0, "qualitative_critique": "Inicializando el proceso."},
639
+ "history": [],
640
+ "evaluation_mean": 0.0,
641
+ "best_iteration": -1,
642
+ "best_weighted_mean": -1.0,
643
+ "best_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
644
+ "best_eval_path": str(TEMP_DIR / "eval_0.csv"),
645
+ }
646
+
647
+ # 2. Definir la gráfica
648
+ workflow = StateGraph(ReflectionState)
649
+
650
+ # Nodos
651
+ workflow.add_node("narrator", narrator_agent)
652
+ workflow.add_node("identity_manager", identity_manager_agent)
653
+ workflow.add_node("background_descriptor", background_descriptor_agent)
654
+ workflow.add_node("critic", critic_agent)
655
+
656
+ # Estructura del bucle: Narrator -> Identity Manager -> Background Descriptor -> Critic -> Check
657
+ workflow.set_entry_point("narrator")
658
+ workflow.add_edge("narrator", "identity_manager")
659
+ workflow.add_edge("identity_manager", "background_descriptor")
660
+ workflow.add_edge("background_descriptor", "critic")
661
+
662
+ # Condición (puente de ramificación)
663
+ workflow.add_conditional_edges(
664
+ "critic",
665
+ should_continue,
666
+ {
667
+ "continue": "narrator", # Si no se cumple el umbral/ciclo, vuelve al narrador
668
+ "end": END # Si se cumple, termina
669
+ }
670
+ )
671
+
672
+ # Compilar la gráfica
673
+ app = workflow.compile()
674
+
675
+
676
+ def generate_free_ad_from_srt(srt_path: Path) -> Path:
677
+ """Genera una narración libre detallada a partir del SRT final."""
678
+ srt_content = srt_path.read_text(encoding="utf-8")
679
+ prompt = (
680
+ "Actua com una narradora professional d'audiodescripcions lliures. "
681
+ "A partir de l'SRT proporcionat, escriu un text narratiu en català que descrigui "
682
+ "de manera exhaustiva i fluida tot el que succeeix a la peça audiovisual. "
683
+ "Inclou accions, aparença, gestos, canvis d'escena i qualsevol detall rellevant, "
684
+ "sense limitar-te a les restriccions temporals del format SRT. "
685
+ "Evita repetir literalment els diàlegs, però contextualitza'ls quan sigui útil. "
686
+ "La narració ha de ser clara, coherent i apta per ser locutada com una narració lliure."
687
+ )
688
+
689
+ response = llm.invoke(
690
+ [
691
+ SystemMessage(content=prompt),
692
+ HumanMessage(
693
+ content=(
694
+ "# SRT FINAL\n"
695
+ f"{srt_content}\n\n"
696
+ "Respon únicamente con la narració lliure sin cap comentario adicional."
697
+ )
698
+ ),
699
+ ]
700
+ )
701
+
702
+ free_ad_path = TEMP_DIR / "free_ad.txt"
703
+ free_ad_path.write_text(response.content, encoding="utf-8")
704
+ logger.info(f"Narració lliure generada en '{free_ad_path}'.")
705
+ return free_ad_path
706
+
707
+ # --- Ejecución Principal ---
708
+
709
+ if __name__ == "__main__":
710
+ # Inicializar el entorno
711
+ setup_files(INITIAL_SRT_CONTENT, CONTEXT_JSON_CONTENT)
712
+
713
+ logger.info("--- Comenzando el Bucle de Finetuning ---")
714
+
715
+ # Ejecutar la gráfica
716
+ final_state = app.invoke(initial_state)
717
+
718
+ logger.info("\n--- Bucle Finalizado ---")
719
+
720
+ best_iteration = final_state.get("best_iteration", -1)
721
+ best_weighted_mean = final_state.get("best_weighted_mean", 0.0)
722
+ best_srt_path = Path(final_state.get("best_srt_path", final_state['current_srt_path']))
723
+ best_eval_path = Path(final_state.get("best_eval_path", TEMP_DIR / "eval_0.csv"))
724
+
725
+ final_srt_path = TEMP_DIR / "une_ad.srt"
726
+ final_eval_path = TEMP_DIR / "eval.csv"
727
+
728
+ try:
729
+ shutil.copy(best_srt_path, final_srt_path)
730
+ logger.info(f"SRT final copiado a '{final_srt_path}'.")
731
+ except Exception as exc:
732
+ logger.error(f"No se pudo copiar el SRT final: {exc}")
733
+
734
+ try:
735
+ shutil.copy(best_eval_path, final_eval_path)
736
+ logger.info(f"Evaluación final copiada a '{final_eval_path}'.")
737
+ except Exception as exc:
738
+ logger.error(f"No se pudo copiar el CSV final: {exc}")
739
+
740
+ free_ad_path: Union[Path, None] = None
741
+ try:
742
+ free_ad_path = generate_free_ad_from_srt(final_srt_path)
743
+ except Exception as exc:
744
+ logger.error(f"No s'ha pogut generar la narració lliure: {exc}")
745
+
746
+ # Mostrar resultados
747
+ print(f"Número final de ciclos: {final_state['iteration']}")
748
+ print(f"Iteración óptima: {best_iteration} (mitjana ponderada {best_weighted_mean:.2f}/7)")
749
+ print(f"Ruta al SRT final: {final_srt_path}")
750
+ print(f"Ruta a l'avaluació final: {final_eval_path}")
751
+ if free_ad_path is not None:
752
+ print(f"Ruta a la narració lliure: {free_ad_path}")
753
+ else:
754
+ print("No s'ha pogut generar la narració lliure.")
755
+
756
+ # Mostrar el SRT final generado
757
+ print("\n--- Contenido del SRT Final ---")
758
+ print(final_srt_path.read_text(encoding="utf-8"))
759
+
760
+ if free_ad_path is not None:
761
+ print("\n--- Narració Lliure ---")
762
+ print(free_ad_path.read_text(encoding="utf-8"))
finetuning/lora.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+ from typing import List, Dict
5
+
6
+ from datasets import Dataset
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ TrainingArguments,
11
+ Trainer,
12
+ )
13
+ from peft import LoraConfig, get_peft_model
14
+
15
+
16
+ BASE_DIR = Path(__file__).resolve().parent
17
+ DATA_DIR = BASE_DIR / "data"
18
+
19
+
20
+ def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]:
21
+ """Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt.
22
+
23
+ Cada ejemplo se formatea como una instrucción estilo instruct, usando el SRT como entrada
24
+ y la narración libre como salida.
25
+ """
26
+ examples: List[Dict[str, str]] = []
27
+
28
+ if not data_dir.exists():
29
+ raise FileNotFoundError(f"Data dir not found: {data_dir}")
30
+
31
+ for item in sorted(data_dir.iterdir()):
32
+ if not item.is_dir():
33
+ continue
34
+
35
+ srt_path = item / "target_une_ad.srt"
36
+ free_path = item / "free_ad.txt"
37
+
38
+ if not srt_path.exists() or not free_path.exists():
39
+ continue
40
+
41
+ srt_text = srt_path.read_text(encoding="utf-8")
42
+ free_text = free_path.read_text(encoding="utf-8")
43
+
44
+ # Formato tipo instruction-tuning, en catalán, coherente con la tarea
45
+ prompt = (
46
+ "Converteix el següent fitxer SRT d'audiodescripció UNE (amb restriccions temporals) "
47
+ "en una narració lliure detallada en català, sense límits de temps. "
48
+ "Mantén tota la informació visual rellevant però amb un to fluid i natural.\n\n"
49
+ "### SRT UNE\n" + srt_text.strip() + "\n\n### Narració lliure:"
50
+ )
51
+
52
+ examples.append({"prompt": prompt, "output": free_text.strip()})
53
+
54
+ if not examples:
55
+ raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)")
56
+
57
+ return examples
58
+
59
+
60
+ def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset:
61
+ """Construye un Dataset de Hugging Face a partir de los pares prompt/output.
62
+
63
+ Se concatena en una sola secuencia para entrenamiento causal:
64
+ [PROMPT] + [OUTPUT] + eos
65
+ y se enmascaran los tokens del prompt para que la loss sólo se compute sobre la salida.
66
+ """
67
+
68
+ def _gen():
69
+ for ex in pairs:
70
+ yield {"prompt": ex["prompt"], "output": ex["output"]}
71
+
72
+ raw_ds = Dataset.from_generator(_gen)
73
+
74
+ def tokenize_fn(batch):
75
+ prompts = batch["prompt"]
76
+ outputs = batch["output"]
77
+
78
+ input_ids_list = []
79
+ labels_list = []
80
+
81
+ for p, o in zip(prompts, outputs):
82
+ full_text = p + "\n" + o + tokenizer.eos_token
83
+ enc = tokenizer(
84
+ full_text,
85
+ truncation=True,
86
+ max_length=max_length,
87
+ padding="max_length",
88
+ )
89
+
90
+ # Máscara: ignorar loss en tokens del prompt
91
+ prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"]
92
+ prompt_len = min(len(prompt_ids), max_length)
93
+
94
+ labels = enc["input_ids"].copy()
95
+ for i in range(prompt_len):
96
+ labels[i] = -100
97
+
98
+ input_ids_list.append(enc["input_ids"])
99
+ labels_list.append(labels)
100
+
101
+ return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list}
102
+
103
+ tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"])
104
+ return tokenized
105
+
106
+
107
+ def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05):
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ base_model_name,
110
+ torch_dtype="auto",
111
+ device_map="auto",
112
+ )
113
+
114
+ lora_config = LoraConfig(
115
+ r=r,
116
+ lora_alpha=alpha,
117
+ lora_dropout=dropout,
118
+ bias="none",
119
+ task_type="CAUSAL_LM",
120
+ )
121
+
122
+ model = get_peft_model(model, lora_config)
123
+ return model
124
+
125
+
126
+ def parse_args() -> argparse.Namespace:
127
+ parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD")
128
+ parser.add_argument(
129
+ "--base_model",
130
+ type=str,
131
+ default="projecte-aina/salamandra-instruct-7b",
132
+ help="Nom o ruta del model base (HF hub o path local)",
133
+ )
134
+ parser.add_argument(
135
+ "--data_dir",
136
+ type=str,
137
+ default=str(DATA_DIR),
138
+ help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt",
139
+ )
140
+ parser.add_argument(
141
+ "--output_dir",
142
+ type=str,
143
+ default=str(BASE_DIR / "lora_output"),
144
+ help="Directori on desar l'adapter LoRA",
145
+ )
146
+ parser.add_argument("--batch_size", type=int, default=1)
147
+ parser.add_argument("--gradient_accumulation", type=int, default=8)
148
+ parser.add_argument("--epochs", type=int, default=3)
149
+ parser.add_argument("--lr", type=float, default=2e-4)
150
+ parser.add_argument("--max_length", type=int, default=2048)
151
+ parser.add_argument("--warmup_ratio", type=float, default=0.03)
152
+ parser.add_argument("--logging_steps", type=int, default=10)
153
+ parser.add_argument("--save_steps", type=int, default=200)
154
+ parser.add_argument("--eval_steps", type=int, default=200)
155
+ parser.add_argument("--r", type=int, default=16, help="Rank de LoRA")
156
+ parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA")
157
+ parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA")
158
+ return parser.parse_args()
159
+
160
+
161
+ def main():
162
+ args = parse_args()
163
+
164
+ data_dir = Path(args.data_dir)
165
+ output_dir = Path(args.output_dir)
166
+ output_dir.mkdir(parents=True, exist_ok=True)
167
+
168
+ print(f"[lora] Buscant dades a: {data_dir}")
169
+ pairs = find_training_pairs(data_dir)
170
+ print(f"[lora] Nombre d'exemples trobats: {len(pairs)}")
171
+
172
+ print(f"[lora] Carregant tokenizer de {args.base_model}")
173
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
174
+ if tokenizer.pad_token is None:
175
+ tokenizer.pad_token = tokenizer.eos_token
176
+
177
+ print("[lora] Construint dataset tokenitzat...")
178
+ dataset = build_dataset(pairs, tokenizer, max_length=args.max_length)
179
+
180
+ print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...")
181
+ model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout)
182
+
183
+ training_args = TrainingArguments(
184
+ output_dir=str(output_dir),
185
+ per_device_train_batch_size=args.batch_size,
186
+ gradient_accumulation_steps=args.gradient_accumulation,
187
+ num_train_epochs=args.epochs,
188
+ learning_rate=args.lr,
189
+ warmup_ratio=args.warmup_ratio,
190
+ logging_steps=args.logging_steps,
191
+ save_steps=args.save_steps,
192
+ evaluation_strategy="steps",
193
+ eval_steps=args.eval_steps,
194
+ save_total_limit=2,
195
+ bf16=True,
196
+ gradient_checkpointing=True,
197
+ report_to=[],
198
+ )
199
+
200
+ trainer = Trainer(
201
+ model=model,
202
+ args=training_args,
203
+ train_dataset=dataset,
204
+ eval_dataset=None,
205
+ tokenizer=tokenizer,
206
+ )
207
+
208
+ print("[lora] Iniciant entrenament...")
209
+ trainer.train()
210
+
211
+ print("[lora] Guardant adapter LoRA...")
212
+ model.save_pretrained(str(output_dir))
213
+ tokenizer.save_pretrained(str(output_dir))
214
+
215
+ print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
finetuning/reflection.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
+ import logging
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import TypedDict, Annotated, List, Dict, Union
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
10
+ from langchain_openai import ChatOpenAI
11
+ from operator import itemgetter
12
+
13
+ # --- Configuración y Herramientas ---
14
+
15
+ # Directorios de trabajo
16
+ BASE_DIR = Path(__file__).resolve().parent
17
+ TEMP_DIR = BASE_DIR / "temp"
18
+ TEMP_DIR.mkdir(exist_ok=True)
19
+
20
+ LOG_FILE = TEMP_DIR / "reflection.log"
21
+
22
+ # Configurar el logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(levelname)s: %(message)s',
26
+ handlers=[
27
+ logging.StreamHandler(),
28
+ logging.FileHandler(LOG_FILE, encoding="utf-8")
29
+ ],
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Asegúrate de configurar tu API Key.
34
+ # En un entorno real, usa os.environ["OPENAI_API_KEY"]
35
+ # Aquí usamos un placeholder para la demostración.
36
+ if "OPENAI_API_KEY" not in os.environ:
37
+ logger.warning("OPENAI_API_KEY no está configurada. Usando un placeholder.")
38
+ os.environ["OPENAI_API_KEY"] = "sk-..."
39
+
40
+ # Inicializar LLM (se usa GPT-4o por su capacidad de razonamiento)
41
+ # En producción, considera un modelo que soporte tus tokens y latencia requeridas.
42
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.3)
43
+
44
+ # --- Ficheros de Ejemplo ---
45
+
46
+ # Fichero SRT inicial (Narrador)
47
+ INITIAL_SRT_CONTENT = """
48
+ 1
49
+ 00:00:00,000 --> 00:00:05,340
50
+ [Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final.
51
+
52
+ 2
53
+ 00:00:04,340 --> 00:00:05,790
54
+ [Lucía] Hem de donar-ho tot.
55
+
56
+ 3
57
+ 00:00:05,790 --> 00:00:08,790
58
+ [Sandra] Ho sé, ho sé.
59
+
60
+ 4
61
+ 00:00:08,000 --> 00:00:10,000
62
+ (AD) De sobte, són al parc.
63
+
64
+ 5
65
+ 00:00:10,000 --> 00:00:14,000
66
+ (AD) Ara tallen menjar i fan una amanida a una cuina.
67
+ """
68
+
69
+ # Fichero JSON de contexto (ejemplo de la respuesta anterior, pero simplificado para el Narrador)
70
+ CONTEXT_JSON_CONTENT = """
71
+ {
72
+ "segments": [
73
+ {"id": 1, "start": "00:00:00,000", "end": "00:00:05,340", "type": "dialog", "text": "[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final."},
74
+ {"id": 2, "start": "00:00:04,340", "end": "00:00:05,790", "type": "dialog", "text": "[Lucía] Hem de donar-ho tot."},
75
+ {"id": 3, "start": "00:00:05,790", "end": "00:00:08,790", "type": "dialog", "text": "[Sandra] Ho sé, ho sé."},
76
+ {"id": 4, "start": "00:00:08,000", "end": "00:00:10,000", "type": "visual_context", "text": "Cambio de escena a un parque. Personajes caminando."},
77
+ {"id": 5, "start": "00:00:10,000", "end": "00:00:14,000", "type": "visual_context", "text": "Escena en una cocina. Los personajes están cortando vegetales y haciendo una ensalada."}
78
+ ]
79
+ }
80
+ """
81
+
82
+ # Fichero de Reglas UNE (Norma Técnica para el Crítico)
83
+ # Nota: Aquí se usa un resumen de las reglas pertinentes para un LLM.
84
+ UNE_RULES = """
85
+ ### Reglas UNE de Audiodescripción (Para el Crítico)
86
+ 1. **Objetividad y Foco Visual:** La descripción debe ser puramente objetiva, describiendo solo lo que se ve. Debe priorizar la acción y los elementos relevantes (personajes, objetos, localización).
87
+ 2. **Tiempo y Espacio (Sincronización):** Las audiodescripciones (AD) deben insertarse en los silencios del diálogo. El tiempo de la AD (entre START y END) debe ser suficiente para narrar el contenido sin solaparse con el diálogo o la música importante.
88
+ 3. **Concisión y Claridad:** Usar lenguaje simple y conciso. Evitar redundancias y juicios de valor.
89
+ 4. **Formato:** Cada segmento de AD debe tener un formato SRT válido, incluyendo el marcador (AD) al principio de la línea de texto.
90
+ 5. **Utilidad:** Cada segmento de AD debe ser útil para la comprensión y nunca ser redundante. En caso de repetir algo ya explicado antes, mejor no decir nada.
91
+ """
92
+
93
+ EVALUATION_CRITERIA = [
94
+ "Precisió Descriptiva",
95
+ "Sincronització Temporal",
96
+ "Claredat i Concisió",
97
+ "Inclusió de Diàleg/So",
98
+ "Contextualització",
99
+ "Flux i Ritme de la Narració",
100
+ ]
101
+
102
+ CRITERIA_WEIGHTS = {
103
+ "Precisió Descriptiva": 1,
104
+ "Sincronització Temporal": 4,
105
+ "Claredat i Concisió": 1,
106
+ "Inclusió de Diàleg/So": 1,
107
+ "Contextualització": 1,
108
+ "Flux i Ritme de la Narració": 1,
109
+ }
110
+
111
+ # Inicializar ficheros para la ejecución
112
+ def setup_files(initial_srt_content: str, context_json_content: str):
113
+ """Crea los ficheros iniciales necesarios en el sistema de archivos local."""
114
+ (TEMP_DIR / "une_ad_0.srt").write_text(initial_srt_content, encoding="utf-8")
115
+ (TEMP_DIR / "json_ad.json").write_text(context_json_content, encoding="utf-8")
116
+ logger.info("Ficheros iniciales 'une_ad_0.srt' y 'json_ad.json' creados.")
117
+
118
+ # --- Utilidades ---
119
+ def _strip_markdown_fences(content: str) -> str:
120
+ """Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
121
+ text = content.strip()
122
+ if text.startswith("```"):
123
+ lines = text.splitlines()
124
+ # descartar primera línea con ``` o ```json
125
+ lines = lines[1:]
126
+ # eliminar el cierre ``` (pueden existir varias líneas en blanco finales)
127
+ while lines and lines[-1].strip() == "```":
128
+ lines.pop()
129
+ text = "\n".join(lines).strip()
130
+ return text
131
+
132
+
133
+ def generate_evaluation_report(srt_content: str, iteration: int) -> tuple[float, float, Path]:
134
+ """Solicita al LLM una avaluació estructurada i guarda'n el CSV."""
135
+ criteria_formatted = "\n".join(f"- {name}" for name in EVALUATION_CRITERIA)
136
+ prompt = (
137
+ "Actua com un auditor UNE. Avalua l'SRT generat, puntuant cada característica de 0 a 7 "
138
+ "segons la qualitat observada. Dónega justificació breve però concreta per a cada cas. "
139
+ "Les característiques obligatòries són:\n"
140
+ f"{criteria_formatted}\n"
141
+ "Retorna ÚNICAMENT un array JSON d'objectes amb les claus: "
142
+ "'caracteristica', 'valoracio' (nombre enter de 0 a 7) i 'justificacio'."
143
+ )
144
+
145
+ response = llm.invoke(
146
+ [
147
+ SystemMessage(content=prompt),
148
+ HumanMessage(
149
+ content=(
150
+ "# SRT AVALUAT\n"
151
+ f"{srt_content}\n\n"
152
+ "Assegura't de complir el format indicat."
153
+ )
154
+ ),
155
+ ]
156
+ )
157
+
158
+ cleaned = _strip_markdown_fences(response.content)
159
+ try:
160
+ data = json.loads(cleaned)
161
+ if not isinstance(data, list):
162
+ raise ValueError("La resposta no és una llista.")
163
+ except Exception as exc:
164
+ logger.error(
165
+ "Error al generar l'avaluació estructurada: %s. Resposta original: %s",
166
+ exc,
167
+ response.content,
168
+ )
169
+ data = [
170
+ {
171
+ "caracteristica": "Avaluació fallida",
172
+ "valoracio": 1,
173
+ "justificacio": "No s'ha pogut obtenir l'avaluació del LLM.",
174
+ }
175
+ ]
176
+
177
+ eval_path = TEMP_DIR / f"eval_{iteration}.csv"
178
+ with eval_path.open("w", encoding="utf-8", newline="") as csvfile:
179
+ writer = csv.writer(csvfile)
180
+ writer.writerow(["Caracteristica", "Valoracio (0-7)", "Justificacio"])
181
+ for item in data:
182
+ writer.writerow(
183
+ [
184
+ item.get("caracteristica", ""),
185
+ item.get("valoracio", 0),
186
+ item.get("justificacio", ""),
187
+ ]
188
+ )
189
+
190
+ scores = []
191
+ weighted_sum = 0.0
192
+ total_weight = 0.0
193
+
194
+ for entry in data:
195
+ if not isinstance(entry, dict):
196
+ continue
197
+ try:
198
+ score = float(entry.get("valoracio", 0))
199
+ except (TypeError, ValueError):
200
+ score = 0.0
201
+ scores.append(score)
202
+
203
+ weight = CRITERIA_WEIGHTS.get(entry.get("caracteristica", ""), 1)
204
+ weighted_sum += score * weight
205
+ total_weight += weight
206
+
207
+ mean_score = sum(scores) / len(scores) if scores else 0.0
208
+ weighted_mean = weighted_sum / total_weight if total_weight else mean_score
209
+ return mean_score, weighted_mean, eval_path
210
+
211
+ # --- Definición del Estado de la Gráfica (StateGraph) ---
212
+ class ReflectionState(TypedDict):
213
+ """Representa el estado del bucle de reflexión."""
214
+ iteration: int # Ciclo actual (empezando en 0)
215
+ current_srt_path: str # Ruta al archivo SRT actual (e.g., une_ad_0.srt, une_ad_1.srt)
216
+ critic_report: Dict[str, Union[float, str]] # Último informe del crítico (puntuación y texto)
217
+ history: List[SystemMessage] # Historial de mensajes entre agentes
218
+ evaluation_mean: float
219
+ best_iteration: int
220
+ best_weighted_mean: float
221
+ best_srt_path: str
222
+ best_eval_path: str
223
+
224
+ # --- Nodos/Agentes de la Gráfica ---
225
+ def narrator_agent(state: ReflectionState):
226
+ """
227
+ Agente que genera o reescribe el SRT.
228
+ - En el ciclo 0, genera el SRT inicial.
229
+ - En ciclos > 0, reescribe el SRT basándose en el critic_report.
230
+ """
231
+ iteration = state["iteration"]
232
+ critic_report = state["critic_report"]
233
+ history = state["history"]
234
+
235
+ # Cargar contexto y último SRT
236
+ json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
237
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
238
+
239
+ # 1. Definir el prompt
240
+ if iteration == 0:
241
+ # Tarea inicial (aunque en este caso ya se proporciona une_ad_0.srt)
242
+ # Aquí se simula la generación inicial.
243
+ prompt = (
244
+ "Ets un Narrador expert en Audiodescripció (AD). La teva tasca inicial és generar "
245
+ "un fitxer SRT d'audiodescripcions basat en el JSON de context visual. "
246
+ "TOT I AIXÍ, per a aquesta primera iteració, l'SRT ja s'ha generat. "
247
+ "Simplement retorna el contingut de 'une_ad_0.srt' com si fos la teva sortida. "
248
+ "Assegura't que totes les audiodescripcions estiguin en català i que cadascuna pugui ser locutada "
249
+ "dins del temps disponible (utilitza un màxim aproximat d'11 caràcters per segon). Si el tram de temps "
250
+ "és massa curt (<1.5s), combina'l amb el bloc d'AD més proper i ajusta els timestamps perquè la narració sigui fluida. "
251
+ "Evita redundàncies: no repeteixis informació ja descrita en segments d'AD anteriors o al diàleg, i elimina qualsevol detall que no sigui essencial."
252
+ )
253
+ output_srt = current_srt
254
+ reflection_text = "Generación inicial. No hay reflexión."
255
+ else:
256
+ # Tarea de reflexión
257
+ prompt = (
258
+ "Ets un Narrador expert en Audiodescripció (AD). Has rebut una crítica sobre la teva última versió de l'SRT. "
259
+ "La teva tasca és REESCRIURE el contingut d'audiodescripció (línies amb '(AD)') del fitxer SRT, "
260
+ "assegurant que sigui coherent amb el JSON de context i, sobretot, que CORREGEIXIS TOTS els problemes "
261
+ "mencionats a l'Informe Crític adjunt. Mantén intactes els diàlegs (línies amb [Nom]) i escriu totes les audiodescripcions en català natural. "
262
+ "Garanteix que cada bloc d'AD pugui ser locutat dins del seu interval temporal disponible considerant un màxim d'11 caràcters per segon. "
263
+ "Si l'interval és massa curt (<1.5s), fusiona'l amb el bloc d'AD anterior o posterior més proper i ajusta els timestamps perquè quedin contínues. "
264
+ "Prefereix frases concises i accionables, prioritzant la informació visual essencial, i elimina redundàncies amb AD anteriors o amb els diàlegs."
265
+ )
266
+
267
+ # Concatenar la entrada para el LLM
268
+ input_content = f"""
269
+ # INFORME CRÍTICO
270
+ Porcentaje de Fiabilidad Anterior: {critic_report.get('reliability_percentage')}
271
+ Crítica Cualitativa: {critic_report.get('qualitative_critique')}
272
+
273
+ # JSON DE CONTEXTO VISUAL (Guía para la AD)
274
+ {json_context}
275
+
276
+ # ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration-1}.srt)
277
+ {current_srt}
278
+
279
+ REGLAS: Tu respuesta debe ser *SOLAMENTE* el contenido completo del nuevo archivo SRT (incluyendo diálogos), sin ningún comentario o explicación adicional.
280
+ """
281
+
282
+ # Llamada al LLM
283
+ response = llm.invoke(
284
+ [
285
+ SystemMessage(content=prompt),
286
+ HumanMessage(content=input_content)
287
+ ]
288
+ )
289
+
290
+ output_srt = response.content
291
+ reflection_text = f"Reescrito en base al informe crítico: {critic_report.get('qualitative_critique', 'N/A')}"
292
+
293
+ # 2. Guardar la nueva salida
294
+ new_srt_path = TEMP_DIR / f"une_ad_{iteration}.srt"
295
+ new_srt_path.write_text(output_srt, encoding="utf-8")
296
+
297
+ # 3. Guardar el pensamiento (reflection_text)
298
+ (TEMP_DIR / f"thinking_{iteration}.txt").write_text(reflection_text, encoding="utf-8")
299
+
300
+ logger.info(f"Narrador: Generada la versión {iteration} del SRT en '{new_srt_path}'.")
301
+
302
+ # 4. Actualizar el estado
303
+ new_history = history + [AIMessage(content=f"Narrador v{iteration} completado. Razón de reflexión: {reflection_text}")]
304
+ return {
305
+ "iteration": iteration,
306
+ "current_srt_path": str(new_srt_path),
307
+ "history": new_history,
308
+ "evaluation_mean": state.get("evaluation_mean", 0.0),
309
+ "best_iteration": state.get("best_iteration", -1),
310
+ "best_weighted_mean": state.get("best_weighted_mean", 0.0),
311
+ "best_srt_path": state.get("best_srt_path", str(new_srt_path)),
312
+ "best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
313
+ }
314
+
315
+ def critic_agent(state: ReflectionState):
316
+ """
317
+ Agente que evalúa la calidad del SRT generado por el Narrador basándose en las Reglas UNE.
318
+ Devuelve una puntuación y una crítica cualitativa.
319
+ """
320
+ iteration = state["iteration"]
321
+ history = state["history"]
322
+ current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
323
+
324
+ prompt = (
325
+ "Ets un Crític d'Audiodescripció molt estricte. La teva tasca és avaluar l'SRT adjunt "
326
+ "únicament segons les Regles UNE proporcionades. L'avaluació ha de ser doble: "
327
+ "1. **Numèrica**: Un percentatge de fiabilitat (ex. 85.5) de 0 a 100%. "
328
+ "2. **Qualitativa**: Una crítica constructiva sobre les principals mancances de les AD respecte a les regles. "
329
+ "Has de ser EXTREMADAMENT estricte amb la sincronització (sense solapament amb el diàleg), "
330
+ "amb l'adequació temporal (velocitat màxima recomanada d'11 caràcters per segon) i amb l'absència de redundàncies. "
331
+ "Comprova també que totes les audiodescripcions estan escrites en català natural."
332
+ )
333
+
334
+ input_content = f"""
335
+ # REGLAS UNE DE AUDIODESCRIPCIÓN:
336
+ {UNE_RULES}
337
+
338
+ # ARCHIVO SRT A EVALUAR (une_ad_{iteration}.srt):
339
+ {current_srt}
340
+
341
+ REGLAS DE RESPUESTA:
342
+ Tu respuesta debe ser *SOLAMENTE* un objeto JSON con dos claves:
343
+ 1. "reliability_percentage": (float) El porcentaje de fiabilidad.
344
+ 2. "qualitative_critique": (string) La crítica cualitativa y sugerencias de mejora.
345
+ Ejemplo de respuesta: {{"reliability_percentage": 75.0, "qualitative_critique": "El segmento 4 se solapa 0.34s con el diálogo de Sandra. El segmento 5 es demasiado genérico y no describe bien la acción."}}
346
+ """
347
+
348
+ # Llamada al LLM
349
+ response = llm.invoke(
350
+ [
351
+ SystemMessage(content=prompt),
352
+ HumanMessage(content=input_content)
353
+ ]
354
+ )
355
+
356
+ # Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
357
+ try:
358
+ cleaned_response = _strip_markdown_fences(response.content)
359
+ report = json.loads(cleaned_response)
360
+ if not isinstance(report, dict) or 'reliability_percentage' not in report:
361
+ raise ValueError("Estructura JSON incorrecta.")
362
+ except Exception as e:
363
+ logger.error(f"Error al parsear el JSON del Crítico: {e}. Respuesta: {response.content}")
364
+ report = {"reliability_percentage": 1.0, "qualitative_critique": "El Crítico no devolvió un JSON válido. Reintentar."}
365
+
366
+ logger.info(f"Crítico: Evaluación completada. Fiabilidad: {report.get('reliability_percentage')}%.")
367
+
368
+ mean_score, weighted_mean, eval_path = generate_evaluation_report(current_srt, iteration)
369
+
370
+ thinking_path = TEMP_DIR / f"thinking_{iteration}.txt"
371
+ if thinking_path.exists():
372
+ previous_text = thinking_path.read_text(encoding="utf-8")
373
+ thinking_path.write_text(
374
+ (
375
+ f"{previous_text}\n\nMitjana simple d'avaluació: {mean_score:.2f} / 7"
376
+ f"\nMitjana ponderada d'avaluació: {weighted_mean:.2f} / 7"
377
+ ),
378
+ encoding="utf-8",
379
+ )
380
+
381
+ best_iteration = state.get("best_iteration", -1)
382
+ best_weighted_mean = state.get("best_weighted_mean", -1.0)
383
+ best_srt_path = state.get("best_srt_path", state["current_srt_path"])
384
+ best_eval_path = state.get("best_eval_path", str(eval_path))
385
+
386
+ if weighted_mean > best_weighted_mean:
387
+ best_iteration = iteration
388
+ best_weighted_mean = weighted_mean
389
+ best_srt_path = state["current_srt_path"]
390
+ best_eval_path = str(eval_path)
391
+
392
+ new_history = history + [
393
+ AIMessage(
394
+ content=(
395
+ "Crítico v{iter} completado. Fiabilidad: {reliab}%. "
396
+ "Mitjana simple: {mean:.2f}/7. Mitjana ponderada: {wmean:.2f}/7"
397
+ ).format(
398
+ iter=iteration,
399
+ reliab=report.get("reliability_percentage"),
400
+ mean=mean_score,
401
+ wmean=weighted_mean,
402
+ )
403
+ )
404
+ ]
405
+ return {
406
+ "iteration": iteration + 1,
407
+ "critic_report": report,
408
+ "history": new_history,
409
+ "evaluation_mean": weighted_mean,
410
+ "best_iteration": best_iteration,
411
+ "best_weighted_mean": best_weighted_mean,
412
+ "best_srt_path": best_srt_path,
413
+ "best_eval_path": best_eval_path,
414
+ }
415
+
416
+
417
+ # --- Condición de Salida del Bucle ---
418
+
419
+ def should_continue(state: ReflectionState) -> str:
420
+ """
421
+ Función de chequeo que decide si continuar iterando o finalizar.
422
+ """
423
+ MAX_ITERATIONS = 5 # Número máximo de ciclos
424
+ MIN_AVERAGE_SCORE = 6.0 # Umbral de calidad sobre 7
425
+
426
+ iteration = state["iteration"]
427
+ mean_score = state.get("evaluation_mean", 0.0)
428
+
429
+ if mean_score >= MIN_AVERAGE_SCORE:
430
+ logger.info(f"FIN: Mitjana ponderada d'avaluació assolida ({mean_score:.2f} >= {MIN_AVERAGE_SCORE}).")
431
+ return "end"
432
+
433
+ if iteration >= MAX_ITERATIONS:
434
+ logger.info(f"FIN: S'ha assolit el màxim d'iteracions ({iteration} / {MAX_ITERATIONS}).")
435
+ return "end"
436
+
437
+ logger.info(f"CONTINUAR: Iteració {iteration} / {MAX_ITERATIONS}. Mitjana ponderada actual: {mean_score:.2f} / 7.")
438
+ return "continue"
439
+
440
+ # --- Construcción de la Gráfica ---
441
+
442
+ # 1. Configurar el estado inicial
443
+ initial_state: ReflectionState = {
444
+ "iteration": 0,
445
+ "current_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
446
+ "critic_report": {"reliability_percentage": 0.0, "qualitative_critique": "Inicializando el proceso."},
447
+ "history": [],
448
+ "evaluation_mean": 0.0,
449
+ "best_iteration": -1,
450
+ "best_weighted_mean": -1.0,
451
+ "best_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
452
+ "best_eval_path": str(TEMP_DIR / "eval_0.csv"),
453
+ }
454
+
455
+ # 2. Definir la gráfica
456
+ workflow = StateGraph(ReflectionState)
457
+
458
+ # Nodos
459
+ workflow.add_node("narrator", narrator_agent)
460
+ workflow.add_node("critic", critic_agent)
461
+
462
+ # Estructura del bucle: Narrator -> Critic -> Check
463
+ workflow.set_entry_point("narrator")
464
+ workflow.add_edge("narrator", "critic")
465
+
466
+ # Condición (puente de ramificación)
467
+ workflow.add_conditional_edges(
468
+ "critic",
469
+ should_continue,
470
+ {
471
+ "continue": "narrator", # Si no se cumple el umbral/ciclo, vuelve al narrador
472
+ "end": END # Si se cumple, termina
473
+ }
474
+ )
475
+
476
+ # Compilar la gráfica
477
+ app = workflow.compile()
478
+
479
+ # --- Ejecución Principal ---
480
+
481
+ if __name__ == "__main__":
482
+ # Inicializar el entorno
483
+ setup_files(INITIAL_SRT_CONTENT, CONTEXT_JSON_CONTENT)
484
+
485
+ logger.info("--- Comenzando el Bucle de Reflexión ---")
486
+
487
+ # Ejecutar la gráfica
488
+ final_state = app.invoke(initial_state)
489
+
490
+ logger.info("\n--- Bucle Finalizado ---")
491
+
492
+ best_iteration = final_state.get("best_iteration", -1)
493
+ best_weighted_mean = final_state.get("best_weighted_mean", 0.0)
494
+ best_srt_path = Path(final_state.get("best_srt_path", final_state['current_srt_path']))
495
+ best_eval_path = Path(final_state.get("best_eval_path", TEMP_DIR / "eval_0.csv"))
496
+
497
+ final_srt_path = TEMP_DIR / "une_ad.srt"
498
+ final_eval_path = TEMP_DIR / "eval.csv"
499
+
500
+ try:
501
+ shutil.copy(best_srt_path, final_srt_path)
502
+ logger.info(f"SRT final copiado a '{final_srt_path}'.")
503
+ except Exception as exc:
504
+ logger.error(f"No se pudo copiar el SRT final: {exc}")
505
+
506
+ try:
507
+ shutil.copy(best_eval_path, final_eval_path)
508
+ logger.info(f"Evaluación final copiada a '{final_eval_path}'.")
509
+ except Exception as exc:
510
+ logger.error(f"No se pudo copiar el CSV final: {exc}")
511
+
512
+ # Mostrar resultados
513
+ print(f"Número final de ciclos: {final_state['iteration']}")
514
+ print(f"Iteración òptima: {best_iteration} (mitjana ponderada {best_weighted_mean:.2f}/7)")
515
+ print(f"Ruta al SRT final: {final_srt_path}")
516
+ print(f"Ruta a l'avaluació final: {final_eval_path}")
517
+
518
+ # Mostrar el SRT final generado
519
+ print("\n--- Contenido del SRT Final ---")
520
+ print(final_srt_path.read_text(encoding="utf-8"))
finetuning/video_analysis.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from datetime import timedelta
6
+ from typing import List, Optional, Dict, Any
7
+
8
+
9
+ TIME_RE = re.compile(
10
+ r"(?P<start>\d{2}:\d{2}:\d{2}[,\.]\d{3})\s*-->\s*(?P<end>\d{2}:\d{2}:\d{2}[,\.]\d{3})"
11
+ )
12
+
13
+
14
+ @dataclass
15
+ class SRTBlock:
16
+ index: int
17
+ start: float # seconds
18
+ end: float # seconds
19
+ text: str
20
+
21
+
22
+ def _parse_timestamp(ts: str) -> float:
23
+ """Convierte 'HH:MM:SS,mmm' o 'HH:MM:SS.mmm' a segundos (float)."""
24
+ ts = ts.replace(",", ".")
25
+ h, m, s = ts.split(":")
26
+ seconds, millis = (s.split("." ) + ["0"])[:2]
27
+ td = timedelta(
28
+ hours=int(h),
29
+ minutes=int(m),
30
+ seconds=int(seconds),
31
+ milliseconds=int(millis.ljust(3, "0")),
32
+ )
33
+ return td.total_seconds()
34
+
35
+
36
+ def _parse_srt(srt_text: str) -> List[SRTBlock]:
37
+ """Parsea texto SRT en una lista de bloques SRTBlock."""
38
+ srt_text = srt_text.replace("\r\n", "\n").replace("\r", "\n")
39
+ chunks = [c.strip() for c in re.split(r"\n\s*\n", srt_text) if c.strip()]
40
+ blocks: List[SRTBlock] = []
41
+
42
+ for chunk in chunks:
43
+ lines = chunk.split("\n")
44
+ idx_line = 0
45
+ index = None
46
+
47
+ if lines and lines[0].strip().isdigit():
48
+ index = int(lines[0].strip())
49
+ idx_line = 1
50
+
51
+ time_match = None
52
+ time_line_idx = None
53
+ for i in range(idx_line, min(idx_line + 3, len(lines))):
54
+ m = TIME_RE.search(lines[i])
55
+ if m:
56
+ time_match = m
57
+ time_line_idx = i
58
+ break
59
+
60
+ if not time_match or time_line_idx is None:
61
+ continue
62
+
63
+ start = _parse_timestamp(time_match.group("start"))
64
+ end = _parse_timestamp(time_match.group("end"))
65
+ if index is None:
66
+ index = len(blocks) + 1
67
+
68
+ text = "\n".join(lines[time_line_idx + 1 :]).strip()
69
+ blocks.append(SRTBlock(index=index, start=start, end=end, text=text))
70
+
71
+ return blocks
72
+
73
+
74
+ def analyze_srt(
75
+ srt_text: str,
76
+ *,
77
+ ad_markers: Optional[List[str]] = None,
78
+ ) -> Dict[str, Any]:
79
+ """Analiza un SRT y devuelve métricas básicas.
80
+
81
+ Métricas devueltas:
82
+ - duration_sec: duración total estimada del vídeo (segundos)
83
+ - words_per_min: número de palabras por minuto
84
+ - speakers_blocks_per_min: número de bloques de diálogo por minuto
85
+ - ad_time_ratio: porcentaje (0..1) del tiempo total con bloques marcados como AD
86
+ - blocks_per_min: número total de bloques por minuto
87
+
88
+ Heurísticas:
89
+ - Se asume que la duración del vídeo es el final del último bloque.
90
+ - Un "bloque de AD" es aquel cuya primera línea contiene alguno de los
91
+ marcadores indicados en `ad_markers` (por ejemplo: "[AD]", "AD:", "(AD)").
92
+ """
93
+
94
+ blocks = _parse_srt(srt_text)
95
+ if not blocks:
96
+ return {
97
+ "duration_sec": 0.0,
98
+ "words_per_min": 0.0,
99
+ "speakers_blocks_per_min": 0.0,
100
+ "ad_time_ratio": 0.0,
101
+ "blocks_per_min": 0.0,
102
+ }
103
+
104
+ duration_sec = max(b.end for b in blocks)
105
+ duration_min = max(duration_sec / 60.0, 1e-6)
106
+
107
+ # Palabras totales
108
+ total_words = 0
109
+ for b in blocks:
110
+ total_words += len(b.text.split())
111
+
112
+ # Bloques considerados de "hablante" (no AD)
113
+ if ad_markers is None:
114
+ ad_markers = ["[AD]", "AD:", "(AD)"]
115
+
116
+ def is_ad_block(block: SRTBlock) -> bool:
117
+ first_line = (block.text.splitlines() or [""])[0].strip().upper()
118
+ for mk in ad_markers:
119
+ if mk.upper() in first_line:
120
+ return True
121
+ return False
122
+
123
+ ad_time = 0.0
124
+ speech_blocks = 0
125
+ for b in blocks:
126
+ if is_ad_block(b):
127
+ ad_time += max(0.0, b.end - b.start)
128
+ else:
129
+ speech_blocks += 1
130
+
131
+ words_per_min = total_words / duration_min
132
+ speakers_blocks_per_min = speech_blocks / duration_min
133
+ blocks_per_min = len(blocks) / duration_min
134
+ ad_time_ratio = ad_time / duration_sec if duration_sec > 0 else 0.0
135
+
136
+ return {
137
+ "duration_sec": float(duration_sec),
138
+ "words_per_min": float(words_per_min),
139
+ "speakers_blocks_per_min": float(speakers_blocks_per_min),
140
+ "ad_time_ratio": float(ad_time_ratio),
141
+ "blocks_per_min": float(blocks_per_min),
142
+ }
143
+
144
+
145
+ def embed_srt_sentences(
146
+ srt_text: str,
147
+ *,
148
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
149
+ ) -> Dict[str, Any]:
150
+ """Devuelve embeddings para las frases de un SRT.
151
+
152
+ Args:
153
+ srt_text: Contenido completo del archivo SRT como string.
154
+ model_name: Nombre del modelo de sentence-transformers a usar.
155
+
156
+ Returns:
157
+ Diccionario con:
158
+ - "model_name": nombre del modelo utilizado
159
+ - "sentences": lista de strings (una por bloque)
160
+ - "embeddings": lista de listas de floats con los embeddings
161
+
162
+ NOTA: Requiere instalar `sentence-transformers` y un backend de PyTorch
163
+ compatible. Si no está instalado, lanzará ImportError.
164
+ """
165
+
166
+ blocks = _parse_srt(srt_text)
167
+ sentences = [b.text.replace("\n", " ").strip() for b in blocks if b.text.strip()]
168
+
169
+ if not sentences:
170
+ return {"model_name": model_name, "sentences": [], "embeddings": []}
171
+
172
+ try:
173
+ from sentence_transformers import SentenceTransformer
174
+ except ImportError as exc:
175
+ raise ImportError(
176
+ "sentence-transformers no está instalado. "
177
+ "Instala la dependencia para poder generar embeddings."
178
+ ) from exc
179
+
180
+ model = SentenceTransformer(model_name)
181
+ embs = model.encode(sentences, convert_to_numpy=False)
182
+
183
+ embeddings = [list(map(float, vec)) for vec in embs]
184
+
185
+ return {
186
+ "model_name": model_name,
187
+ "sentences": sentences,
188
+ "embeddings": embeddings,
189
+ }
storage/pending_videos_routers.py CHANGED
@@ -1,244 +1,244 @@
1
- import os
2
- import io
3
- import shutil
4
-
5
- import sqlite3
6
-
7
- from pathlib import Path
8
-
9
- from fastapi import APIRouter, UploadFile, File, Query, HTTPException
10
- from fastapi.responses import FileResponse, JSONResponse
11
-
12
-
13
- from storage.files.file_manager import FileManager
14
- from storage.common import validate_token
15
-
16
- router = APIRouter(prefix="/peding_videos", tags=["Pending Videos Manager"])
17
- MEDIA_ROOT = Path("/data/peding_videos")
18
- file_manager = FileManager(MEDIA_ROOT)
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
-
21
-
22
- @router.delete("/clear_pending_videos", tags=["Pending Videos Manager"])
23
- def clear_media(token: str = Query(..., description="Token required for authorization")):
24
- """
25
- Delete all contents of the /data/peding_videos folder.
26
- Steps:
27
- - Validate the token.
28
- - Ensure the folder exists.
29
- - Delete all files and subfolders inside /data/peding_videos.
30
- - Return a JSON response confirming the deletion.
31
- Warning: This will remove all stored videos, clips, and cast CSV files.
32
- """
33
- validate_token(token)
34
-
35
- if not MEDIA_ROOT.exists() or not MEDIA_ROOT.is_dir():
36
- raise HTTPException(status_code=404, detail="/data/peding_videos folder does not exist")
37
-
38
- # Delete contents
39
- for item in MEDIA_ROOT.iterdir():
40
- try:
41
- if item.is_dir():
42
- shutil.rmtree(item)
43
- else:
44
- item.unlink()
45
- except Exception as e:
46
- raise HTTPException(status_code=500, detail=f"Failed to delete {item}: {e}")
47
-
48
- return {"status": "ok", "message": "All peding_videos files deleted successfully"}
49
-
50
-
51
- @router.delete("/clear_pending_video", tags=["Pending Videos Manager"])
52
- def clear_pending_video(
53
- sha1: str = Query(..., description="SHA1 folder to delete inside pending_videos"),
54
- token: str = Query(..., description="Token required for authorization")
55
- ):
56
- """
57
- Delete a specific SHA1 folder inside /data/pending_videos.
58
- Steps:
59
- - Validate the token.
60
- - Ensure the folder exists.
61
- - Delete the folder and all its contents.
62
- - Return a JSON response confirming the deletion.
63
- """
64
- validate_token(token)
65
-
66
- PENDING_ROOT = Path("/data/pending_videos")
67
- target_folder = PENDING_ROOT / sha1
68
-
69
- if not target_folder.exists() or not target_folder.is_dir():
70
- raise HTTPException(status_code=404, detail=f"Folder {sha1} does not exist in pending_videos")
71
-
72
- try:
73
- shutil.rmtree(target_folder)
74
- except Exception as e:
75
- raise HTTPException(status_code=500, detail=f"Failed to delete {sha1}: {e}")
76
-
77
- return {"status": "ok", "message": f"Pending video folder {sha1} deleted successfully"}
78
-
79
-
80
- @router.post("/upload_pending_video", tags=["Pending Videos Manager"])
81
- async def upload_video(
82
- video: UploadFile = File(...),
83
- token: str = Query(..., description="Token required for authorization")
84
- ):
85
- """
86
- Saves an uploaded video by hashing it with SHA1 and placing it under:
87
- /data/media/<sha1>/<original_filename>
88
- Behavior:
89
- - Compute SHA1 of the uploaded video.
90
- - Ensure folder structure exists.
91
- - Delete any existing .mp4 files under sha1.
92
- - Save the uploaded video in the folder.
93
- """
94
- # Read content into memory (needed to compute hash twice)
95
- file_bytes = await video.read()
96
-
97
- # Create an in-memory file handler for hashing
98
- file_handler = io.BytesIO(file_bytes)
99
-
100
- # Compute SHA1
101
- try:
102
- sha1 = file_manager.compute_sha1(file_handler)
103
- except Exception as exc:
104
- raise HTTPException(status_code=500, detail=f"SHA1 computation failed: {exc}")
105
-
106
- # Ensure /data/media exists
107
- MEDIA_ROOT.mkdir(parents=True, exist_ok=True)
108
-
109
- # Path: /data/media/<sha1>
110
- video_root = MEDIA_ROOT / sha1
111
- video_root.mkdir(parents=True, exist_ok=True)
112
-
113
- # Delete old MP4 files
114
- try:
115
- for old_mp4 in video_root.glob("*.mp4"):
116
- old_mp4.unlink()
117
- except Exception as exc:
118
- raise HTTPException(status_code=500, detail=f"Failed to delete old videos: {exc}")
119
-
120
- # Save new video path
121
- final_path = video_root / video.filename
122
-
123
- # Save file
124
- save_result = file_manager.upload_file(io.BytesIO(file_bytes), final_path)
125
-
126
- if not save_result["operation_success"]:
127
- raise HTTPException(status_code=500, detail=save_result["error"])
128
-
129
- return JSONResponse(
130
- status_code=200,
131
- content={
132
- "status": "ok",
133
- "sha1": sha1,
134
- "saved_to": str(final_path)
135
- }
136
- )
137
-
138
-
139
- @router.get("/download_pending_video", tags=["Pending Videos Manager"])
140
- def download_video(
141
- sha1: str,
142
- token: str = Query(..., description="Token required for authorization")
143
- ):
144
- """
145
- Download a stored video by its SHA-1 directory name.
146
- This endpoint looks for a video stored under the path:
147
- /data/media/<sha1>/clip/
148
- and returns the first MP4 file found in that folder.
149
- The method performs the following steps:
150
- - Checks if the SHA-1 folder exists inside the media root.
151
- - Validates that the "clip" subfolder exists.
152
- - Searches for the first .mp4 file inside the clip folder.
153
- - Uses the FileManager.get_file method to ensure the file is accessible.
154
- - Returns the video directly as a FileResponse.
155
- Parameters
156
- ----------
157
- sha1 : str
158
- The SHA-1 hash corresponding to the directory where the video is stored.
159
- Returns
160
- -------
161
- FileResponse
162
- A streaming response containing the MP4 video.
163
- Raises
164
- ------
165
- HTTPException
166
- - 404 if the SHA-1 folder does not exist.
167
- - 404 if the clip folder is missing.
168
- - 404 if no MP4 files are found.
169
- - 404 if the file cannot be retrieved using FileManager.
170
- """
171
- sha1_folder = MEDIA_ROOT / sha1
172
-
173
- if not sha1_folder.exists() or not sha1_folder.is_dir():
174
- raise HTTPException(status_code=404, detail="SHA1 folder not found")
175
-
176
- # Find first MP4 file
177
- mp4_files = list(sha1_folder.glob("*.mp4"))
178
- if not mp4_files:
179
- raise HTTPException(status_code=404, detail="No MP4 files found")
180
-
181
- video_path = mp4_files[0]
182
-
183
- # Convert to relative path for FileManager
184
- relative_path = video_path.relative_to(MEDIA_ROOT)
185
-
186
- handler = file_manager.get_file(relative_path)
187
- if handler is None:
188
- raise HTTPException(status_code=404, detail="Video not accessible")
189
-
190
- handler.close()
191
-
192
- return FileResponse(
193
- path=video_path,
194
- media_type="video/mp4",
195
- filename=video_path.name
196
- )
197
-
198
- @router.get("/list_pending_videos", tags=["Pending Videos Manager"])
199
- def list_all_videos(
200
- token: str = Query(..., description="Token required for authorization")
201
- ):
202
- """
203
- List all videos stored under /data/media.
204
- For each SHA1 folder, the endpoint returns:
205
- - sha1: folder name
206
- - video_files: list of mp4 files inside /clip
207
- - latest_video: the most recently modified mp4
208
- - video_count: total number of mp4 files
209
- Notes:
210
- - Videos may not have a /clip folder.
211
- - SHA1 folders without mp4 files are still returned.
212
- """
213
- validate_token(token)
214
-
215
- results = []
216
-
217
- # If media root does not exist, return empty list
218
- if not MEDIA_ROOT.exists():
219
- return []
220
-
221
- for sha1_dir in MEDIA_ROOT.iterdir():
222
- if not sha1_dir.is_dir():
223
- continue # skip non-folders
224
-
225
- videos = []
226
- latest_video = None
227
-
228
- if sha1_dir.exists() and sha1_dir.is_dir():
229
- mp4_files = list(sha1_dir.glob("*.mp4"))
230
-
231
- # Sort by modification time (newest first)
232
- mp4_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
233
-
234
- videos = [f.name for f in mp4_files]
235
-
236
- if mp4_files:
237
- latest_video = mp4_files[0].name
238
-
239
- results.append({
240
- "sha1": sha1_dir.name,
241
- "video_name": latest_video
242
- })
243
-
244
  return results
 
1
+ import os
2
+ import io
3
+ import shutil
4
+
5
+ import sqlite3
6
+
7
+ from pathlib import Path
8
+
9
+ from fastapi import APIRouter, UploadFile, File, Query, HTTPException
10
+ from fastapi.responses import FileResponse, JSONResponse
11
+
12
+
13
+ from storage.files.file_manager import FileManager
14
+ from storage.common import validate_token
15
+
16
+ router = APIRouter(prefix="/pending_videos", tags=["Pending Videos Manager"])
17
+ MEDIA_ROOT = Path("/data/pending_videos")
18
+ file_manager = FileManager(MEDIA_ROOT)
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+
21
+
22
+ @router.delete("/clear_pending_videos", tags=["Pending Videos Manager"])
23
+ def clear_media(token: str = Query(..., description="Token required for authorization")):
24
+ """
25
+ Delete all contents of the /data/pending_videos folder.
26
+ Steps:
27
+ - Validate the token.
28
+ - Ensure the folder exists.
29
+ - Delete all files and subfolders inside /data/pending_videos.
30
+ - Return a JSON response confirming the deletion.
31
+ Warning: This will remove all stored videos, clips, and cast CSV files.
32
+ """
33
+ validate_token(token)
34
+
35
+ if not MEDIA_ROOT.exists() or not MEDIA_ROOT.is_dir():
36
+ raise HTTPException(status_code=404, detail="/data/pending_videos folder does not exist")
37
+
38
+ # Delete contents
39
+ for item in MEDIA_ROOT.iterdir():
40
+ try:
41
+ if item.is_dir():
42
+ shutil.rmtree(item)
43
+ else:
44
+ item.unlink()
45
+ except Exception as e:
46
+ raise HTTPException(status_code=500, detail=f"Failed to delete {item}: {e}")
47
+
48
+ return {"status": "ok", "message": "All pending_videos files deleted successfully"}
49
+
50
+
51
+ @router.delete("/clear_pending_video", tags=["Pending Videos Manager"])
52
+ def clear_pending_video(
53
+ sha1: str = Query(..., description="SHA1 folder to delete inside pending_videos"),
54
+ token: str = Query(..., description="Token required for authorization")
55
+ ):
56
+ """
57
+ Delete a specific SHA1 folder inside /data/pending_videos.
58
+ Steps:
59
+ - Validate the token.
60
+ - Ensure the folder exists.
61
+ - Delete the folder and all its contents.
62
+ - Return a JSON response confirming the deletion.
63
+ """
64
+ validate_token(token)
65
+
66
+ PENDING_ROOT = Path("/data/pending_videos")
67
+ target_folder = PENDING_ROOT / sha1
68
+
69
+ if not target_folder.exists() or not target_folder.is_dir():
70
+ raise HTTPException(status_code=404, detail=f"Folder {sha1} does not exist in pending_videos")
71
+
72
+ try:
73
+ shutil.rmtree(target_folder)
74
+ except Exception as e:
75
+ raise HTTPException(status_code=500, detail=f"Failed to delete {sha1}: {e}")
76
+
77
+ return {"status": "ok", "message": f"Pending video folder {sha1} deleted successfully"}
78
+
79
+
80
+ @router.post("/upload_pending_video", tags=["Pending Videos Manager"])
81
+ async def upload_video(
82
+ video: UploadFile = File(...),
83
+ token: str = Query(..., description="Token required for authorization")
84
+ ):
85
+ """
86
+ Saves an uploaded video by hashing it with SHA1 and placing it under:
87
+ /data/media/<sha1>/<original_filename>
88
+ Behavior:
89
+ - Compute SHA1 of the uploaded video.
90
+ - Ensure folder structure exists.
91
+ - Delete any existing .mp4 files under sha1.
92
+ - Save the uploaded video in the folder.
93
+ """
94
+ # Read content into memory (needed to compute hash twice)
95
+ file_bytes = await video.read()
96
+
97
+ # Create an in-memory file handler for hashing
98
+ file_handler = io.BytesIO(file_bytes)
99
+
100
+ # Compute SHA1
101
+ try:
102
+ sha1 = file_manager.compute_sha1(file_handler)
103
+ except Exception as exc:
104
+ raise HTTPException(status_code=500, detail=f"SHA1 computation failed: {exc}")
105
+
106
+ # Ensure /data/media exists
107
+ MEDIA_ROOT.mkdir(parents=True, exist_ok=True)
108
+
109
+ # Path: /data/media/<sha1>
110
+ video_root = MEDIA_ROOT / sha1
111
+ video_root.mkdir(parents=True, exist_ok=True)
112
+
113
+ # Delete old MP4 files
114
+ try:
115
+ for old_mp4 in video_root.glob("*.mp4"):
116
+ old_mp4.unlink()
117
+ except Exception as exc:
118
+ raise HTTPException(status_code=500, detail=f"Failed to delete old videos: {exc}")
119
+
120
+ # Save new video path
121
+ final_path = video_root / video.filename
122
+
123
+ # Save file
124
+ save_result = file_manager.upload_file(io.BytesIO(file_bytes), final_path)
125
+
126
+ if not save_result["operation_success"]:
127
+ raise HTTPException(status_code=500, detail=save_result["error"])
128
+
129
+ return JSONResponse(
130
+ status_code=200,
131
+ content={
132
+ "status": "ok",
133
+ "sha1": sha1,
134
+ "saved_to": str(final_path)
135
+ }
136
+ )
137
+
138
+
139
+ @router.get("/download_pending_video", tags=["Pending Videos Manager"])
140
+ def download_video(
141
+ sha1: str,
142
+ token: str = Query(..., description="Token required for authorization")
143
+ ):
144
+ """
145
+ Download a stored video by its SHA-1 directory name.
146
+ This endpoint looks for a video stored under the path:
147
+ /data/media/<sha1>/clip/
148
+ and returns the first MP4 file found in that folder.
149
+ The method performs the following steps:
150
+ - Checks if the SHA-1 folder exists inside the media root.
151
+ - Validates that the "clip" subfolder exists.
152
+ - Searches for the first .mp4 file inside the clip folder.
153
+ - Uses the FileManager.get_file method to ensure the file is accessible.
154
+ - Returns the video directly as a FileResponse.
155
+ Parameters
156
+ ----------
157
+ sha1 : str
158
+ The SHA-1 hash corresponding to the directory where the video is stored.
159
+ Returns
160
+ -------
161
+ FileResponse
162
+ A streaming response containing the MP4 video.
163
+ Raises
164
+ ------
165
+ HTTPException
166
+ - 404 if the SHA-1 folder does not exist.
167
+ - 404 if the clip folder is missing.
168
+ - 404 if no MP4 files are found.
169
+ - 404 if the file cannot be retrieved using FileManager.
170
+ """
171
+ sha1_folder = MEDIA_ROOT / sha1
172
+
173
+ if not sha1_folder.exists() or not sha1_folder.is_dir():
174
+ raise HTTPException(status_code=404, detail="SHA1 folder not found")
175
+
176
+ # Find first MP4 file
177
+ mp4_files = list(sha1_folder.glob("*.mp4"))
178
+ if not mp4_files:
179
+ raise HTTPException(status_code=404, detail="No MP4 files found")
180
+
181
+ video_path = mp4_files[0]
182
+
183
+ # Convert to relative path for FileManager
184
+ relative_path = video_path.relative_to(MEDIA_ROOT)
185
+
186
+ handler = file_manager.get_file(relative_path)
187
+ if handler is None:
188
+ raise HTTPException(status_code=404, detail="Video not accessible")
189
+
190
+ handler.close()
191
+
192
+ return FileResponse(
193
+ path=video_path,
194
+ media_type="video/mp4",
195
+ filename=video_path.name
196
+ )
197
+
198
+ @router.get("/list_pending_videos", tags=["Pending Videos Manager"])
199
+ def list_all_videos(
200
+ token: str = Query(..., description="Token required for authorization")
201
+ ):
202
+ """
203
+ List all videos stored under /data/media.
204
+ For each SHA1 folder, the endpoint returns:
205
+ - sha1: folder name
206
+ - video_files: list of mp4 files inside /clip
207
+ - latest_video: the most recently modified mp4
208
+ - video_count: total number of mp4 files
209
+ Notes:
210
+ - Videos may not have a /clip folder.
211
+ - SHA1 folders without mp4 files are still returned.
212
+ """
213
+ validate_token(token)
214
+
215
+ results = []
216
+
217
+ # If media root does not exist, return empty list
218
+ if not MEDIA_ROOT.exists():
219
+ return []
220
+
221
+ for sha1_dir in MEDIA_ROOT.iterdir():
222
+ if not sha1_dir.is_dir():
223
+ continue # skip non-folders
224
+
225
+ videos = []
226
+ latest_video = None
227
+
228
+ if sha1_dir.exists() and sha1_dir.is_dir():
229
+ mp4_files = list(sha1_dir.glob("*.mp4"))
230
+
231
+ # Sort by modification time (newest first)
232
+ mp4_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
233
+
234
+ videos = [f.name for f in mp4_files]
235
+
236
+ if mp4_files:
237
+ latest_video = mp4_files[0].name
238
+
239
+ results.append({
240
+ "sha1": sha1_dir.name,
241
+ "video_name": latest_video
242
+ })
243
+
244
  return results