File size: 5,043 Bytes
31d4d14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
from pathlib import Path
from typing import Optional
import yaml
from .reflection import refine_srt_with_reflection, refine_video_with_reflection
from .reflection_ma import refine_srt_with_reflection_ma, refine_video_with_reflection_ma
from .reflexion import refine_srt_with_reflexion
from .introspection import refine_srt_with_introspection
def _load_refinement_flags(config_path: Optional[Path] = None) -> dict:
"""Carga los flags de refinamiento desde config.yaml.
Por defecto usa demo/config.yaml porque ahí están definidos los parámetros
`refinement.reflection_enabled`, `refinement.reflexion_enabled` e
`refinement.introspection_enabled`.
"""
if config_path is None:
# Raíz del repo: .../hf_spaces
root = Path(__file__).resolve().parents[2]
config_path = root / "demo" / "config.yaml"
flags = {
"reflection_enabled": True,
"reflexion_enabled": False,
"introspection_enabled": False,
"reflection_ma_enabled": False,
}
try:
if config_path.exists():
with config_path.open("r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
ref_cfg = cfg.get("refinement", {}) or {}
flags["reflection_enabled"] = bool(ref_cfg.get("reflection_enabled", flags["reflection_enabled"]))
flags["reflexion_enabled"] = bool(ref_cfg.get("reflexion_enabled", flags["reflexion_enabled"]))
flags["introspection_enabled"] = bool(ref_cfg.get("introspection_enabled", flags["introspection_enabled"]))
flags["reflection_ma_enabled"] = bool(ref_cfg.get("reflection_ma_enabled", flags["reflection_ma_enabled"]))
except Exception:
# Si algo falla, devolvemos los valores por defecto sin romper el flujo
pass
return flags
def execute_refinement(initial_srt: str, *, config_path: Optional[Path] = None) -> str:
"""Ejecuta el pipeline de refinamiento multi‑agente sobre un SRT.
- Lee `refinement.*` de config.yaml para decidir qué pasos aplicar.
- Aplica, en este orden, si están habilitados:
1) reflection (LangGraph principal)
2) reflexion (ajustes de longitud/filtrado de pistes AD via KNN+LLM)
3) introspection (aplicació de regles apreses de HITL via LLM)
- Devuelve el SRT final (o el original si ningún paso está activo).
"""
flags = _load_refinement_flags(config_path)
srt = initial_srt
if flags.get("reflection_ma_enabled", False):
srt = refine_srt_with_reflection_ma(srt)
elif flags.get("reflection_enabled", False):
srt = refine_srt_with_reflection(srt)
if flags.get("reflexion_enabled", False):
srt = refine_srt_with_reflexion(srt)
if flags.get("introspection_enabled", False):
srt = refine_srt_with_introspection(srt)
return srt
def execute_refinement_for_video(
sha1sum: str,
version: str,
*,
config_path: Optional[Path] = None,
) -> str:
"""Executa el pipeline de refinament per a un vídeo (sha1sum, version).
- Llegeix une_ad/json_ad/casting/scenarios des de les BDs de demo.
- Aplica, segons flags de config.yaml (o config_path):
1) reflection: via `refine_video_with_reflection(sha1sum, version)`
2) reflexion: ajustos de longitud/filtrat sobre el SRT resultat
3) introspection: aplicació de regles apreses sobre el SRT resultat
- Retorna el SRT final.
"""
flags = _load_refinement_flags(config_path)
# 1) Reflection sobre el SRT UNE/JSON de la BD (imprescindible en aquest flux)
if flags.get("reflection_ma_enabled", False):
srt = refine_video_with_reflection_ma(sha1sum, version)
elif flags.get("reflection_enabled", False):
srt = refine_video_with_reflection(sha1sum, version)
else:
# Si es desactiva reflection, intentem igualment llegir une_ad de BD com a punt de partida
from demo.databases import get_audiodescription # type: ignore
row = get_audiodescription(sha1sum, version)
if row is None or "une_ad" not in row.keys():
raise ValueError(
f"No s'ha trobat une_ad a audiodescriptions.db per sha1sum={sha1sum}, version={version}"
)
srt = row["une_ad"] or ""
# 2) Reflexion (dummy, treballa directament sobre el SRT en memòria)
if flags.get("reflexion_enabled", False):
srt = refine_srt_with_reflexion(srt)
# 3) Introspection (dummy)
if flags.get("introspection_enabled", False):
srt = refine_srt_with_introspection(srt)
return srt
if __name__ == "__main__": # Pequeña demo manual
demo_srt = """1\n00:00:00,000 --> 00:00:03,000\n(AD) Una noia entra a l'aula.\n"""
refined = execute_refinement(demo_srt)
print("=== SRT original ===")
print(demo_srt)
print("\n=== SRT refinat ===")
print(refined)
|