engine / refinement /multiagent_refinement.py
VeuReu's picture
Upload 10 files
31d4d14 verified
from __future__ import annotations
from pathlib import Path
from typing import Optional
import yaml
from .reflection import refine_srt_with_reflection, refine_video_with_reflection
from .reflection_ma import refine_srt_with_reflection_ma, refine_video_with_reflection_ma
from .reflexion import refine_srt_with_reflexion
from .introspection import refine_srt_with_introspection
def _load_refinement_flags(config_path: Optional[Path] = None) -> dict:
"""Carga los flags de refinamiento desde config.yaml.
Por defecto usa demo/config.yaml porque ahí están definidos los parámetros
`refinement.reflection_enabled`, `refinement.reflexion_enabled` e
`refinement.introspection_enabled`.
"""
if config_path is None:
# Raíz del repo: .../hf_spaces
root = Path(__file__).resolve().parents[2]
config_path = root / "demo" / "config.yaml"
flags = {
"reflection_enabled": True,
"reflexion_enabled": False,
"introspection_enabled": False,
"reflection_ma_enabled": False,
}
try:
if config_path.exists():
with config_path.open("r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
ref_cfg = cfg.get("refinement", {}) or {}
flags["reflection_enabled"] = bool(ref_cfg.get("reflection_enabled", flags["reflection_enabled"]))
flags["reflexion_enabled"] = bool(ref_cfg.get("reflexion_enabled", flags["reflexion_enabled"]))
flags["introspection_enabled"] = bool(ref_cfg.get("introspection_enabled", flags["introspection_enabled"]))
flags["reflection_ma_enabled"] = bool(ref_cfg.get("reflection_ma_enabled", flags["reflection_ma_enabled"]))
except Exception:
# Si algo falla, devolvemos los valores por defecto sin romper el flujo
pass
return flags
def execute_refinement(initial_srt: str, *, config_path: Optional[Path] = None) -> str:
"""Ejecuta el pipeline de refinamiento multi‑agente sobre un SRT.
- Lee `refinement.*` de config.yaml para decidir qué pasos aplicar.
- Aplica, en este orden, si están habilitados:
1) reflection (LangGraph principal)
2) reflexion (ajustes de longitud/filtrado de pistes AD via KNN+LLM)
3) introspection (aplicació de regles apreses de HITL via LLM)
- Devuelve el SRT final (o el original si ningún paso está activo).
"""
flags = _load_refinement_flags(config_path)
srt = initial_srt
if flags.get("reflection_ma_enabled", False):
srt = refine_srt_with_reflection_ma(srt)
elif flags.get("reflection_enabled", False):
srt = refine_srt_with_reflection(srt)
if flags.get("reflexion_enabled", False):
srt = refine_srt_with_reflexion(srt)
if flags.get("introspection_enabled", False):
srt = refine_srt_with_introspection(srt)
return srt
def execute_refinement_for_video(
sha1sum: str,
version: str,
*,
config_path: Optional[Path] = None,
) -> str:
"""Executa el pipeline de refinament per a un vídeo (sha1sum, version).
- Llegeix une_ad/json_ad/casting/scenarios des de les BDs de demo.
- Aplica, segons flags de config.yaml (o config_path):
1) reflection: via `refine_video_with_reflection(sha1sum, version)`
2) reflexion: ajustos de longitud/filtrat sobre el SRT resultat
3) introspection: aplicació de regles apreses sobre el SRT resultat
- Retorna el SRT final.
"""
flags = _load_refinement_flags(config_path)
# 1) Reflection sobre el SRT UNE/JSON de la BD (imprescindible en aquest flux)
if flags.get("reflection_ma_enabled", False):
srt = refine_video_with_reflection_ma(sha1sum, version)
elif flags.get("reflection_enabled", False):
srt = refine_video_with_reflection(sha1sum, version)
else:
# Si es desactiva reflection, intentem igualment llegir une_ad de BD com a punt de partida
from demo.databases import get_audiodescription # type: ignore
row = get_audiodescription(sha1sum, version)
if row is None or "une_ad" not in row.keys():
raise ValueError(
f"No s'ha trobat une_ad a audiodescriptions.db per sha1sum={sha1sum}, version={version}"
)
srt = row["une_ad"] or ""
# 2) Reflexion (dummy, treballa directament sobre el SRT en memòria)
if flags.get("reflexion_enabled", False):
srt = refine_srt_with_reflexion(srt)
# 3) Introspection (dummy)
if flags.get("introspection_enabled", False):
srt = refine_srt_with_introspection(srt)
return srt
if __name__ == "__main__": # Pequeña demo manual
demo_srt = """1\n00:00:00,000 --> 00:00:03,000\n(AD) Una noia entra a l'aula.\n"""
refined = execute_refinement(demo_srt)
print("=== SRT original ===")
print(demo_srt)
print("\n=== SRT refinat ===")
print(refined)