Upload 5 files
Browse files- finetuning/finetuning.py +762 -0
- finetuning/lora.py +219 -0
- finetuning/reflection.py +520 -0
- finetuning/video_analysis.py +189 -0
- storage/pending_videos_routers.py +243 -243
finetuning/finetuning.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TypedDict, Annotated, List, Dict, Union
|
| 8 |
+
from langgraph.graph import StateGraph, END
|
| 9 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
+
from operator import itemgetter
|
| 12 |
+
|
| 13 |
+
# --- Configuración y Herramientas ---
|
| 14 |
+
|
| 15 |
+
# Directorios de trabajo
|
| 16 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 17 |
+
TEMP_DIR = BASE_DIR / "temp"
|
| 18 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
| 19 |
+
|
| 20 |
+
LOG_FILE = TEMP_DIR / "finetuning.log"
|
| 21 |
+
|
| 22 |
+
# Configurar el logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format='%(levelname)s: %(message)s',
|
| 26 |
+
handlers=[
|
| 27 |
+
logging.StreamHandler(),
|
| 28 |
+
logging.FileHandler(LOG_FILE, encoding="utf-8")
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Asegúrate de configurar tu API Key en la variable de entorno OPENAI_API_KEY.
|
| 34 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 35 |
+
if not api_key:
|
| 36 |
+
raise EnvironmentError("OPENAI_API_KEY no está configurada. Define la variable de entorno antes de ejecutar finetuning.py.")
|
| 37 |
+
|
| 38 |
+
# Inicializar LLM (se usa GPT-4o por su capacidad de razonamiento)
|
| 39 |
+
# En producción, considera un modelo que soporte tus tokens y latencia requeridas.
|
| 40 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.3)
|
| 41 |
+
|
| 42 |
+
# --- Ficheros de Ejemplo ---
|
| 43 |
+
|
| 44 |
+
# Fichero SRT inicial (Narrador)
|
| 45 |
+
INITIAL_SRT_CONTENT = """
|
| 46 |
+
1
|
| 47 |
+
00:00:00,000 --> 00:00:05,340
|
| 48 |
+
[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final.
|
| 49 |
+
|
| 50 |
+
2
|
| 51 |
+
00:00:04,340 --> 00:00:05,790
|
| 52 |
+
[Lucía] Hem de donar-ho tot.
|
| 53 |
+
|
| 54 |
+
3
|
| 55 |
+
00:00:05,790 --> 00:00:08,790
|
| 56 |
+
[Sandra] Ho sé, ho sé.
|
| 57 |
+
|
| 58 |
+
4
|
| 59 |
+
00:00:08,000 --> 00:00:10,000
|
| 60 |
+
(AD) De sobte, són al parc.
|
| 61 |
+
|
| 62 |
+
5
|
| 63 |
+
00:00:10,000 --> 00:00:14,000
|
| 64 |
+
(AD) Ara tallen menjar i fan una amanida a una cuina.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
# Fichero JSON de contexto (ejemplo de la respuesta anterior, pero simplificado para el Narrador)
|
| 68 |
+
CONTEXT_JSON_CONTENT = """
|
| 69 |
+
{
|
| 70 |
+
"segments": [
|
| 71 |
+
{"id": 1, "start": "00:00:00,000", "end": "00:00:05,340", "type": "dialog", "text": "[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final."},
|
| 72 |
+
{"id": 2, "start": "00:00:04,340", "end": "00:00:05,790", "type": "dialog", "text": "[Lucía] Hem de donar-ho tot."},
|
| 73 |
+
{"id": 3, "start": "00:00:05,790", "end": "00:00:08,790", "type": "dialog", "text": "[Sandra] Ho sé, ho sé."},
|
| 74 |
+
{"id": 4, "start": "00:00:08,000", "end": "00:00:10,000", "type": "visual_context", "text": "Cambio de escena a un parque. Personajes caminando."},
|
| 75 |
+
{"id": 5, "start": "00:00:10,000", "end": "00:00:14,000", "type": "visual_context", "text": "Escena en una cocina. Los personajes están cortando vegetales y haciendo una ensalada."}
|
| 76 |
+
]
|
| 77 |
+
}
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Fichero de Reglas UNE (Norma Técnica para el Crítico)
|
| 81 |
+
# Nota: Aquí se usa un resumen de las reglas pertinentes para un LLM.
|
| 82 |
+
UNE_RULES = """
|
| 83 |
+
### Reglas UNE de Audiodescripción (Para el Crítico)
|
| 84 |
+
1. **Objetividad y Foco Visual:** La descripción debe ser puramente objetiva, describiendo solo lo que se ve. Debe priorizar la acción y los elementos relevantes (personajes, objetos, localización).
|
| 85 |
+
2. **Tiempo y Espacio (Sincronización):** Las audiodescripciones (AD) deben insertarse en los silencios del diálogo. El tiempo de la AD (entre START y END) debe ser suficiente para narrar el contenido sin solaparse con el diálogo o la música importante.
|
| 86 |
+
3. **Concisión y Claridad:** Usar lenguaje simple y conciso. Evitar redundancias y juicios de valor.
|
| 87 |
+
4. **Formato:** Cada segmento de AD debe tener un formato SRT válido, incluyendo el marcador (AD) al principio de la línea de texto.
|
| 88 |
+
5. **Utilidad:** Cada segmento de AD debe ser útil para la comprensión y nunca ser redundante. En caso de repetir algo ya explicado antes, mejor no decir nada.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
EVALUATION_CRITERIA = [
|
| 92 |
+
"Precisió Descriptiva",
|
| 93 |
+
"Sincronització Temporal",
|
| 94 |
+
"Claredat i Concisió",
|
| 95 |
+
"Inclusió de Diàleg/So",
|
| 96 |
+
"Contextualització",
|
| 97 |
+
"Flux i Ritme de la Narració",
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
CRITERIA_WEIGHTS = {
|
| 101 |
+
"Precisió Descriptiva": 1,
|
| 102 |
+
"Sincronització Temporal": 4,
|
| 103 |
+
"Claredat i Concisió": 1,
|
| 104 |
+
"Inclusió de Diàleg/So": 1,
|
| 105 |
+
"Contextualització": 1,
|
| 106 |
+
"Flux i Ritme de la Narració": 1,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Inicializar ficheros para la ejecución
|
| 110 |
+
def setup_files(initial_srt_content: str, context_json_content: str):
|
| 111 |
+
"""Crea los ficheros iniciales necesarios en el sistema de archivos local."""
|
| 112 |
+
(TEMP_DIR / "une_ad_0.srt").write_text(initial_srt_content, encoding="utf-8")
|
| 113 |
+
(TEMP_DIR / "json_ad.json").write_text(context_json_content, encoding="utf-8")
|
| 114 |
+
logger.info("Ficheros iniciales 'une_ad_0.srt' y 'json_ad.json' creados.")
|
| 115 |
+
|
| 116 |
+
# --- Utilidades ---
|
| 117 |
+
def _strip_markdown_fences(content: str) -> str:
|
| 118 |
+
"""Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
|
| 119 |
+
text = content.strip()
|
| 120 |
+
if text.startswith("```"):
|
| 121 |
+
lines = text.splitlines()
|
| 122 |
+
# descartar primera línea con ``` o ```json
|
| 123 |
+
lines = lines[1:]
|
| 124 |
+
# eliminar el cierre ``` (pueden existir varias líneas en blanco finales)
|
| 125 |
+
while lines and lines[-1].strip() == "```":
|
| 126 |
+
lines.pop()
|
| 127 |
+
text = "\n".join(lines).strip()
|
| 128 |
+
return text
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def generate_evaluation_report(srt_content: str, iteration: int) -> tuple[float, float, Path]:
|
| 132 |
+
"""Solicita al LLM una avaluació estructurada i guarda'n el CSV."""
|
| 133 |
+
criteria_formatted = "\n".join(f"- {name}" for name in EVALUATION_CRITERIA)
|
| 134 |
+
prompt = (
|
| 135 |
+
"Actua com un auditor UNE. Avalua l'SRT generat, puntuant cada característica de 0 a 7 "
|
| 136 |
+
"segons la qualitat observada. Dónega justificació breve però concreta per a cada cas. "
|
| 137 |
+
"Les característiques obligatòries són:\n"
|
| 138 |
+
f"{criteria_formatted}\n"
|
| 139 |
+
"Retorna ÚNICAMENT un array JSON d'objectes amb les claus: "
|
| 140 |
+
"'caracteristica', 'valoracio' (nombre enter de 0 a 7) i 'justificacio'."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
response = llm.invoke(
|
| 144 |
+
[
|
| 145 |
+
SystemMessage(content=prompt),
|
| 146 |
+
HumanMessage(
|
| 147 |
+
content=(
|
| 148 |
+
"# SRT AVALUAT\n"
|
| 149 |
+
f"{srt_content}\n\n"
|
| 150 |
+
"Assegura't de complir el format indicat."
|
| 151 |
+
)
|
| 152 |
+
),
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
cleaned = _strip_markdown_fences(response.content)
|
| 157 |
+
try:
|
| 158 |
+
data = json.loads(cleaned)
|
| 159 |
+
if not isinstance(data, list):
|
| 160 |
+
raise ValueError("La resposta no és una llista.")
|
| 161 |
+
except Exception as exc:
|
| 162 |
+
logger.error(
|
| 163 |
+
"Error al generar l'avaluació estructurada: %s. Resposta original: %s",
|
| 164 |
+
exc,
|
| 165 |
+
response.content,
|
| 166 |
+
)
|
| 167 |
+
data = [
|
| 168 |
+
{
|
| 169 |
+
"caracteristica": "Avaluació fallida",
|
| 170 |
+
"valoracio": 1,
|
| 171 |
+
"justificacio": "No s'ha pogut obtenir l'avaluació del LLM.",
|
| 172 |
+
}
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
eval_path = TEMP_DIR / f"eval_{iteration}.csv"
|
| 176 |
+
with eval_path.open("w", encoding="utf-8", newline="") as csvfile:
|
| 177 |
+
writer = csv.writer(csvfile)
|
| 178 |
+
writer.writerow(["Caracteristica", "Valoracio (0-7)", "Justificacio"])
|
| 179 |
+
for item in data:
|
| 180 |
+
writer.writerow(
|
| 181 |
+
[
|
| 182 |
+
item.get("caracteristica", ""),
|
| 183 |
+
item.get("valoracio", 0),
|
| 184 |
+
item.get("justificacio", ""),
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
scores = []
|
| 189 |
+
weighted_sum = 0.0
|
| 190 |
+
total_weight = 0.0
|
| 191 |
+
|
| 192 |
+
for entry in data:
|
| 193 |
+
if not isinstance(entry, dict):
|
| 194 |
+
continue
|
| 195 |
+
try:
|
| 196 |
+
score = float(entry.get("valoracio", 0))
|
| 197 |
+
except (TypeError, ValueError):
|
| 198 |
+
score = 0.0
|
| 199 |
+
scores.append(score)
|
| 200 |
+
|
| 201 |
+
weight = CRITERIA_WEIGHTS.get(entry.get("caracteristica", ""), 1)
|
| 202 |
+
weighted_sum += score * weight
|
| 203 |
+
total_weight += weight
|
| 204 |
+
|
| 205 |
+
mean_score = sum(scores) / len(scores) if scores else 0.0
|
| 206 |
+
weighted_mean = weighted_sum / total_weight if total_weight else mean_score
|
| 207 |
+
return mean_score, weighted_mean, eval_path
|
| 208 |
+
|
| 209 |
+
# --- Definición del Estado de la Gráfica (StateGraph) ---
|
| 210 |
+
class ReflectionState(TypedDict):
|
| 211 |
+
"""Representa el estado del bucle de reflexión."""
|
| 212 |
+
iteration: int # Ciclo actual (empezando en 0)
|
| 213 |
+
current_srt_path: str # Ruta al archivo SRT actual (e.g., une_ad_0.srt, une_ad_1.srt)
|
| 214 |
+
critic_report: Dict[str, Union[float, str]] # Último informe del crítico (puntuación y texto)
|
| 215 |
+
history: List[SystemMessage] # Historial de mensajes entre agentes
|
| 216 |
+
evaluation_mean: float
|
| 217 |
+
best_iteration: int
|
| 218 |
+
best_weighted_mean: float
|
| 219 |
+
best_srt_path: str
|
| 220 |
+
best_eval_path: str
|
| 221 |
+
|
| 222 |
+
# --- Nodos/Agentes de la Gráfica ---
|
| 223 |
+
def narrator_agent(state: ReflectionState):
|
| 224 |
+
"""
|
| 225 |
+
Agente que genera o reescribe el SRT.
|
| 226 |
+
- En el ciclo 0, genera el SRT inicial.
|
| 227 |
+
- En ciclos > 0, reescribe el SRT basándose en el critic_report.
|
| 228 |
+
"""
|
| 229 |
+
iteration = state["iteration"]
|
| 230 |
+
critic_report = state["critic_report"]
|
| 231 |
+
history = state["history"]
|
| 232 |
+
|
| 233 |
+
# Cargar contexto y último SRT
|
| 234 |
+
json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
|
| 235 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 236 |
+
|
| 237 |
+
# 1. Definir el prompt
|
| 238 |
+
if iteration == 0:
|
| 239 |
+
# Tarea inicial (aunque en este caso ya se proporciona une_ad_0.srt)
|
| 240 |
+
# Aquí se simula la generación inicial.
|
| 241 |
+
prompt = (
|
| 242 |
+
"Ets un Narrador expert en Audiodescripció (AD). La teva tasca inicial és generar "
|
| 243 |
+
"un fitxer SRT d'audiodescripcions basat en el JSON de context visual. "
|
| 244 |
+
"TOT I AIXÍ, per a aquesta primera iteració, l'SRT ja s'ha generat. "
|
| 245 |
+
"Simplement retorna el contingut de 'une_ad_0.srt' com si fos la teva sortida. "
|
| 246 |
+
"Assegura't que totes les audiodescripcions estiguin en català i que cadascuna pugui ser locutada "
|
| 247 |
+
"dins del temps disponible (utilitza un màxim aproximat d'11 caràcters per segon). Si el tram de temps "
|
| 248 |
+
"és massa curt (<1.5s), combina'l amb el bloc d'AD més proper i ajusta els timestamps perquè la narració sigui fluida. "
|
| 249 |
+
"Evita redundàncies: no repeteixis informació ja descrita en segments d'AD anteriors o al diàleg, i elimina qualsevol detall que no sigui essencial."
|
| 250 |
+
)
|
| 251 |
+
output_srt = current_srt
|
| 252 |
+
reflection_text = "Generación inicial. No hay reflexión."
|
| 253 |
+
else:
|
| 254 |
+
# Tarea de reflexión
|
| 255 |
+
prompt = (
|
| 256 |
+
"Ets un Narrador expert en Audiodescripció (AD). Has rebut una crítica sobre la teva última versió de l'SRT. "
|
| 257 |
+
"La teva tasca és REESCRIURE el contingut d'audiodescripció (línies amb '(AD)') del fitxer SRT, "
|
| 258 |
+
"assegurant que sigui coherent amb el JSON de context i, sobretot, que CORREGEIXIS TOTS els problemes "
|
| 259 |
+
"mencionats a l'Informe Crític adjunt. Mantén intactes els diàlegs (línies amb [Nom]) i escriu totes les audiodescripcions en català natural. "
|
| 260 |
+
"Garanteix que cada bloc d'AD pugui ser locutat dins del seu interval temporal disponible considerant un màxim d'11 caràcters per segon. "
|
| 261 |
+
"Si l'interval és massa curt (<1.5s), fusiona'l amb el bloc d'AD anterior o posterior més proper i ajusta els timestamps perquè quedin contínues. "
|
| 262 |
+
"Prefereix frases concises i accionables, prioritzant la informació visual essencial, i elimina redundàncies amb AD anteriors o amb els diàlegs."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Concatenar la entrada para el LLM
|
| 266 |
+
input_content = f"""
|
| 267 |
+
# INFORME CRÍTICO
|
| 268 |
+
Porcentaje de Fiabilidad Anterior: {critic_report.get('reliability_percentage')}
|
| 269 |
+
Crítica Cualitativa: {critic_report.get('qualitative_critique')}
|
| 270 |
+
|
| 271 |
+
# JSON DE CONTEXTO VISUAL (Guía para la AD)
|
| 272 |
+
{json_context}
|
| 273 |
+
|
| 274 |
+
# ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration-1}.srt)
|
| 275 |
+
{current_srt}
|
| 276 |
+
|
| 277 |
+
REGLAS: Tu respuesta debe ser *SOLAMENTE* el contenido completo del nuevo archivo SRT (incluyendo diálogos), sin ningún comentario o explicación adicional.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
# Llamada al LLM
|
| 281 |
+
response = llm.invoke(
|
| 282 |
+
[
|
| 283 |
+
SystemMessage(content=prompt),
|
| 284 |
+
HumanMessage(content=input_content)
|
| 285 |
+
]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
output_srt = response.content
|
| 289 |
+
reflection_text = f"Reescrito en base al informe crítico: {critic_report.get('qualitative_critique', 'N/A')}"
|
| 290 |
+
|
| 291 |
+
# 2. Guardar la nueva salida
|
| 292 |
+
new_srt_path = TEMP_DIR / f"une_ad_{iteration}.srt"
|
| 293 |
+
new_srt_path.write_text(output_srt, encoding="utf-8")
|
| 294 |
+
|
| 295 |
+
# 3. Guardar el pensamiento (reflection_text)
|
| 296 |
+
(TEMP_DIR / f"thinking_{iteration}.txt").write_text(reflection_text, encoding="utf-8")
|
| 297 |
+
|
| 298 |
+
logger.info(f"Narrador: Generada la versión {iteration} del SRT en '{new_srt_path}'.")
|
| 299 |
+
|
| 300 |
+
# 4. Actualizar el estado
|
| 301 |
+
new_history = history + [AIMessage(content=f"Narrador v{iteration} completado. Razón de reflexión: {reflection_text}")]
|
| 302 |
+
return {
|
| 303 |
+
"iteration": iteration,
|
| 304 |
+
"current_srt_path": str(new_srt_path),
|
| 305 |
+
"history": new_history,
|
| 306 |
+
"evaluation_mean": state.get("evaluation_mean", 0.0),
|
| 307 |
+
"best_iteration": state.get("best_iteration", -1),
|
| 308 |
+
"best_weighted_mean": state.get("best_weighted_mean", 0.0),
|
| 309 |
+
"best_srt_path": state.get("best_srt_path", str(new_srt_path)),
|
| 310 |
+
"best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
def identity_manager_agent(state: ReflectionState):
|
| 314 |
+
"""
|
| 315 |
+
Agente que gestiona la identidad del usuario.
|
| 316 |
+
"""
|
| 317 |
+
iteration = state["iteration"]
|
| 318 |
+
history = state["history"]
|
| 319 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 320 |
+
|
| 321 |
+
prompt = (
|
| 322 |
+
"Ets un gestor d'identitats. La teva tasca és verificar la identitat de l'usuari "
|
| 323 |
+
"i assegurar-te que les seves dades estiguin actualitzades."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
input_content = f"""
|
| 327 |
+
# ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration}.srt):
|
| 328 |
+
{current_srt}
|
| 329 |
+
|
| 330 |
+
REGLAS: Tu respuesta debe ser *SOLAMENTE* un objeto JSON con la información de la identidad del usuario.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
# Llamada al LLM
|
| 334 |
+
response = llm.invoke(
|
| 335 |
+
[
|
| 336 |
+
SystemMessage(content=prompt),
|
| 337 |
+
HumanMessage(content=input_content)
|
| 338 |
+
]
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
|
| 342 |
+
try:
|
| 343 |
+
cleaned_response = _strip_markdown_fences(response.content)
|
| 344 |
+
identity_info = json.loads(cleaned_response)
|
| 345 |
+
if not isinstance(identity_info, dict):
|
| 346 |
+
raise ValueError("Estructura JSON incorrecta.")
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.error(f"Error al parsear el JSON de la identidad: {e}. Respuesta: {response.content}")
|
| 349 |
+
identity_info = {"error": "No s'ha pogut obtenir la informació d'identitat."}
|
| 350 |
+
|
| 351 |
+
logger.info(f"Identity Manager: Información de identidad actualizada.")
|
| 352 |
+
|
| 353 |
+
new_history = history + [AIMessage(content=f"Identity Manager v{iteration} completado.")]
|
| 354 |
+
return {
|
| 355 |
+
"iteration": iteration,
|
| 356 |
+
"current_srt_path": state["current_srt_path"],
|
| 357 |
+
"history": new_history,
|
| 358 |
+
"evaluation_mean": state.get("evaluation_mean", 0.0),
|
| 359 |
+
"best_iteration": state.get("best_iteration", -1),
|
| 360 |
+
"best_weighted_mean": state.get("best_weighted_mean", 0.0),
|
| 361 |
+
"best_srt_path": state.get("best_srt_path", state["current_srt_path"]),
|
| 362 |
+
"best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def critic_agent(state: ReflectionState):
|
| 366 |
+
"""
|
| 367 |
+
Agente que evalúa la calidad del SRT generado por el Narrador basándose en las Reglas UNE.
|
| 368 |
+
Devuelve una puntuación y una crítica cualitativa.
|
| 369 |
+
"""
|
| 370 |
+
iteration = state["iteration"]
|
| 371 |
+
history = state["history"]
|
| 372 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 373 |
+
|
| 374 |
+
prompt = (
|
| 375 |
+
"Ets un Crític d'Audiodescripció molt estricte. La teva tasca és avaluar l'SRT adjunt "
|
| 376 |
+
"únicament segons les Regles UNE proporcionades. L'avaluació ha de ser doble: "
|
| 377 |
+
"1. **Numèrica**: Un percentatge de fiabilitat (ex. 85.5) de 0 a 100%. "
|
| 378 |
+
"2. **Qualitativa**: Una crítica constructiva sobre les principals mancances de les AD respecte a les regles. "
|
| 379 |
+
"Has de ser EXTREMADAMENT estricte amb la sincronització (sense solapament amb el diàleg), "
|
| 380 |
+
"amb l'adequació temporal (velocitat màxima recomanada d'11 caràcters per segon) i amb l'absència de redundàncies. "
|
| 381 |
+
"Comprova també que totes les audiodescripcions estan escrites en català natural."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
input_content = f"""
|
| 385 |
+
# REGLAS UNE DE AUDIODESCRIPCIÓN:
|
| 386 |
+
{UNE_RULES}
|
| 387 |
+
|
| 388 |
+
# ARCHIVO SRT A EVALUAR (une_ad_{iteration}.srt):
|
| 389 |
+
{current_srt}
|
| 390 |
+
|
| 391 |
+
REGLAS DE RESPUESTA:
|
| 392 |
+
Tu respuesta debe ser *SOLAMENTE* un objeto JSON con dos claves:
|
| 393 |
+
1. "reliability_percentage": (float) El porcentaje de fiabilidad.
|
| 394 |
+
2. "qualitative_critique": (string) La crítica cualitativa y sugerencias de mejora.
|
| 395 |
+
Ejemplo de respuesta: {{"reliability_percentage": 75.0, "qualitative_critique": "El segmento 4 se solapa 0.34s con el diálogo de Sandra. El segmento 5 es demasiado genérico y no describe bien la acción."}}
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
# Llamada al LLM
|
| 399 |
+
response = llm.invoke(
|
| 400 |
+
[
|
| 401 |
+
SystemMessage(content=prompt),
|
| 402 |
+
HumanMessage(content=input_content)
|
| 403 |
+
]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
|
| 407 |
+
try:
|
| 408 |
+
cleaned_response = _strip_markdown_fences(response.content)
|
| 409 |
+
report = json.loads(cleaned_response)
|
| 410 |
+
if not isinstance(report, dict) or 'reliability_percentage' not in report:
|
| 411 |
+
raise ValueError("Estructura JSON incorrecta.")
|
| 412 |
+
except Exception as e:
|
| 413 |
+
logger.error(f"Error al parsear el JSON del Crítico: {e}. Respuesta: {response.content}")
|
| 414 |
+
report = {"reliability_percentage": 1.0, "qualitative_critique": "El Crítico no devolvió un JSON válido. Reintentar."}
|
| 415 |
+
|
| 416 |
+
logger.info(f"Crítico: Evaluación completada. Fiabilidad: {report.get('reliability_percentage')}%.")
|
| 417 |
+
|
| 418 |
+
mean_score, weighted_mean, eval_path = generate_evaluation_report(current_srt, iteration)
|
| 419 |
+
|
| 420 |
+
thinking_path = TEMP_DIR / f"thinking_{iteration}.txt"
|
| 421 |
+
if thinking_path.exists():
|
| 422 |
+
previous_text = thinking_path.read_text(encoding="utf-8")
|
| 423 |
+
thinking_path.write_text(
|
| 424 |
+
(
|
| 425 |
+
f"{previous_text}\n\nMitjana simple d'avaluació: {mean_score:.2f} / 7"
|
| 426 |
+
f"\nMitjana ponderada d'avaluació: {weighted_mean:.2f} / 7"
|
| 427 |
+
),
|
| 428 |
+
encoding="utf-8",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
best_iteration = state.get("best_iteration", -1)
|
| 432 |
+
best_weighted_mean = state.get("best_weighted_mean", -1.0)
|
| 433 |
+
best_srt_path = state.get("best_srt_path", state["current_srt_path"])
|
| 434 |
+
best_eval_path = state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv"))
|
| 435 |
+
|
| 436 |
+
if weighted_mean > best_weighted_mean:
|
| 437 |
+
best_iteration = iteration
|
| 438 |
+
best_weighted_mean = weighted_mean
|
| 439 |
+
best_srt_path = state["current_srt_path"]
|
| 440 |
+
best_eval_path = str(eval_path)
|
| 441 |
+
|
| 442 |
+
new_history = history + [
|
| 443 |
+
AIMessage(
|
| 444 |
+
content=(
|
| 445 |
+
"Crítico v{iter} completado. Fiabilidad: {reliab}%. "
|
| 446 |
+
"Mitjana simple: {mean:.2f}/7. Mitjana ponderada: {wmean:.2f}/7"
|
| 447 |
+
).format(
|
| 448 |
+
iter=iteration,
|
| 449 |
+
reliab=report.get("reliability_percentage"),
|
| 450 |
+
mean=mean_score,
|
| 451 |
+
wmean=weighted_mean,
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
]
|
| 455 |
+
return {
|
| 456 |
+
"iteration": iteration + 1,
|
| 457 |
+
"critic_report": report,
|
| 458 |
+
"history": new_history,
|
| 459 |
+
"evaluation_mean": weighted_mean,
|
| 460 |
+
"best_iteration": best_iteration,
|
| 461 |
+
"best_weighted_mean": best_weighted_mean,
|
| 462 |
+
"best_srt_path": best_srt_path,
|
| 463 |
+
"best_eval_path": best_eval_path,
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
def identity_manager_agent(state: ReflectionState):
|
| 467 |
+
"""
|
| 468 |
+
Agente que verifica coherencia entre hablantes en SRT, casting.csv y contexto visual.
|
| 469 |
+
Corrige asignaciones de hablantes y genera log de cambios.
|
| 470 |
+
"""
|
| 471 |
+
iteration = state["iteration"]
|
| 472 |
+
|
| 473 |
+
# Cargar archivos
|
| 474 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 475 |
+
casting_path = TEMP_DIR / "casting.csv"
|
| 476 |
+
json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
|
| 477 |
+
|
| 478 |
+
# Verificar existencia de casting.csv
|
| 479 |
+
if not casting_path.exists():
|
| 480 |
+
logger.warning("Casting.csv no encontrado. Saltando identity_manager.")
|
| 481 |
+
return state
|
| 482 |
+
|
| 483 |
+
casting_content = casting_path.read_text(encoding="utf-8")
|
| 484 |
+
|
| 485 |
+
prompt = (
|
| 486 |
+
"Ets un Identity Manager. La teva tasca és:\n"
|
| 487 |
+
"1. Verificar que les assignacions de parlants a l'SRT coincideixen amb casting.csv\n"
|
| 488 |
+
"2. Comprovar que els parlants assignats són coherents amb el context visual de json_ad.json\n"
|
| 489 |
+
"3. Si trobes inconsistències, re-assigna els parlants corregint les etiquetes [Nom]\n"
|
| 490 |
+
"4. Justifica canvis al fitxer identity_log.txt\n"
|
| 491 |
+
"\n"
|
| 492 |
+
"Dades d'entrada:\n"
|
| 493 |
+
f"- CASTING.CSV:\n{casting_content}\n"
|
| 494 |
+
f"- JSON CONTEXT:\n{json_context}\n"
|
| 495 |
+
f"- SRT ACTUAL:\n{current_srt}\n"
|
| 496 |
+
"\n"
|
| 497 |
+
"REGLES:\n"
|
| 498 |
+
"- Només modifica les línies de diàleg (ex: [Nom])\n"
|
| 499 |
+
"- Manté la numeració i timestamps\n"
|
| 500 |
+
"- Si no hi ha canvis, retorna l'SRT original\n"
|
| 501 |
+
"\n"
|
| 502 |
+
"Format de sortida:\n"
|
| 503 |
+
"```json\n"
|
| 504 |
+
"{{\n"
|
| 505 |
+
" \"srt_content\": \"<nou contingut SRT>\",\n"
|
| 506 |
+
" \"log_message\": \"<explicació canvis o 'Sense canvis'>\"\n"
|
| 507 |
+
"}}\n"
|
| 508 |
+
"```"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
response = llm.invoke([SystemMessage(content=prompt)])
|
| 512 |
+
|
| 513 |
+
try:
|
| 514 |
+
# Parsejar resposta JSON
|
| 515 |
+
cleaned = _strip_markdown_fences(response.content)
|
| 516 |
+
data = json.loads(cleaned)
|
| 517 |
+
new_srt = data["srt_content"]
|
| 518 |
+
log_msg = data["log_message"]
|
| 519 |
+
|
| 520 |
+
# Escriure log
|
| 521 |
+
log_path = TEMP_DIR / f"identity_log_{iteration}.txt"
|
| 522 |
+
log_path.write_text(f"Iteració {iteration}: {log_msg}", encoding="utf-8")
|
| 523 |
+
|
| 524 |
+
# Actualitzar SRT si hi ha canvis
|
| 525 |
+
if new_srt != current_srt:
|
| 526 |
+
new_srt_path = TEMP_DIR / f"une_ad_{iteration}_corrected.srt"
|
| 527 |
+
new_srt_path.write_text(new_srt, encoding="utf-8")
|
| 528 |
+
logger.info(f"Identity Manager: Correccions aplicades. Detalls: {log_msg}")
|
| 529 |
+
return {
|
| 530 |
+
**state,
|
| 531 |
+
"current_srt_path": str(new_srt_path)
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.error(f"Error en identity_manager: {e}")
|
| 536 |
+
|
| 537 |
+
return state
|
| 538 |
+
|
| 539 |
+
def background_descriptor_agent(state: ReflectionState):
|
| 540 |
+
"""
|
| 541 |
+
Agente que verifica coherencia entre escenarios en SRT y scenarios.csv.
|
| 542 |
+
Corrige nombres de escenarios usando descripciones coherentes.
|
| 543 |
+
"""
|
| 544 |
+
iteration = state["iteration"]
|
| 545 |
+
|
| 546 |
+
# Cargar archivos
|
| 547 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 548 |
+
scenarios_path = TEMP_DIR / "scenarios.csv"
|
| 549 |
+
|
| 550 |
+
# Verificar existencia de scenarios.csv
|
| 551 |
+
if not scenarios_path.exists():
|
| 552 |
+
logger.warning("Scenarios.csv no encontrado. Saltando background_descriptor.")
|
| 553 |
+
return state
|
| 554 |
+
|
| 555 |
+
scenarios_content = scenarios_path.read_text(encoding="utf-8")
|
| 556 |
+
|
| 557 |
+
prompt = (
|
| 558 |
+
"Ets un Background Descriptor. La teva tasca és:\n"
|
| 559 |
+
"1. Verificar que les descripcions d'escenaris a l'SRT coincideixen amb scenarios.csv\n"
|
| 560 |
+
"2. Si trobes coincidències, reemplaça les descripcions genèriques pel nom oficial de l'escenari\n"
|
| 561 |
+
"3. Justifica canvis al fitxer background_log.txt\n"
|
| 562 |
+
"\n"
|
| 563 |
+
"Dades d'entrada:\n"
|
| 564 |
+
f"- SCENARIOS.CSV:\n{scenarios_content}\n"
|
| 565 |
+
f"- SRT ACTUAL:\n{current_srt}\n"
|
| 566 |
+
"\n"
|
| 567 |
+
"REGLES:\n"
|
| 568 |
+
"- Només modifica línies d'audiodescripció (ex: (AD) ...)\n"
|
| 569 |
+
"- Manté la numeració i timestamps\n"
|
| 570 |
+
"- Si no hi ha canvis, retorna l'SRT original\n"
|
| 571 |
+
"\n"
|
| 572 |
+
"Format de sortida:\n"
|
| 573 |
+
"```json\n"
|
| 574 |
+
"{{\n"
|
| 575 |
+
" \"srt_content\": \"<nou contingut SRT>\",\n"
|
| 576 |
+
" \"log_message\": \"<explicació canvis o 'Sense canvis'>\"\n"
|
| 577 |
+
"}}\n"
|
| 578 |
+
"```"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
response = llm.invoke([SystemMessage(content=prompt)])
|
| 582 |
+
|
| 583 |
+
try:
|
| 584 |
+
# Parsejar resposta JSON
|
| 585 |
+
cleaned = _strip_markdown_fences(response.content)
|
| 586 |
+
data = json.loads(cleaned)
|
| 587 |
+
new_srt = data["srt_content"]
|
| 588 |
+
log_msg = data["log_message"]
|
| 589 |
+
|
| 590 |
+
# Escriure log
|
| 591 |
+
log_path = TEMP_DIR / f"background_log_{iteration}.txt"
|
| 592 |
+
log_path.write_text(f"Iteració {iteration}: {log_msg}", encoding="utf-8")
|
| 593 |
+
|
| 594 |
+
# Actualitzar SRT si hi ha canvis
|
| 595 |
+
if new_srt != current_srt:
|
| 596 |
+
new_srt_path = TEMP_DIR / f"une_ad_{iteration}_scenario_corrected.srt"
|
| 597 |
+
new_srt_path.write_text(new_srt, encoding="utf-8")
|
| 598 |
+
logger.info(f"Background Descriptor: Correccions aplicades. Detalls: {log_msg}")
|
| 599 |
+
return {
|
| 600 |
+
**state,
|
| 601 |
+
"current_srt_path": str(new_srt_path)
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
except Exception as e:
|
| 605 |
+
logger.error(f"Error en background_descriptor: {e}")
|
| 606 |
+
|
| 607 |
+
return state
|
| 608 |
+
|
| 609 |
+
# --- Condición de Salida del Bucle ---
|
| 610 |
+
|
| 611 |
+
def should_continue(state: ReflectionState) -> str:
|
| 612 |
+
"""
|
| 613 |
+
Función de chequeo que decide si continuar iterando o finalizar.
|
| 614 |
+
"""
|
| 615 |
+
MAX_ITERATIONS = 5 # Número máximo de ciclos
|
| 616 |
+
MIN_AVERAGE_SCORE = 6.0 # Umbral de calidad sobre 7
|
| 617 |
+
|
| 618 |
+
iteration = state["iteration"]
|
| 619 |
+
mean_score = state.get("evaluation_mean", 0.0)
|
| 620 |
+
|
| 621 |
+
if mean_score >= MIN_AVERAGE_SCORE:
|
| 622 |
+
logger.info(f"FIN: Mitjana ponderada d'avaluació assolida ({mean_score:.2f} >= {MIN_AVERAGE_SCORE}).")
|
| 623 |
+
return "end"
|
| 624 |
+
|
| 625 |
+
if iteration >= MAX_ITERATIONS:
|
| 626 |
+
logger.info(f"FIN: S'ha assolit el màxim d'iteracions ({iteration} / {MAX_ITERATIONS}).")
|
| 627 |
+
return "end"
|
| 628 |
+
|
| 629 |
+
logger.info(f"CONTINUAR: Iteració {iteration} / {MAX_ITERATIONS}. Mitjana ponderada actual: {mean_score:.2f} / 7.")
|
| 630 |
+
return "continue"
|
| 631 |
+
|
| 632 |
+
# --- Construcción de la Gráfica ---
|
| 633 |
+
|
| 634 |
+
# 1. Configurar el estado inicial
|
| 635 |
+
initial_state: ReflectionState = {
|
| 636 |
+
"iteration": 0,
|
| 637 |
+
"current_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
|
| 638 |
+
"critic_report": {"reliability_percentage": 0.0, "qualitative_critique": "Inicializando el proceso."},
|
| 639 |
+
"history": [],
|
| 640 |
+
"evaluation_mean": 0.0,
|
| 641 |
+
"best_iteration": -1,
|
| 642 |
+
"best_weighted_mean": -1.0,
|
| 643 |
+
"best_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
|
| 644 |
+
"best_eval_path": str(TEMP_DIR / "eval_0.csv"),
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
# 2. Definir la gráfica
|
| 648 |
+
workflow = StateGraph(ReflectionState)
|
| 649 |
+
|
| 650 |
+
# Nodos
|
| 651 |
+
workflow.add_node("narrator", narrator_agent)
|
| 652 |
+
workflow.add_node("identity_manager", identity_manager_agent)
|
| 653 |
+
workflow.add_node("background_descriptor", background_descriptor_agent)
|
| 654 |
+
workflow.add_node("critic", critic_agent)
|
| 655 |
+
|
| 656 |
+
# Estructura del bucle: Narrator -> Identity Manager -> Background Descriptor -> Critic -> Check
|
| 657 |
+
workflow.set_entry_point("narrator")
|
| 658 |
+
workflow.add_edge("narrator", "identity_manager")
|
| 659 |
+
workflow.add_edge("identity_manager", "background_descriptor")
|
| 660 |
+
workflow.add_edge("background_descriptor", "critic")
|
| 661 |
+
|
| 662 |
+
# Condición (puente de ramificación)
|
| 663 |
+
workflow.add_conditional_edges(
|
| 664 |
+
"critic",
|
| 665 |
+
should_continue,
|
| 666 |
+
{
|
| 667 |
+
"continue": "narrator", # Si no se cumple el umbral/ciclo, vuelve al narrador
|
| 668 |
+
"end": END # Si se cumple, termina
|
| 669 |
+
}
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Compilar la gráfica
|
| 673 |
+
app = workflow.compile()
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def generate_free_ad_from_srt(srt_path: Path) -> Path:
|
| 677 |
+
"""Genera una narración libre detallada a partir del SRT final."""
|
| 678 |
+
srt_content = srt_path.read_text(encoding="utf-8")
|
| 679 |
+
prompt = (
|
| 680 |
+
"Actua com una narradora professional d'audiodescripcions lliures. "
|
| 681 |
+
"A partir de l'SRT proporcionat, escriu un text narratiu en català que descrigui "
|
| 682 |
+
"de manera exhaustiva i fluida tot el que succeeix a la peça audiovisual. "
|
| 683 |
+
"Inclou accions, aparença, gestos, canvis d'escena i qualsevol detall rellevant, "
|
| 684 |
+
"sense limitar-te a les restriccions temporals del format SRT. "
|
| 685 |
+
"Evita repetir literalment els diàlegs, però contextualitza'ls quan sigui útil. "
|
| 686 |
+
"La narració ha de ser clara, coherent i apta per ser locutada com una narració lliure."
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
response = llm.invoke(
|
| 690 |
+
[
|
| 691 |
+
SystemMessage(content=prompt),
|
| 692 |
+
HumanMessage(
|
| 693 |
+
content=(
|
| 694 |
+
"# SRT FINAL\n"
|
| 695 |
+
f"{srt_content}\n\n"
|
| 696 |
+
"Respon únicamente con la narració lliure sin cap comentario adicional."
|
| 697 |
+
)
|
| 698 |
+
),
|
| 699 |
+
]
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
free_ad_path = TEMP_DIR / "free_ad.txt"
|
| 703 |
+
free_ad_path.write_text(response.content, encoding="utf-8")
|
| 704 |
+
logger.info(f"Narració lliure generada en '{free_ad_path}'.")
|
| 705 |
+
return free_ad_path
|
| 706 |
+
|
| 707 |
+
# --- Ejecución Principal ---
|
| 708 |
+
|
| 709 |
+
if __name__ == "__main__":
|
| 710 |
+
# Inicializar el entorno
|
| 711 |
+
setup_files(INITIAL_SRT_CONTENT, CONTEXT_JSON_CONTENT)
|
| 712 |
+
|
| 713 |
+
logger.info("--- Comenzando el Bucle de Finetuning ---")
|
| 714 |
+
|
| 715 |
+
# Ejecutar la gráfica
|
| 716 |
+
final_state = app.invoke(initial_state)
|
| 717 |
+
|
| 718 |
+
logger.info("\n--- Bucle Finalizado ---")
|
| 719 |
+
|
| 720 |
+
best_iteration = final_state.get("best_iteration", -1)
|
| 721 |
+
best_weighted_mean = final_state.get("best_weighted_mean", 0.0)
|
| 722 |
+
best_srt_path = Path(final_state.get("best_srt_path", final_state['current_srt_path']))
|
| 723 |
+
best_eval_path = Path(final_state.get("best_eval_path", TEMP_DIR / "eval_0.csv"))
|
| 724 |
+
|
| 725 |
+
final_srt_path = TEMP_DIR / "une_ad.srt"
|
| 726 |
+
final_eval_path = TEMP_DIR / "eval.csv"
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
shutil.copy(best_srt_path, final_srt_path)
|
| 730 |
+
logger.info(f"SRT final copiado a '{final_srt_path}'.")
|
| 731 |
+
except Exception as exc:
|
| 732 |
+
logger.error(f"No se pudo copiar el SRT final: {exc}")
|
| 733 |
+
|
| 734 |
+
try:
|
| 735 |
+
shutil.copy(best_eval_path, final_eval_path)
|
| 736 |
+
logger.info(f"Evaluación final copiada a '{final_eval_path}'.")
|
| 737 |
+
except Exception as exc:
|
| 738 |
+
logger.error(f"No se pudo copiar el CSV final: {exc}")
|
| 739 |
+
|
| 740 |
+
free_ad_path: Union[Path, None] = None
|
| 741 |
+
try:
|
| 742 |
+
free_ad_path = generate_free_ad_from_srt(final_srt_path)
|
| 743 |
+
except Exception as exc:
|
| 744 |
+
logger.error(f"No s'ha pogut generar la narració lliure: {exc}")
|
| 745 |
+
|
| 746 |
+
# Mostrar resultados
|
| 747 |
+
print(f"Número final de ciclos: {final_state['iteration']}")
|
| 748 |
+
print(f"Iteración óptima: {best_iteration} (mitjana ponderada {best_weighted_mean:.2f}/7)")
|
| 749 |
+
print(f"Ruta al SRT final: {final_srt_path}")
|
| 750 |
+
print(f"Ruta a l'avaluació final: {final_eval_path}")
|
| 751 |
+
if free_ad_path is not None:
|
| 752 |
+
print(f"Ruta a la narració lliure: {free_ad_path}")
|
| 753 |
+
else:
|
| 754 |
+
print("No s'ha pogut generar la narració lliure.")
|
| 755 |
+
|
| 756 |
+
# Mostrar el SRT final generado
|
| 757 |
+
print("\n--- Contenido del SRT Final ---")
|
| 758 |
+
print(final_srt_path.read_text(encoding="utf-8"))
|
| 759 |
+
|
| 760 |
+
if free_ad_path is not None:
|
| 761 |
+
print("\n--- Narració Lliure ---")
|
| 762 |
+
print(free_ad_path.read_text(encoding="utf-8"))
|
finetuning/lora.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
|
| 6 |
+
from datasets import Dataset
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
AutoModelForCausalLM,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
Trainer,
|
| 12 |
+
)
|
| 13 |
+
from peft import LoraConfig, get_peft_model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 17 |
+
DATA_DIR = BASE_DIR / "data"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]:
|
| 21 |
+
"""Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt.
|
| 22 |
+
|
| 23 |
+
Cada ejemplo se formatea como una instrucción estilo instruct, usando el SRT como entrada
|
| 24 |
+
y la narración libre como salida.
|
| 25 |
+
"""
|
| 26 |
+
examples: List[Dict[str, str]] = []
|
| 27 |
+
|
| 28 |
+
if not data_dir.exists():
|
| 29 |
+
raise FileNotFoundError(f"Data dir not found: {data_dir}")
|
| 30 |
+
|
| 31 |
+
for item in sorted(data_dir.iterdir()):
|
| 32 |
+
if not item.is_dir():
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
srt_path = item / "target_une_ad.srt"
|
| 36 |
+
free_path = item / "free_ad.txt"
|
| 37 |
+
|
| 38 |
+
if not srt_path.exists() or not free_path.exists():
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
srt_text = srt_path.read_text(encoding="utf-8")
|
| 42 |
+
free_text = free_path.read_text(encoding="utf-8")
|
| 43 |
+
|
| 44 |
+
# Formato tipo instruction-tuning, en catalán, coherente con la tarea
|
| 45 |
+
prompt = (
|
| 46 |
+
"Converteix el següent fitxer SRT d'audiodescripció UNE (amb restriccions temporals) "
|
| 47 |
+
"en una narració lliure detallada en català, sense límits de temps. "
|
| 48 |
+
"Mantén tota la informació visual rellevant però amb un to fluid i natural.\n\n"
|
| 49 |
+
"### SRT UNE\n" + srt_text.strip() + "\n\n### Narració lliure:"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
examples.append({"prompt": prompt, "output": free_text.strip()})
|
| 53 |
+
|
| 54 |
+
if not examples:
|
| 55 |
+
raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)")
|
| 56 |
+
|
| 57 |
+
return examples
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset:
|
| 61 |
+
"""Construye un Dataset de Hugging Face a partir de los pares prompt/output.
|
| 62 |
+
|
| 63 |
+
Se concatena en una sola secuencia para entrenamiento causal:
|
| 64 |
+
[PROMPT] + [OUTPUT] + eos
|
| 65 |
+
y se enmascaran los tokens del prompt para que la loss sólo se compute sobre la salida.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def _gen():
|
| 69 |
+
for ex in pairs:
|
| 70 |
+
yield {"prompt": ex["prompt"], "output": ex["output"]}
|
| 71 |
+
|
| 72 |
+
raw_ds = Dataset.from_generator(_gen)
|
| 73 |
+
|
| 74 |
+
def tokenize_fn(batch):
|
| 75 |
+
prompts = batch["prompt"]
|
| 76 |
+
outputs = batch["output"]
|
| 77 |
+
|
| 78 |
+
input_ids_list = []
|
| 79 |
+
labels_list = []
|
| 80 |
+
|
| 81 |
+
for p, o in zip(prompts, outputs):
|
| 82 |
+
full_text = p + "\n" + o + tokenizer.eos_token
|
| 83 |
+
enc = tokenizer(
|
| 84 |
+
full_text,
|
| 85 |
+
truncation=True,
|
| 86 |
+
max_length=max_length,
|
| 87 |
+
padding="max_length",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Máscara: ignorar loss en tokens del prompt
|
| 91 |
+
prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"]
|
| 92 |
+
prompt_len = min(len(prompt_ids), max_length)
|
| 93 |
+
|
| 94 |
+
labels = enc["input_ids"].copy()
|
| 95 |
+
for i in range(prompt_len):
|
| 96 |
+
labels[i] = -100
|
| 97 |
+
|
| 98 |
+
input_ids_list.append(enc["input_ids"])
|
| 99 |
+
labels_list.append(labels)
|
| 100 |
+
|
| 101 |
+
return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list}
|
| 102 |
+
|
| 103 |
+
tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"])
|
| 104 |
+
return tokenized
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05):
|
| 108 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
+
base_model_name,
|
| 110 |
+
torch_dtype="auto",
|
| 111 |
+
device_map="auto",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
lora_config = LoraConfig(
|
| 115 |
+
r=r,
|
| 116 |
+
lora_alpha=alpha,
|
| 117 |
+
lora_dropout=dropout,
|
| 118 |
+
bias="none",
|
| 119 |
+
task_type="CAUSAL_LM",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
model = get_peft_model(model, lora_config)
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def parse_args() -> argparse.Namespace:
|
| 127 |
+
parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD")
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--base_model",
|
| 130 |
+
type=str,
|
| 131 |
+
default="projecte-aina/salamandra-instruct-7b",
|
| 132 |
+
help="Nom o ruta del model base (HF hub o path local)",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--data_dir",
|
| 136 |
+
type=str,
|
| 137 |
+
default=str(DATA_DIR),
|
| 138 |
+
help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--output_dir",
|
| 142 |
+
type=str,
|
| 143 |
+
default=str(BASE_DIR / "lora_output"),
|
| 144 |
+
help="Directori on desar l'adapter LoRA",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 147 |
+
parser.add_argument("--gradient_accumulation", type=int, default=8)
|
| 148 |
+
parser.add_argument("--epochs", type=int, default=3)
|
| 149 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 150 |
+
parser.add_argument("--max_length", type=int, default=2048)
|
| 151 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.03)
|
| 152 |
+
parser.add_argument("--logging_steps", type=int, default=10)
|
| 153 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
| 154 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
| 155 |
+
parser.add_argument("--r", type=int, default=16, help="Rank de LoRA")
|
| 156 |
+
parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA")
|
| 157 |
+
parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA")
|
| 158 |
+
return parser.parse_args()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def main():
|
| 162 |
+
args = parse_args()
|
| 163 |
+
|
| 164 |
+
data_dir = Path(args.data_dir)
|
| 165 |
+
output_dir = Path(args.output_dir)
|
| 166 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
print(f"[lora] Buscant dades a: {data_dir}")
|
| 169 |
+
pairs = find_training_pairs(data_dir)
|
| 170 |
+
print(f"[lora] Nombre d'exemples trobats: {len(pairs)}")
|
| 171 |
+
|
| 172 |
+
print(f"[lora] Carregant tokenizer de {args.base_model}")
|
| 173 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
| 174 |
+
if tokenizer.pad_token is None:
|
| 175 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 176 |
+
|
| 177 |
+
print("[lora] Construint dataset tokenitzat...")
|
| 178 |
+
dataset = build_dataset(pairs, tokenizer, max_length=args.max_length)
|
| 179 |
+
|
| 180 |
+
print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...")
|
| 181 |
+
model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout)
|
| 182 |
+
|
| 183 |
+
training_args = TrainingArguments(
|
| 184 |
+
output_dir=str(output_dir),
|
| 185 |
+
per_device_train_batch_size=args.batch_size,
|
| 186 |
+
gradient_accumulation_steps=args.gradient_accumulation,
|
| 187 |
+
num_train_epochs=args.epochs,
|
| 188 |
+
learning_rate=args.lr,
|
| 189 |
+
warmup_ratio=args.warmup_ratio,
|
| 190 |
+
logging_steps=args.logging_steps,
|
| 191 |
+
save_steps=args.save_steps,
|
| 192 |
+
evaluation_strategy="steps",
|
| 193 |
+
eval_steps=args.eval_steps,
|
| 194 |
+
save_total_limit=2,
|
| 195 |
+
bf16=True,
|
| 196 |
+
gradient_checkpointing=True,
|
| 197 |
+
report_to=[],
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
trainer = Trainer(
|
| 201 |
+
model=model,
|
| 202 |
+
args=training_args,
|
| 203 |
+
train_dataset=dataset,
|
| 204 |
+
eval_dataset=None,
|
| 205 |
+
tokenizer=tokenizer,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
print("[lora] Iniciant entrenament...")
|
| 209 |
+
trainer.train()
|
| 210 |
+
|
| 211 |
+
print("[lora] Guardant adapter LoRA...")
|
| 212 |
+
model.save_pretrained(str(output_dir))
|
| 213 |
+
tokenizer.save_pretrained(str(output_dir))
|
| 214 |
+
|
| 215 |
+
print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|
finetuning/reflection.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TypedDict, Annotated, List, Dict, Union
|
| 8 |
+
from langgraph.graph import StateGraph, END
|
| 9 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
+
from operator import itemgetter
|
| 12 |
+
|
| 13 |
+
# --- Configuración y Herramientas ---
|
| 14 |
+
|
| 15 |
+
# Directorios de trabajo
|
| 16 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 17 |
+
TEMP_DIR = BASE_DIR / "temp"
|
| 18 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
| 19 |
+
|
| 20 |
+
LOG_FILE = TEMP_DIR / "reflection.log"
|
| 21 |
+
|
| 22 |
+
# Configurar el logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format='%(levelname)s: %(message)s',
|
| 26 |
+
handlers=[
|
| 27 |
+
logging.StreamHandler(),
|
| 28 |
+
logging.FileHandler(LOG_FILE, encoding="utf-8")
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Asegúrate de configurar tu API Key.
|
| 34 |
+
# En un entorno real, usa os.environ["OPENAI_API_KEY"]
|
| 35 |
+
# Aquí usamos un placeholder para la demostración.
|
| 36 |
+
if "OPENAI_API_KEY" not in os.environ:
|
| 37 |
+
logger.warning("OPENAI_API_KEY no está configurada. Usando un placeholder.")
|
| 38 |
+
os.environ["OPENAI_API_KEY"] = "sk-..."
|
| 39 |
+
|
| 40 |
+
# Inicializar LLM (se usa GPT-4o por su capacidad de razonamiento)
|
| 41 |
+
# En producción, considera un modelo que soporte tus tokens y latencia requeridas.
|
| 42 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.3)
|
| 43 |
+
|
| 44 |
+
# --- Ficheros de Ejemplo ---
|
| 45 |
+
|
| 46 |
+
# Fichero SRT inicial (Narrador)
|
| 47 |
+
INITIAL_SRT_CONTENT = """
|
| 48 |
+
1
|
| 49 |
+
00:00:00,000 --> 00:00:05,340
|
| 50 |
+
[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final.
|
| 51 |
+
|
| 52 |
+
2
|
| 53 |
+
00:00:04,340 --> 00:00:05,790
|
| 54 |
+
[Lucía] Hem de donar-ho tot.
|
| 55 |
+
|
| 56 |
+
3
|
| 57 |
+
00:00:05,790 --> 00:00:08,790
|
| 58 |
+
[Sandra] Ho sé, ho sé.
|
| 59 |
+
|
| 60 |
+
4
|
| 61 |
+
00:00:08,000 --> 00:00:10,000
|
| 62 |
+
(AD) De sobte, són al parc.
|
| 63 |
+
|
| 64 |
+
5
|
| 65 |
+
00:00:10,000 --> 00:00:14,000
|
| 66 |
+
(AD) Ara tallen menjar i fan una amanida a una cuina.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# Fichero JSON de contexto (ejemplo de la respuesta anterior, pero simplificado para el Narrador)
|
| 70 |
+
CONTEXT_JSON_CONTENT = """
|
| 71 |
+
{
|
| 72 |
+
"segments": [
|
| 73 |
+
{"id": 1, "start": "00:00:00,000", "end": "00:00:05,340", "type": "dialog", "text": "[Sandra] Però de veritat crec que aquest projecte canviarà la nostra nota final."},
|
| 74 |
+
{"id": 2, "start": "00:00:04,340", "end": "00:00:05,790", "type": "dialog", "text": "[Lucía] Hem de donar-ho tot."},
|
| 75 |
+
{"id": 3, "start": "00:00:05,790", "end": "00:00:08,790", "type": "dialog", "text": "[Sandra] Ho sé, ho sé."},
|
| 76 |
+
{"id": 4, "start": "00:00:08,000", "end": "00:00:10,000", "type": "visual_context", "text": "Cambio de escena a un parque. Personajes caminando."},
|
| 77 |
+
{"id": 5, "start": "00:00:10,000", "end": "00:00:14,000", "type": "visual_context", "text": "Escena en una cocina. Los personajes están cortando vegetales y haciendo una ensalada."}
|
| 78 |
+
]
|
| 79 |
+
}
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
# Fichero de Reglas UNE (Norma Técnica para el Crítico)
|
| 83 |
+
# Nota: Aquí se usa un resumen de las reglas pertinentes para un LLM.
|
| 84 |
+
UNE_RULES = """
|
| 85 |
+
### Reglas UNE de Audiodescripción (Para el Crítico)
|
| 86 |
+
1. **Objetividad y Foco Visual:** La descripción debe ser puramente objetiva, describiendo solo lo que se ve. Debe priorizar la acción y los elementos relevantes (personajes, objetos, localización).
|
| 87 |
+
2. **Tiempo y Espacio (Sincronización):** Las audiodescripciones (AD) deben insertarse en los silencios del diálogo. El tiempo de la AD (entre START y END) debe ser suficiente para narrar el contenido sin solaparse con el diálogo o la música importante.
|
| 88 |
+
3. **Concisión y Claridad:** Usar lenguaje simple y conciso. Evitar redundancias y juicios de valor.
|
| 89 |
+
4. **Formato:** Cada segmento de AD debe tener un formato SRT válido, incluyendo el marcador (AD) al principio de la línea de texto.
|
| 90 |
+
5. **Utilidad:** Cada segmento de AD debe ser útil para la comprensión y nunca ser redundante. En caso de repetir algo ya explicado antes, mejor no decir nada.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
EVALUATION_CRITERIA = [
|
| 94 |
+
"Precisió Descriptiva",
|
| 95 |
+
"Sincronització Temporal",
|
| 96 |
+
"Claredat i Concisió",
|
| 97 |
+
"Inclusió de Diàleg/So",
|
| 98 |
+
"Contextualització",
|
| 99 |
+
"Flux i Ritme de la Narració",
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
CRITERIA_WEIGHTS = {
|
| 103 |
+
"Precisió Descriptiva": 1,
|
| 104 |
+
"Sincronització Temporal": 4,
|
| 105 |
+
"Claredat i Concisió": 1,
|
| 106 |
+
"Inclusió de Diàleg/So": 1,
|
| 107 |
+
"Contextualització": 1,
|
| 108 |
+
"Flux i Ritme de la Narració": 1,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Inicializar ficheros para la ejecución
|
| 112 |
+
def setup_files(initial_srt_content: str, context_json_content: str):
|
| 113 |
+
"""Crea los ficheros iniciales necesarios en el sistema de archivos local."""
|
| 114 |
+
(TEMP_DIR / "une_ad_0.srt").write_text(initial_srt_content, encoding="utf-8")
|
| 115 |
+
(TEMP_DIR / "json_ad.json").write_text(context_json_content, encoding="utf-8")
|
| 116 |
+
logger.info("Ficheros iniciales 'une_ad_0.srt' y 'json_ad.json' creados.")
|
| 117 |
+
|
| 118 |
+
# --- Utilidades ---
|
| 119 |
+
def _strip_markdown_fences(content: str) -> str:
|
| 120 |
+
"""Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
|
| 121 |
+
text = content.strip()
|
| 122 |
+
if text.startswith("```"):
|
| 123 |
+
lines = text.splitlines()
|
| 124 |
+
# descartar primera línea con ``` o ```json
|
| 125 |
+
lines = lines[1:]
|
| 126 |
+
# eliminar el cierre ``` (pueden existir varias líneas en blanco finales)
|
| 127 |
+
while lines and lines[-1].strip() == "```":
|
| 128 |
+
lines.pop()
|
| 129 |
+
text = "\n".join(lines).strip()
|
| 130 |
+
return text
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def generate_evaluation_report(srt_content: str, iteration: int) -> tuple[float, float, Path]:
|
| 134 |
+
"""Solicita al LLM una avaluació estructurada i guarda'n el CSV."""
|
| 135 |
+
criteria_formatted = "\n".join(f"- {name}" for name in EVALUATION_CRITERIA)
|
| 136 |
+
prompt = (
|
| 137 |
+
"Actua com un auditor UNE. Avalua l'SRT generat, puntuant cada característica de 0 a 7 "
|
| 138 |
+
"segons la qualitat observada. Dónega justificació breve però concreta per a cada cas. "
|
| 139 |
+
"Les característiques obligatòries són:\n"
|
| 140 |
+
f"{criteria_formatted}\n"
|
| 141 |
+
"Retorna ÚNICAMENT un array JSON d'objectes amb les claus: "
|
| 142 |
+
"'caracteristica', 'valoracio' (nombre enter de 0 a 7) i 'justificacio'."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
response = llm.invoke(
|
| 146 |
+
[
|
| 147 |
+
SystemMessage(content=prompt),
|
| 148 |
+
HumanMessage(
|
| 149 |
+
content=(
|
| 150 |
+
"# SRT AVALUAT\n"
|
| 151 |
+
f"{srt_content}\n\n"
|
| 152 |
+
"Assegura't de complir el format indicat."
|
| 153 |
+
)
|
| 154 |
+
),
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
cleaned = _strip_markdown_fences(response.content)
|
| 159 |
+
try:
|
| 160 |
+
data = json.loads(cleaned)
|
| 161 |
+
if not isinstance(data, list):
|
| 162 |
+
raise ValueError("La resposta no és una llista.")
|
| 163 |
+
except Exception as exc:
|
| 164 |
+
logger.error(
|
| 165 |
+
"Error al generar l'avaluació estructurada: %s. Resposta original: %s",
|
| 166 |
+
exc,
|
| 167 |
+
response.content,
|
| 168 |
+
)
|
| 169 |
+
data = [
|
| 170 |
+
{
|
| 171 |
+
"caracteristica": "Avaluació fallida",
|
| 172 |
+
"valoracio": 1,
|
| 173 |
+
"justificacio": "No s'ha pogut obtenir l'avaluació del LLM.",
|
| 174 |
+
}
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
eval_path = TEMP_DIR / f"eval_{iteration}.csv"
|
| 178 |
+
with eval_path.open("w", encoding="utf-8", newline="") as csvfile:
|
| 179 |
+
writer = csv.writer(csvfile)
|
| 180 |
+
writer.writerow(["Caracteristica", "Valoracio (0-7)", "Justificacio"])
|
| 181 |
+
for item in data:
|
| 182 |
+
writer.writerow(
|
| 183 |
+
[
|
| 184 |
+
item.get("caracteristica", ""),
|
| 185 |
+
item.get("valoracio", 0),
|
| 186 |
+
item.get("justificacio", ""),
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
scores = []
|
| 191 |
+
weighted_sum = 0.0
|
| 192 |
+
total_weight = 0.0
|
| 193 |
+
|
| 194 |
+
for entry in data:
|
| 195 |
+
if not isinstance(entry, dict):
|
| 196 |
+
continue
|
| 197 |
+
try:
|
| 198 |
+
score = float(entry.get("valoracio", 0))
|
| 199 |
+
except (TypeError, ValueError):
|
| 200 |
+
score = 0.0
|
| 201 |
+
scores.append(score)
|
| 202 |
+
|
| 203 |
+
weight = CRITERIA_WEIGHTS.get(entry.get("caracteristica", ""), 1)
|
| 204 |
+
weighted_sum += score * weight
|
| 205 |
+
total_weight += weight
|
| 206 |
+
|
| 207 |
+
mean_score = sum(scores) / len(scores) if scores else 0.0
|
| 208 |
+
weighted_mean = weighted_sum / total_weight if total_weight else mean_score
|
| 209 |
+
return mean_score, weighted_mean, eval_path
|
| 210 |
+
|
| 211 |
+
# --- Definición del Estado de la Gráfica (StateGraph) ---
|
| 212 |
+
class ReflectionState(TypedDict):
|
| 213 |
+
"""Representa el estado del bucle de reflexión."""
|
| 214 |
+
iteration: int # Ciclo actual (empezando en 0)
|
| 215 |
+
current_srt_path: str # Ruta al archivo SRT actual (e.g., une_ad_0.srt, une_ad_1.srt)
|
| 216 |
+
critic_report: Dict[str, Union[float, str]] # Último informe del crítico (puntuación y texto)
|
| 217 |
+
history: List[SystemMessage] # Historial de mensajes entre agentes
|
| 218 |
+
evaluation_mean: float
|
| 219 |
+
best_iteration: int
|
| 220 |
+
best_weighted_mean: float
|
| 221 |
+
best_srt_path: str
|
| 222 |
+
best_eval_path: str
|
| 223 |
+
|
| 224 |
+
# --- Nodos/Agentes de la Gráfica ---
|
| 225 |
+
def narrator_agent(state: ReflectionState):
|
| 226 |
+
"""
|
| 227 |
+
Agente que genera o reescribe el SRT.
|
| 228 |
+
- En el ciclo 0, genera el SRT inicial.
|
| 229 |
+
- En ciclos > 0, reescribe el SRT basándose en el critic_report.
|
| 230 |
+
"""
|
| 231 |
+
iteration = state["iteration"]
|
| 232 |
+
critic_report = state["critic_report"]
|
| 233 |
+
history = state["history"]
|
| 234 |
+
|
| 235 |
+
# Cargar contexto y último SRT
|
| 236 |
+
json_context = (TEMP_DIR / "json_ad.json").read_text(encoding="utf-8")
|
| 237 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 238 |
+
|
| 239 |
+
# 1. Definir el prompt
|
| 240 |
+
if iteration == 0:
|
| 241 |
+
# Tarea inicial (aunque en este caso ya se proporciona une_ad_0.srt)
|
| 242 |
+
# Aquí se simula la generación inicial.
|
| 243 |
+
prompt = (
|
| 244 |
+
"Ets un Narrador expert en Audiodescripció (AD). La teva tasca inicial és generar "
|
| 245 |
+
"un fitxer SRT d'audiodescripcions basat en el JSON de context visual. "
|
| 246 |
+
"TOT I AIXÍ, per a aquesta primera iteració, l'SRT ja s'ha generat. "
|
| 247 |
+
"Simplement retorna el contingut de 'une_ad_0.srt' com si fos la teva sortida. "
|
| 248 |
+
"Assegura't que totes les audiodescripcions estiguin en català i que cadascuna pugui ser locutada "
|
| 249 |
+
"dins del temps disponible (utilitza un màxim aproximat d'11 caràcters per segon). Si el tram de temps "
|
| 250 |
+
"és massa curt (<1.5s), combina'l amb el bloc d'AD més proper i ajusta els timestamps perquè la narració sigui fluida. "
|
| 251 |
+
"Evita redundàncies: no repeteixis informació ja descrita en segments d'AD anteriors o al diàleg, i elimina qualsevol detall que no sigui essencial."
|
| 252 |
+
)
|
| 253 |
+
output_srt = current_srt
|
| 254 |
+
reflection_text = "Generación inicial. No hay reflexión."
|
| 255 |
+
else:
|
| 256 |
+
# Tarea de reflexión
|
| 257 |
+
prompt = (
|
| 258 |
+
"Ets un Narrador expert en Audiodescripció (AD). Has rebut una crítica sobre la teva última versió de l'SRT. "
|
| 259 |
+
"La teva tasca és REESCRIURE el contingut d'audiodescripció (línies amb '(AD)') del fitxer SRT, "
|
| 260 |
+
"assegurant que sigui coherent amb el JSON de context i, sobretot, que CORREGEIXIS TOTS els problemes "
|
| 261 |
+
"mencionats a l'Informe Crític adjunt. Mantén intactes els diàlegs (línies amb [Nom]) i escriu totes les audiodescripcions en català natural. "
|
| 262 |
+
"Garanteix que cada bloc d'AD pugui ser locutat dins del seu interval temporal disponible considerant un màxim d'11 caràcters per segon. "
|
| 263 |
+
"Si l'interval és massa curt (<1.5s), fusiona'l amb el bloc d'AD anterior o posterior més proper i ajusta els timestamps perquè quedin contínues. "
|
| 264 |
+
"Prefereix frases concises i accionables, prioritzant la informació visual essencial, i elimina redundàncies amb AD anteriors o amb els diàlegs."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Concatenar la entrada para el LLM
|
| 268 |
+
input_content = f"""
|
| 269 |
+
# INFORME CRÍTICO
|
| 270 |
+
Porcentaje de Fiabilidad Anterior: {critic_report.get('reliability_percentage')}
|
| 271 |
+
Crítica Cualitativa: {critic_report.get('qualitative_critique')}
|
| 272 |
+
|
| 273 |
+
# JSON DE CONTEXTO VISUAL (Guía para la AD)
|
| 274 |
+
{json_context}
|
| 275 |
+
|
| 276 |
+
# ÚLTIMO ARCHIVO SRT GENERADO (une_ad_{iteration-1}.srt)
|
| 277 |
+
{current_srt}
|
| 278 |
+
|
| 279 |
+
REGLAS: Tu respuesta debe ser *SOLAMENTE* el contenido completo del nuevo archivo SRT (incluyendo diálogos), sin ningún comentario o explicación adicional.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
# Llamada al LLM
|
| 283 |
+
response = llm.invoke(
|
| 284 |
+
[
|
| 285 |
+
SystemMessage(content=prompt),
|
| 286 |
+
HumanMessage(content=input_content)
|
| 287 |
+
]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
output_srt = response.content
|
| 291 |
+
reflection_text = f"Reescrito en base al informe crítico: {critic_report.get('qualitative_critique', 'N/A')}"
|
| 292 |
+
|
| 293 |
+
# 2. Guardar la nueva salida
|
| 294 |
+
new_srt_path = TEMP_DIR / f"une_ad_{iteration}.srt"
|
| 295 |
+
new_srt_path.write_text(output_srt, encoding="utf-8")
|
| 296 |
+
|
| 297 |
+
# 3. Guardar el pensamiento (reflection_text)
|
| 298 |
+
(TEMP_DIR / f"thinking_{iteration}.txt").write_text(reflection_text, encoding="utf-8")
|
| 299 |
+
|
| 300 |
+
logger.info(f"Narrador: Generada la versión {iteration} del SRT en '{new_srt_path}'.")
|
| 301 |
+
|
| 302 |
+
# 4. Actualizar el estado
|
| 303 |
+
new_history = history + [AIMessage(content=f"Narrador v{iteration} completado. Razón de reflexión: {reflection_text}")]
|
| 304 |
+
return {
|
| 305 |
+
"iteration": iteration,
|
| 306 |
+
"current_srt_path": str(new_srt_path),
|
| 307 |
+
"history": new_history,
|
| 308 |
+
"evaluation_mean": state.get("evaluation_mean", 0.0),
|
| 309 |
+
"best_iteration": state.get("best_iteration", -1),
|
| 310 |
+
"best_weighted_mean": state.get("best_weighted_mean", 0.0),
|
| 311 |
+
"best_srt_path": state.get("best_srt_path", str(new_srt_path)),
|
| 312 |
+
"best_eval_path": state.get("best_eval_path", str(TEMP_DIR / f"eval_{iteration}.csv")),
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
def critic_agent(state: ReflectionState):
|
| 316 |
+
"""
|
| 317 |
+
Agente que evalúa la calidad del SRT generado por el Narrador basándose en las Reglas UNE.
|
| 318 |
+
Devuelve una puntuación y una crítica cualitativa.
|
| 319 |
+
"""
|
| 320 |
+
iteration = state["iteration"]
|
| 321 |
+
history = state["history"]
|
| 322 |
+
current_srt = Path(state["current_srt_path"]).read_text(encoding="utf-8")
|
| 323 |
+
|
| 324 |
+
prompt = (
|
| 325 |
+
"Ets un Crític d'Audiodescripció molt estricte. La teva tasca és avaluar l'SRT adjunt "
|
| 326 |
+
"únicament segons les Regles UNE proporcionades. L'avaluació ha de ser doble: "
|
| 327 |
+
"1. **Numèrica**: Un percentatge de fiabilitat (ex. 85.5) de 0 a 100%. "
|
| 328 |
+
"2. **Qualitativa**: Una crítica constructiva sobre les principals mancances de les AD respecte a les regles. "
|
| 329 |
+
"Has de ser EXTREMADAMENT estricte amb la sincronització (sense solapament amb el diàleg), "
|
| 330 |
+
"amb l'adequació temporal (velocitat màxima recomanada d'11 caràcters per segon) i amb l'absència de redundàncies. "
|
| 331 |
+
"Comprova també que totes les audiodescripcions estan escrites en català natural."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
input_content = f"""
|
| 335 |
+
# REGLAS UNE DE AUDIODESCRIPCIÓN:
|
| 336 |
+
{UNE_RULES}
|
| 337 |
+
|
| 338 |
+
# ARCHIVO SRT A EVALUAR (une_ad_{iteration}.srt):
|
| 339 |
+
{current_srt}
|
| 340 |
+
|
| 341 |
+
REGLAS DE RESPUESTA:
|
| 342 |
+
Tu respuesta debe ser *SOLAMENTE* un objeto JSON con dos claves:
|
| 343 |
+
1. "reliability_percentage": (float) El porcentaje de fiabilidad.
|
| 344 |
+
2. "qualitative_critique": (string) La crítica cualitativa y sugerencias de mejora.
|
| 345 |
+
Ejemplo de respuesta: {{"reliability_percentage": 75.0, "qualitative_critique": "El segmento 4 se solapa 0.34s con el diálogo de Sandra. El segmento 5 es demasiado genérico y no describe bien la acción."}}
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
# Llamada al LLM
|
| 349 |
+
response = llm.invoke(
|
| 350 |
+
[
|
| 351 |
+
SystemMessage(content=prompt),
|
| 352 |
+
HumanMessage(content=input_content)
|
| 353 |
+
]
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Intentar parsear la respuesta del LLM (puede fallar, por eso se usa un try/except)
|
| 357 |
+
try:
|
| 358 |
+
cleaned_response = _strip_markdown_fences(response.content)
|
| 359 |
+
report = json.loads(cleaned_response)
|
| 360 |
+
if not isinstance(report, dict) or 'reliability_percentage' not in report:
|
| 361 |
+
raise ValueError("Estructura JSON incorrecta.")
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Error al parsear el JSON del Crítico: {e}. Respuesta: {response.content}")
|
| 364 |
+
report = {"reliability_percentage": 1.0, "qualitative_critique": "El Crítico no devolvió un JSON válido. Reintentar."}
|
| 365 |
+
|
| 366 |
+
logger.info(f"Crítico: Evaluación completada. Fiabilidad: {report.get('reliability_percentage')}%.")
|
| 367 |
+
|
| 368 |
+
mean_score, weighted_mean, eval_path = generate_evaluation_report(current_srt, iteration)
|
| 369 |
+
|
| 370 |
+
thinking_path = TEMP_DIR / f"thinking_{iteration}.txt"
|
| 371 |
+
if thinking_path.exists():
|
| 372 |
+
previous_text = thinking_path.read_text(encoding="utf-8")
|
| 373 |
+
thinking_path.write_text(
|
| 374 |
+
(
|
| 375 |
+
f"{previous_text}\n\nMitjana simple d'avaluació: {mean_score:.2f} / 7"
|
| 376 |
+
f"\nMitjana ponderada d'avaluació: {weighted_mean:.2f} / 7"
|
| 377 |
+
),
|
| 378 |
+
encoding="utf-8",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
best_iteration = state.get("best_iteration", -1)
|
| 382 |
+
best_weighted_mean = state.get("best_weighted_mean", -1.0)
|
| 383 |
+
best_srt_path = state.get("best_srt_path", state["current_srt_path"])
|
| 384 |
+
best_eval_path = state.get("best_eval_path", str(eval_path))
|
| 385 |
+
|
| 386 |
+
if weighted_mean > best_weighted_mean:
|
| 387 |
+
best_iteration = iteration
|
| 388 |
+
best_weighted_mean = weighted_mean
|
| 389 |
+
best_srt_path = state["current_srt_path"]
|
| 390 |
+
best_eval_path = str(eval_path)
|
| 391 |
+
|
| 392 |
+
new_history = history + [
|
| 393 |
+
AIMessage(
|
| 394 |
+
content=(
|
| 395 |
+
"Crítico v{iter} completado. Fiabilidad: {reliab}%. "
|
| 396 |
+
"Mitjana simple: {mean:.2f}/7. Mitjana ponderada: {wmean:.2f}/7"
|
| 397 |
+
).format(
|
| 398 |
+
iter=iteration,
|
| 399 |
+
reliab=report.get("reliability_percentage"),
|
| 400 |
+
mean=mean_score,
|
| 401 |
+
wmean=weighted_mean,
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
]
|
| 405 |
+
return {
|
| 406 |
+
"iteration": iteration + 1,
|
| 407 |
+
"critic_report": report,
|
| 408 |
+
"history": new_history,
|
| 409 |
+
"evaluation_mean": weighted_mean,
|
| 410 |
+
"best_iteration": best_iteration,
|
| 411 |
+
"best_weighted_mean": best_weighted_mean,
|
| 412 |
+
"best_srt_path": best_srt_path,
|
| 413 |
+
"best_eval_path": best_eval_path,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# --- Condición de Salida del Bucle ---
|
| 418 |
+
|
| 419 |
+
def should_continue(state: ReflectionState) -> str:
|
| 420 |
+
"""
|
| 421 |
+
Función de chequeo que decide si continuar iterando o finalizar.
|
| 422 |
+
"""
|
| 423 |
+
MAX_ITERATIONS = 5 # Número máximo de ciclos
|
| 424 |
+
MIN_AVERAGE_SCORE = 6.0 # Umbral de calidad sobre 7
|
| 425 |
+
|
| 426 |
+
iteration = state["iteration"]
|
| 427 |
+
mean_score = state.get("evaluation_mean", 0.0)
|
| 428 |
+
|
| 429 |
+
if mean_score >= MIN_AVERAGE_SCORE:
|
| 430 |
+
logger.info(f"FIN: Mitjana ponderada d'avaluació assolida ({mean_score:.2f} >= {MIN_AVERAGE_SCORE}).")
|
| 431 |
+
return "end"
|
| 432 |
+
|
| 433 |
+
if iteration >= MAX_ITERATIONS:
|
| 434 |
+
logger.info(f"FIN: S'ha assolit el màxim d'iteracions ({iteration} / {MAX_ITERATIONS}).")
|
| 435 |
+
return "end"
|
| 436 |
+
|
| 437 |
+
logger.info(f"CONTINUAR: Iteració {iteration} / {MAX_ITERATIONS}. Mitjana ponderada actual: {mean_score:.2f} / 7.")
|
| 438 |
+
return "continue"
|
| 439 |
+
|
| 440 |
+
# --- Construcción de la Gráfica ---
|
| 441 |
+
|
| 442 |
+
# 1. Configurar el estado inicial
|
| 443 |
+
initial_state: ReflectionState = {
|
| 444 |
+
"iteration": 0,
|
| 445 |
+
"current_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
|
| 446 |
+
"critic_report": {"reliability_percentage": 0.0, "qualitative_critique": "Inicializando el proceso."},
|
| 447 |
+
"history": [],
|
| 448 |
+
"evaluation_mean": 0.0,
|
| 449 |
+
"best_iteration": -1,
|
| 450 |
+
"best_weighted_mean": -1.0,
|
| 451 |
+
"best_srt_path": str(TEMP_DIR / "une_ad_0.srt"),
|
| 452 |
+
"best_eval_path": str(TEMP_DIR / "eval_0.csv"),
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
# 2. Definir la gráfica
|
| 456 |
+
workflow = StateGraph(ReflectionState)
|
| 457 |
+
|
| 458 |
+
# Nodos
|
| 459 |
+
workflow.add_node("narrator", narrator_agent)
|
| 460 |
+
workflow.add_node("critic", critic_agent)
|
| 461 |
+
|
| 462 |
+
# Estructura del bucle: Narrator -> Critic -> Check
|
| 463 |
+
workflow.set_entry_point("narrator")
|
| 464 |
+
workflow.add_edge("narrator", "critic")
|
| 465 |
+
|
| 466 |
+
# Condición (puente de ramificación)
|
| 467 |
+
workflow.add_conditional_edges(
|
| 468 |
+
"critic",
|
| 469 |
+
should_continue,
|
| 470 |
+
{
|
| 471 |
+
"continue": "narrator", # Si no se cumple el umbral/ciclo, vuelve al narrador
|
| 472 |
+
"end": END # Si se cumple, termina
|
| 473 |
+
}
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Compilar la gráfica
|
| 477 |
+
app = workflow.compile()
|
| 478 |
+
|
| 479 |
+
# --- Ejecución Principal ---
|
| 480 |
+
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
# Inicializar el entorno
|
| 483 |
+
setup_files(INITIAL_SRT_CONTENT, CONTEXT_JSON_CONTENT)
|
| 484 |
+
|
| 485 |
+
logger.info("--- Comenzando el Bucle de Reflexión ---")
|
| 486 |
+
|
| 487 |
+
# Ejecutar la gráfica
|
| 488 |
+
final_state = app.invoke(initial_state)
|
| 489 |
+
|
| 490 |
+
logger.info("\n--- Bucle Finalizado ---")
|
| 491 |
+
|
| 492 |
+
best_iteration = final_state.get("best_iteration", -1)
|
| 493 |
+
best_weighted_mean = final_state.get("best_weighted_mean", 0.0)
|
| 494 |
+
best_srt_path = Path(final_state.get("best_srt_path", final_state['current_srt_path']))
|
| 495 |
+
best_eval_path = Path(final_state.get("best_eval_path", TEMP_DIR / "eval_0.csv"))
|
| 496 |
+
|
| 497 |
+
final_srt_path = TEMP_DIR / "une_ad.srt"
|
| 498 |
+
final_eval_path = TEMP_DIR / "eval.csv"
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
shutil.copy(best_srt_path, final_srt_path)
|
| 502 |
+
logger.info(f"SRT final copiado a '{final_srt_path}'.")
|
| 503 |
+
except Exception as exc:
|
| 504 |
+
logger.error(f"No se pudo copiar el SRT final: {exc}")
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
shutil.copy(best_eval_path, final_eval_path)
|
| 508 |
+
logger.info(f"Evaluación final copiada a '{final_eval_path}'.")
|
| 509 |
+
except Exception as exc:
|
| 510 |
+
logger.error(f"No se pudo copiar el CSV final: {exc}")
|
| 511 |
+
|
| 512 |
+
# Mostrar resultados
|
| 513 |
+
print(f"Número final de ciclos: {final_state['iteration']}")
|
| 514 |
+
print(f"Iteración òptima: {best_iteration} (mitjana ponderada {best_weighted_mean:.2f}/7)")
|
| 515 |
+
print(f"Ruta al SRT final: {final_srt_path}")
|
| 516 |
+
print(f"Ruta a l'avaluació final: {final_eval_path}")
|
| 517 |
+
|
| 518 |
+
# Mostrar el SRT final generado
|
| 519 |
+
print("\n--- Contenido del SRT Final ---")
|
| 520 |
+
print(final_srt_path.read_text(encoding="utf-8"))
|
finetuning/video_analysis.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
TIME_RE = re.compile(
|
| 10 |
+
r"(?P<start>\d{2}:\d{2}:\d{2}[,\.]\d{3})\s*-->\s*(?P<end>\d{2}:\d{2}:\d{2}[,\.]\d{3})"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SRTBlock:
|
| 16 |
+
index: int
|
| 17 |
+
start: float # seconds
|
| 18 |
+
end: float # seconds
|
| 19 |
+
text: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _parse_timestamp(ts: str) -> float:
|
| 23 |
+
"""Convierte 'HH:MM:SS,mmm' o 'HH:MM:SS.mmm' a segundos (float)."""
|
| 24 |
+
ts = ts.replace(",", ".")
|
| 25 |
+
h, m, s = ts.split(":")
|
| 26 |
+
seconds, millis = (s.split("." ) + ["0"])[:2]
|
| 27 |
+
td = timedelta(
|
| 28 |
+
hours=int(h),
|
| 29 |
+
minutes=int(m),
|
| 30 |
+
seconds=int(seconds),
|
| 31 |
+
milliseconds=int(millis.ljust(3, "0")),
|
| 32 |
+
)
|
| 33 |
+
return td.total_seconds()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _parse_srt(srt_text: str) -> List[SRTBlock]:
|
| 37 |
+
"""Parsea texto SRT en una lista de bloques SRTBlock."""
|
| 38 |
+
srt_text = srt_text.replace("\r\n", "\n").replace("\r", "\n")
|
| 39 |
+
chunks = [c.strip() for c in re.split(r"\n\s*\n", srt_text) if c.strip()]
|
| 40 |
+
blocks: List[SRTBlock] = []
|
| 41 |
+
|
| 42 |
+
for chunk in chunks:
|
| 43 |
+
lines = chunk.split("\n")
|
| 44 |
+
idx_line = 0
|
| 45 |
+
index = None
|
| 46 |
+
|
| 47 |
+
if lines and lines[0].strip().isdigit():
|
| 48 |
+
index = int(lines[0].strip())
|
| 49 |
+
idx_line = 1
|
| 50 |
+
|
| 51 |
+
time_match = None
|
| 52 |
+
time_line_idx = None
|
| 53 |
+
for i in range(idx_line, min(idx_line + 3, len(lines))):
|
| 54 |
+
m = TIME_RE.search(lines[i])
|
| 55 |
+
if m:
|
| 56 |
+
time_match = m
|
| 57 |
+
time_line_idx = i
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
if not time_match or time_line_idx is None:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
start = _parse_timestamp(time_match.group("start"))
|
| 64 |
+
end = _parse_timestamp(time_match.group("end"))
|
| 65 |
+
if index is None:
|
| 66 |
+
index = len(blocks) + 1
|
| 67 |
+
|
| 68 |
+
text = "\n".join(lines[time_line_idx + 1 :]).strip()
|
| 69 |
+
blocks.append(SRTBlock(index=index, start=start, end=end, text=text))
|
| 70 |
+
|
| 71 |
+
return blocks
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def analyze_srt(
|
| 75 |
+
srt_text: str,
|
| 76 |
+
*,
|
| 77 |
+
ad_markers: Optional[List[str]] = None,
|
| 78 |
+
) -> Dict[str, Any]:
|
| 79 |
+
"""Analiza un SRT y devuelve métricas básicas.
|
| 80 |
+
|
| 81 |
+
Métricas devueltas:
|
| 82 |
+
- duration_sec: duración total estimada del vídeo (segundos)
|
| 83 |
+
- words_per_min: número de palabras por minuto
|
| 84 |
+
- speakers_blocks_per_min: número de bloques de diálogo por minuto
|
| 85 |
+
- ad_time_ratio: porcentaje (0..1) del tiempo total con bloques marcados como AD
|
| 86 |
+
- blocks_per_min: número total de bloques por minuto
|
| 87 |
+
|
| 88 |
+
Heurísticas:
|
| 89 |
+
- Se asume que la duración del vídeo es el final del último bloque.
|
| 90 |
+
- Un "bloque de AD" es aquel cuya primera línea contiene alguno de los
|
| 91 |
+
marcadores indicados en `ad_markers` (por ejemplo: "[AD]", "AD:", "(AD)").
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
blocks = _parse_srt(srt_text)
|
| 95 |
+
if not blocks:
|
| 96 |
+
return {
|
| 97 |
+
"duration_sec": 0.0,
|
| 98 |
+
"words_per_min": 0.0,
|
| 99 |
+
"speakers_blocks_per_min": 0.0,
|
| 100 |
+
"ad_time_ratio": 0.0,
|
| 101 |
+
"blocks_per_min": 0.0,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
duration_sec = max(b.end for b in blocks)
|
| 105 |
+
duration_min = max(duration_sec / 60.0, 1e-6)
|
| 106 |
+
|
| 107 |
+
# Palabras totales
|
| 108 |
+
total_words = 0
|
| 109 |
+
for b in blocks:
|
| 110 |
+
total_words += len(b.text.split())
|
| 111 |
+
|
| 112 |
+
# Bloques considerados de "hablante" (no AD)
|
| 113 |
+
if ad_markers is None:
|
| 114 |
+
ad_markers = ["[AD]", "AD:", "(AD)"]
|
| 115 |
+
|
| 116 |
+
def is_ad_block(block: SRTBlock) -> bool:
|
| 117 |
+
first_line = (block.text.splitlines() or [""])[0].strip().upper()
|
| 118 |
+
for mk in ad_markers:
|
| 119 |
+
if mk.upper() in first_line:
|
| 120 |
+
return True
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
ad_time = 0.0
|
| 124 |
+
speech_blocks = 0
|
| 125 |
+
for b in blocks:
|
| 126 |
+
if is_ad_block(b):
|
| 127 |
+
ad_time += max(0.0, b.end - b.start)
|
| 128 |
+
else:
|
| 129 |
+
speech_blocks += 1
|
| 130 |
+
|
| 131 |
+
words_per_min = total_words / duration_min
|
| 132 |
+
speakers_blocks_per_min = speech_blocks / duration_min
|
| 133 |
+
blocks_per_min = len(blocks) / duration_min
|
| 134 |
+
ad_time_ratio = ad_time / duration_sec if duration_sec > 0 else 0.0
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"duration_sec": float(duration_sec),
|
| 138 |
+
"words_per_min": float(words_per_min),
|
| 139 |
+
"speakers_blocks_per_min": float(speakers_blocks_per_min),
|
| 140 |
+
"ad_time_ratio": float(ad_time_ratio),
|
| 141 |
+
"blocks_per_min": float(blocks_per_min),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def embed_srt_sentences(
|
| 146 |
+
srt_text: str,
|
| 147 |
+
*,
|
| 148 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 149 |
+
) -> Dict[str, Any]:
|
| 150 |
+
"""Devuelve embeddings para las frases de un SRT.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
srt_text: Contenido completo del archivo SRT como string.
|
| 154 |
+
model_name: Nombre del modelo de sentence-transformers a usar.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Diccionario con:
|
| 158 |
+
- "model_name": nombre del modelo utilizado
|
| 159 |
+
- "sentences": lista de strings (una por bloque)
|
| 160 |
+
- "embeddings": lista de listas de floats con los embeddings
|
| 161 |
+
|
| 162 |
+
NOTA: Requiere instalar `sentence-transformers` y un backend de PyTorch
|
| 163 |
+
compatible. Si no está instalado, lanzará ImportError.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
blocks = _parse_srt(srt_text)
|
| 167 |
+
sentences = [b.text.replace("\n", " ").strip() for b in blocks if b.text.strip()]
|
| 168 |
+
|
| 169 |
+
if not sentences:
|
| 170 |
+
return {"model_name": model_name, "sentences": [], "embeddings": []}
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
from sentence_transformers import SentenceTransformer
|
| 174 |
+
except ImportError as exc:
|
| 175 |
+
raise ImportError(
|
| 176 |
+
"sentence-transformers no está instalado. "
|
| 177 |
+
"Instala la dependencia para poder generar embeddings."
|
| 178 |
+
) from exc
|
| 179 |
+
|
| 180 |
+
model = SentenceTransformer(model_name)
|
| 181 |
+
embs = model.encode(sentences, convert_to_numpy=False)
|
| 182 |
+
|
| 183 |
+
embeddings = [list(map(float, vec)) for vec in embs]
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
"model_name": model_name,
|
| 187 |
+
"sentences": sentences,
|
| 188 |
+
"embeddings": embeddings,
|
| 189 |
+
}
|
storage/pending_videos_routers.py
CHANGED
|
@@ -1,244 +1,244 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import io
|
| 3 |
-
import shutil
|
| 4 |
-
|
| 5 |
-
import sqlite3
|
| 6 |
-
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
from fastapi import APIRouter, UploadFile, File, Query, HTTPException
|
| 10 |
-
from fastapi.responses import FileResponse, JSONResponse
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from storage.files.file_manager import FileManager
|
| 14 |
-
from storage.common import validate_token
|
| 15 |
-
|
| 16 |
-
router = APIRouter(prefix="/
|
| 17 |
-
MEDIA_ROOT = Path("/data/
|
| 18 |
-
file_manager = FileManager(MEDIA_ROOT)
|
| 19 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@router.delete("/clear_pending_videos", tags=["Pending Videos Manager"])
|
| 23 |
-
def clear_media(token: str = Query(..., description="Token required for authorization")):
|
| 24 |
-
"""
|
| 25 |
-
Delete all contents of the /data/
|
| 26 |
-
Steps:
|
| 27 |
-
- Validate the token.
|
| 28 |
-
- Ensure the folder exists.
|
| 29 |
-
- Delete all files and subfolders inside /data/
|
| 30 |
-
- Return a JSON response confirming the deletion.
|
| 31 |
-
Warning: This will remove all stored videos, clips, and cast CSV files.
|
| 32 |
-
"""
|
| 33 |
-
validate_token(token)
|
| 34 |
-
|
| 35 |
-
if not MEDIA_ROOT.exists() or not MEDIA_ROOT.is_dir():
|
| 36 |
-
raise HTTPException(status_code=404, detail="/data/
|
| 37 |
-
|
| 38 |
-
# Delete contents
|
| 39 |
-
for item in MEDIA_ROOT.iterdir():
|
| 40 |
-
try:
|
| 41 |
-
if item.is_dir():
|
| 42 |
-
shutil.rmtree(item)
|
| 43 |
-
else:
|
| 44 |
-
item.unlink()
|
| 45 |
-
except Exception as e:
|
| 46 |
-
raise HTTPException(status_code=500, detail=f"Failed to delete {item}: {e}")
|
| 47 |
-
|
| 48 |
-
return {"status": "ok", "message": "All
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
@router.delete("/clear_pending_video", tags=["Pending Videos Manager"])
|
| 52 |
-
def clear_pending_video(
|
| 53 |
-
sha1: str = Query(..., description="SHA1 folder to delete inside pending_videos"),
|
| 54 |
-
token: str = Query(..., description="Token required for authorization")
|
| 55 |
-
):
|
| 56 |
-
"""
|
| 57 |
-
Delete a specific SHA1 folder inside /data/pending_videos.
|
| 58 |
-
Steps:
|
| 59 |
-
- Validate the token.
|
| 60 |
-
- Ensure the folder exists.
|
| 61 |
-
- Delete the folder and all its contents.
|
| 62 |
-
- Return a JSON response confirming the deletion.
|
| 63 |
-
"""
|
| 64 |
-
validate_token(token)
|
| 65 |
-
|
| 66 |
-
PENDING_ROOT = Path("/data/pending_videos")
|
| 67 |
-
target_folder = PENDING_ROOT / sha1
|
| 68 |
-
|
| 69 |
-
if not target_folder.exists() or not target_folder.is_dir():
|
| 70 |
-
raise HTTPException(status_code=404, detail=f"Folder {sha1} does not exist in pending_videos")
|
| 71 |
-
|
| 72 |
-
try:
|
| 73 |
-
shutil.rmtree(target_folder)
|
| 74 |
-
except Exception as e:
|
| 75 |
-
raise HTTPException(status_code=500, detail=f"Failed to delete {sha1}: {e}")
|
| 76 |
-
|
| 77 |
-
return {"status": "ok", "message": f"Pending video folder {sha1} deleted successfully"}
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
@router.post("/upload_pending_video", tags=["Pending Videos Manager"])
|
| 81 |
-
async def upload_video(
|
| 82 |
-
video: UploadFile = File(...),
|
| 83 |
-
token: str = Query(..., description="Token required for authorization")
|
| 84 |
-
):
|
| 85 |
-
"""
|
| 86 |
-
Saves an uploaded video by hashing it with SHA1 and placing it under:
|
| 87 |
-
/data/media/<sha1>/<original_filename>
|
| 88 |
-
Behavior:
|
| 89 |
-
- Compute SHA1 of the uploaded video.
|
| 90 |
-
- Ensure folder structure exists.
|
| 91 |
-
- Delete any existing .mp4 files under sha1.
|
| 92 |
-
- Save the uploaded video in the folder.
|
| 93 |
-
"""
|
| 94 |
-
# Read content into memory (needed to compute hash twice)
|
| 95 |
-
file_bytes = await video.read()
|
| 96 |
-
|
| 97 |
-
# Create an in-memory file handler for hashing
|
| 98 |
-
file_handler = io.BytesIO(file_bytes)
|
| 99 |
-
|
| 100 |
-
# Compute SHA1
|
| 101 |
-
try:
|
| 102 |
-
sha1 = file_manager.compute_sha1(file_handler)
|
| 103 |
-
except Exception as exc:
|
| 104 |
-
raise HTTPException(status_code=500, detail=f"SHA1 computation failed: {exc}")
|
| 105 |
-
|
| 106 |
-
# Ensure /data/media exists
|
| 107 |
-
MEDIA_ROOT.mkdir(parents=True, exist_ok=True)
|
| 108 |
-
|
| 109 |
-
# Path: /data/media/<sha1>
|
| 110 |
-
video_root = MEDIA_ROOT / sha1
|
| 111 |
-
video_root.mkdir(parents=True, exist_ok=True)
|
| 112 |
-
|
| 113 |
-
# Delete old MP4 files
|
| 114 |
-
try:
|
| 115 |
-
for old_mp4 in video_root.glob("*.mp4"):
|
| 116 |
-
old_mp4.unlink()
|
| 117 |
-
except Exception as exc:
|
| 118 |
-
raise HTTPException(status_code=500, detail=f"Failed to delete old videos: {exc}")
|
| 119 |
-
|
| 120 |
-
# Save new video path
|
| 121 |
-
final_path = video_root / video.filename
|
| 122 |
-
|
| 123 |
-
# Save file
|
| 124 |
-
save_result = file_manager.upload_file(io.BytesIO(file_bytes), final_path)
|
| 125 |
-
|
| 126 |
-
if not save_result["operation_success"]:
|
| 127 |
-
raise HTTPException(status_code=500, detail=save_result["error"])
|
| 128 |
-
|
| 129 |
-
return JSONResponse(
|
| 130 |
-
status_code=200,
|
| 131 |
-
content={
|
| 132 |
-
"status": "ok",
|
| 133 |
-
"sha1": sha1,
|
| 134 |
-
"saved_to": str(final_path)
|
| 135 |
-
}
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
@router.get("/download_pending_video", tags=["Pending Videos Manager"])
|
| 140 |
-
def download_video(
|
| 141 |
-
sha1: str,
|
| 142 |
-
token: str = Query(..., description="Token required for authorization")
|
| 143 |
-
):
|
| 144 |
-
"""
|
| 145 |
-
Download a stored video by its SHA-1 directory name.
|
| 146 |
-
This endpoint looks for a video stored under the path:
|
| 147 |
-
/data/media/<sha1>/clip/
|
| 148 |
-
and returns the first MP4 file found in that folder.
|
| 149 |
-
The method performs the following steps:
|
| 150 |
-
- Checks if the SHA-1 folder exists inside the media root.
|
| 151 |
-
- Validates that the "clip" subfolder exists.
|
| 152 |
-
- Searches for the first .mp4 file inside the clip folder.
|
| 153 |
-
- Uses the FileManager.get_file method to ensure the file is accessible.
|
| 154 |
-
- Returns the video directly as a FileResponse.
|
| 155 |
-
Parameters
|
| 156 |
-
----------
|
| 157 |
-
sha1 : str
|
| 158 |
-
The SHA-1 hash corresponding to the directory where the video is stored.
|
| 159 |
-
Returns
|
| 160 |
-
-------
|
| 161 |
-
FileResponse
|
| 162 |
-
A streaming response containing the MP4 video.
|
| 163 |
-
Raises
|
| 164 |
-
------
|
| 165 |
-
HTTPException
|
| 166 |
-
- 404 if the SHA-1 folder does not exist.
|
| 167 |
-
- 404 if the clip folder is missing.
|
| 168 |
-
- 404 if no MP4 files are found.
|
| 169 |
-
- 404 if the file cannot be retrieved using FileManager.
|
| 170 |
-
"""
|
| 171 |
-
sha1_folder = MEDIA_ROOT / sha1
|
| 172 |
-
|
| 173 |
-
if not sha1_folder.exists() or not sha1_folder.is_dir():
|
| 174 |
-
raise HTTPException(status_code=404, detail="SHA1 folder not found")
|
| 175 |
-
|
| 176 |
-
# Find first MP4 file
|
| 177 |
-
mp4_files = list(sha1_folder.glob("*.mp4"))
|
| 178 |
-
if not mp4_files:
|
| 179 |
-
raise HTTPException(status_code=404, detail="No MP4 files found")
|
| 180 |
-
|
| 181 |
-
video_path = mp4_files[0]
|
| 182 |
-
|
| 183 |
-
# Convert to relative path for FileManager
|
| 184 |
-
relative_path = video_path.relative_to(MEDIA_ROOT)
|
| 185 |
-
|
| 186 |
-
handler = file_manager.get_file(relative_path)
|
| 187 |
-
if handler is None:
|
| 188 |
-
raise HTTPException(status_code=404, detail="Video not accessible")
|
| 189 |
-
|
| 190 |
-
handler.close()
|
| 191 |
-
|
| 192 |
-
return FileResponse(
|
| 193 |
-
path=video_path,
|
| 194 |
-
media_type="video/mp4",
|
| 195 |
-
filename=video_path.name
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
@router.get("/list_pending_videos", tags=["Pending Videos Manager"])
|
| 199 |
-
def list_all_videos(
|
| 200 |
-
token: str = Query(..., description="Token required for authorization")
|
| 201 |
-
):
|
| 202 |
-
"""
|
| 203 |
-
List all videos stored under /data/media.
|
| 204 |
-
For each SHA1 folder, the endpoint returns:
|
| 205 |
-
- sha1: folder name
|
| 206 |
-
- video_files: list of mp4 files inside /clip
|
| 207 |
-
- latest_video: the most recently modified mp4
|
| 208 |
-
- video_count: total number of mp4 files
|
| 209 |
-
Notes:
|
| 210 |
-
- Videos may not have a /clip folder.
|
| 211 |
-
- SHA1 folders without mp4 files are still returned.
|
| 212 |
-
"""
|
| 213 |
-
validate_token(token)
|
| 214 |
-
|
| 215 |
-
results = []
|
| 216 |
-
|
| 217 |
-
# If media root does not exist, return empty list
|
| 218 |
-
if not MEDIA_ROOT.exists():
|
| 219 |
-
return []
|
| 220 |
-
|
| 221 |
-
for sha1_dir in MEDIA_ROOT.iterdir():
|
| 222 |
-
if not sha1_dir.is_dir():
|
| 223 |
-
continue # skip non-folders
|
| 224 |
-
|
| 225 |
-
videos = []
|
| 226 |
-
latest_video = None
|
| 227 |
-
|
| 228 |
-
if sha1_dir.exists() and sha1_dir.is_dir():
|
| 229 |
-
mp4_files = list(sha1_dir.glob("*.mp4"))
|
| 230 |
-
|
| 231 |
-
# Sort by modification time (newest first)
|
| 232 |
-
mp4_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
|
| 233 |
-
|
| 234 |
-
videos = [f.name for f in mp4_files]
|
| 235 |
-
|
| 236 |
-
if mp4_files:
|
| 237 |
-
latest_video = mp4_files[0].name
|
| 238 |
-
|
| 239 |
-
results.append({
|
| 240 |
-
"sha1": sha1_dir.name,
|
| 241 |
-
"video_name": latest_video
|
| 242 |
-
})
|
| 243 |
-
|
| 244 |
return results
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, UploadFile, File, Query, HTTPException
|
| 10 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from storage.files.file_manager import FileManager
|
| 14 |
+
from storage.common import validate_token
|
| 15 |
+
|
| 16 |
+
router = APIRouter(prefix="/pending_videos", tags=["Pending Videos Manager"])
|
| 17 |
+
MEDIA_ROOT = Path("/data/pending_videos")
|
| 18 |
+
file_manager = FileManager(MEDIA_ROOT)
|
| 19 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@router.delete("/clear_pending_videos", tags=["Pending Videos Manager"])
|
| 23 |
+
def clear_media(token: str = Query(..., description="Token required for authorization")):
|
| 24 |
+
"""
|
| 25 |
+
Delete all contents of the /data/pending_videos folder.
|
| 26 |
+
Steps:
|
| 27 |
+
- Validate the token.
|
| 28 |
+
- Ensure the folder exists.
|
| 29 |
+
- Delete all files and subfolders inside /data/pending_videos.
|
| 30 |
+
- Return a JSON response confirming the deletion.
|
| 31 |
+
Warning: This will remove all stored videos, clips, and cast CSV files.
|
| 32 |
+
"""
|
| 33 |
+
validate_token(token)
|
| 34 |
+
|
| 35 |
+
if not MEDIA_ROOT.exists() or not MEDIA_ROOT.is_dir():
|
| 36 |
+
raise HTTPException(status_code=404, detail="/data/pending_videos folder does not exist")
|
| 37 |
+
|
| 38 |
+
# Delete contents
|
| 39 |
+
for item in MEDIA_ROOT.iterdir():
|
| 40 |
+
try:
|
| 41 |
+
if item.is_dir():
|
| 42 |
+
shutil.rmtree(item)
|
| 43 |
+
else:
|
| 44 |
+
item.unlink()
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete {item}: {e}")
|
| 47 |
+
|
| 48 |
+
return {"status": "ok", "message": "All pending_videos files deleted successfully"}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.delete("/clear_pending_video", tags=["Pending Videos Manager"])
|
| 52 |
+
def clear_pending_video(
|
| 53 |
+
sha1: str = Query(..., description="SHA1 folder to delete inside pending_videos"),
|
| 54 |
+
token: str = Query(..., description="Token required for authorization")
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Delete a specific SHA1 folder inside /data/pending_videos.
|
| 58 |
+
Steps:
|
| 59 |
+
- Validate the token.
|
| 60 |
+
- Ensure the folder exists.
|
| 61 |
+
- Delete the folder and all its contents.
|
| 62 |
+
- Return a JSON response confirming the deletion.
|
| 63 |
+
"""
|
| 64 |
+
validate_token(token)
|
| 65 |
+
|
| 66 |
+
PENDING_ROOT = Path("/data/pending_videos")
|
| 67 |
+
target_folder = PENDING_ROOT / sha1
|
| 68 |
+
|
| 69 |
+
if not target_folder.exists() or not target_folder.is_dir():
|
| 70 |
+
raise HTTPException(status_code=404, detail=f"Folder {sha1} does not exist in pending_videos")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
shutil.rmtree(target_folder)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete {sha1}: {e}")
|
| 76 |
+
|
| 77 |
+
return {"status": "ok", "message": f"Pending video folder {sha1} deleted successfully"}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@router.post("/upload_pending_video", tags=["Pending Videos Manager"])
|
| 81 |
+
async def upload_video(
|
| 82 |
+
video: UploadFile = File(...),
|
| 83 |
+
token: str = Query(..., description="Token required for authorization")
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
Saves an uploaded video by hashing it with SHA1 and placing it under:
|
| 87 |
+
/data/media/<sha1>/<original_filename>
|
| 88 |
+
Behavior:
|
| 89 |
+
- Compute SHA1 of the uploaded video.
|
| 90 |
+
- Ensure folder structure exists.
|
| 91 |
+
- Delete any existing .mp4 files under sha1.
|
| 92 |
+
- Save the uploaded video in the folder.
|
| 93 |
+
"""
|
| 94 |
+
# Read content into memory (needed to compute hash twice)
|
| 95 |
+
file_bytes = await video.read()
|
| 96 |
+
|
| 97 |
+
# Create an in-memory file handler for hashing
|
| 98 |
+
file_handler = io.BytesIO(file_bytes)
|
| 99 |
+
|
| 100 |
+
# Compute SHA1
|
| 101 |
+
try:
|
| 102 |
+
sha1 = file_manager.compute_sha1(file_handler)
|
| 103 |
+
except Exception as exc:
|
| 104 |
+
raise HTTPException(status_code=500, detail=f"SHA1 computation failed: {exc}")
|
| 105 |
+
|
| 106 |
+
# Ensure /data/media exists
|
| 107 |
+
MEDIA_ROOT.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
# Path: /data/media/<sha1>
|
| 110 |
+
video_root = MEDIA_ROOT / sha1
|
| 111 |
+
video_root.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
# Delete old MP4 files
|
| 114 |
+
try:
|
| 115 |
+
for old_mp4 in video_root.glob("*.mp4"):
|
| 116 |
+
old_mp4.unlink()
|
| 117 |
+
except Exception as exc:
|
| 118 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete old videos: {exc}")
|
| 119 |
+
|
| 120 |
+
# Save new video path
|
| 121 |
+
final_path = video_root / video.filename
|
| 122 |
+
|
| 123 |
+
# Save file
|
| 124 |
+
save_result = file_manager.upload_file(io.BytesIO(file_bytes), final_path)
|
| 125 |
+
|
| 126 |
+
if not save_result["operation_success"]:
|
| 127 |
+
raise HTTPException(status_code=500, detail=save_result["error"])
|
| 128 |
+
|
| 129 |
+
return JSONResponse(
|
| 130 |
+
status_code=200,
|
| 131 |
+
content={
|
| 132 |
+
"status": "ok",
|
| 133 |
+
"sha1": sha1,
|
| 134 |
+
"saved_to": str(final_path)
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@router.get("/download_pending_video", tags=["Pending Videos Manager"])
|
| 140 |
+
def download_video(
|
| 141 |
+
sha1: str,
|
| 142 |
+
token: str = Query(..., description="Token required for authorization")
|
| 143 |
+
):
|
| 144 |
+
"""
|
| 145 |
+
Download a stored video by its SHA-1 directory name.
|
| 146 |
+
This endpoint looks for a video stored under the path:
|
| 147 |
+
/data/media/<sha1>/clip/
|
| 148 |
+
and returns the first MP4 file found in that folder.
|
| 149 |
+
The method performs the following steps:
|
| 150 |
+
- Checks if the SHA-1 folder exists inside the media root.
|
| 151 |
+
- Validates that the "clip" subfolder exists.
|
| 152 |
+
- Searches for the first .mp4 file inside the clip folder.
|
| 153 |
+
- Uses the FileManager.get_file method to ensure the file is accessible.
|
| 154 |
+
- Returns the video directly as a FileResponse.
|
| 155 |
+
Parameters
|
| 156 |
+
----------
|
| 157 |
+
sha1 : str
|
| 158 |
+
The SHA-1 hash corresponding to the directory where the video is stored.
|
| 159 |
+
Returns
|
| 160 |
+
-------
|
| 161 |
+
FileResponse
|
| 162 |
+
A streaming response containing the MP4 video.
|
| 163 |
+
Raises
|
| 164 |
+
------
|
| 165 |
+
HTTPException
|
| 166 |
+
- 404 if the SHA-1 folder does not exist.
|
| 167 |
+
- 404 if the clip folder is missing.
|
| 168 |
+
- 404 if no MP4 files are found.
|
| 169 |
+
- 404 if the file cannot be retrieved using FileManager.
|
| 170 |
+
"""
|
| 171 |
+
sha1_folder = MEDIA_ROOT / sha1
|
| 172 |
+
|
| 173 |
+
if not sha1_folder.exists() or not sha1_folder.is_dir():
|
| 174 |
+
raise HTTPException(status_code=404, detail="SHA1 folder not found")
|
| 175 |
+
|
| 176 |
+
# Find first MP4 file
|
| 177 |
+
mp4_files = list(sha1_folder.glob("*.mp4"))
|
| 178 |
+
if not mp4_files:
|
| 179 |
+
raise HTTPException(status_code=404, detail="No MP4 files found")
|
| 180 |
+
|
| 181 |
+
video_path = mp4_files[0]
|
| 182 |
+
|
| 183 |
+
# Convert to relative path for FileManager
|
| 184 |
+
relative_path = video_path.relative_to(MEDIA_ROOT)
|
| 185 |
+
|
| 186 |
+
handler = file_manager.get_file(relative_path)
|
| 187 |
+
if handler is None:
|
| 188 |
+
raise HTTPException(status_code=404, detail="Video not accessible")
|
| 189 |
+
|
| 190 |
+
handler.close()
|
| 191 |
+
|
| 192 |
+
return FileResponse(
|
| 193 |
+
path=video_path,
|
| 194 |
+
media_type="video/mp4",
|
| 195 |
+
filename=video_path.name
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
@router.get("/list_pending_videos", tags=["Pending Videos Manager"])
|
| 199 |
+
def list_all_videos(
|
| 200 |
+
token: str = Query(..., description="Token required for authorization")
|
| 201 |
+
):
|
| 202 |
+
"""
|
| 203 |
+
List all videos stored under /data/media.
|
| 204 |
+
For each SHA1 folder, the endpoint returns:
|
| 205 |
+
- sha1: folder name
|
| 206 |
+
- video_files: list of mp4 files inside /clip
|
| 207 |
+
- latest_video: the most recently modified mp4
|
| 208 |
+
- video_count: total number of mp4 files
|
| 209 |
+
Notes:
|
| 210 |
+
- Videos may not have a /clip folder.
|
| 211 |
+
- SHA1 folders without mp4 files are still returned.
|
| 212 |
+
"""
|
| 213 |
+
validate_token(token)
|
| 214 |
+
|
| 215 |
+
results = []
|
| 216 |
+
|
| 217 |
+
# If media root does not exist, return empty list
|
| 218 |
+
if not MEDIA_ROOT.exists():
|
| 219 |
+
return []
|
| 220 |
+
|
| 221 |
+
for sha1_dir in MEDIA_ROOT.iterdir():
|
| 222 |
+
if not sha1_dir.is_dir():
|
| 223 |
+
continue # skip non-folders
|
| 224 |
+
|
| 225 |
+
videos = []
|
| 226 |
+
latest_video = None
|
| 227 |
+
|
| 228 |
+
if sha1_dir.exists() and sha1_dir.is_dir():
|
| 229 |
+
mp4_files = list(sha1_dir.glob("*.mp4"))
|
| 230 |
+
|
| 231 |
+
# Sort by modification time (newest first)
|
| 232 |
+
mp4_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
|
| 233 |
+
|
| 234 |
+
videos = [f.name for f in mp4_files]
|
| 235 |
+
|
| 236 |
+
if mp4_files:
|
| 237 |
+
latest_video = mp4_files[0].name
|
| 238 |
+
|
| 239 |
+
results.append({
|
| 240 |
+
"sha1": sha1_dir.name,
|
| 241 |
+
"video_name": latest_video
|
| 242 |
+
})
|
| 243 |
+
|
| 244 |
return results
|