File size: 4,653 Bytes
287f01b
b17b915
 
 
287f01b
b17b915
 
 
ed7696e
b17b915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648c0b6
b17b915
 
287f01b
 
648c0b6
 
 
b17b915
648c0b6
 
 
 
 
b17b915
 
ed7696e
 
 
 
 
 
 
 
 
 
 
 
 
 
b17b915
 
 
ed7696e
 
648c0b6
 
ed7696e
 
b17b915
 
 
 
 
ed7696e
 
648c0b6
 
ed7696e
 
b17b915
 
 
 
 
ed7696e
 
648c0b6
 
ed7696e
 
b17b915
 
 
 
 
ed7696e
 
648c0b6
 
ed7696e
 
b17b915
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# llm_router.py — enruta llamadas a Spaces remotos según config.yaml
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pathlib import Path
import os
import yaml

from remote_clients import InstructClient, VisionClient, ToolsClient, ASRClient
import time

def load_yaml(path: str) -> Dict[str, Any]:
    p = Path(path)
    if not p.exists():
        return {}
    return yaml.safe_load(p.read_text(encoding="utf-8")) or {}

class LLMRouter:
    def __init__(self, cfg: Dict[str, Any]):
        self.cfg = cfg
        self.rem = cfg.get("models", {}).get("routing", {}).get("use_remote_for", [])
        base_user = cfg.get("remote_spaces", {}).get("user", "veureu")
        eps = cfg.get("remote_spaces", {}).get("endpoints", {})
        token_enabled = cfg.get("security", {}).get("use_hf_token", False)
        hf_token = os.getenv(cfg.get("security", {}).get("hf_token_env", "HF_TOKEN")) if token_enabled else None

        def mk_factory(endpoint_key: str, cls):
            info = eps.get(endpoint_key, {})
            base_url = info.get("base_url") or f"https://{base_user}-{info.get('space')}.hf.space"
            use_gradio = (info.get("client", "gradio") == "gradio")
            timeout = int(cfg.get("remote_spaces", {}).get("http", {}).get("timeout_seconds", 180))
            def _factory():
                return cls(base_url=base_url, use_gradio=use_gradio, hf_token=hf_token, timeout=timeout)
            return _factory

        self.client_factories = {
            "salamandra-instruct": mk_factory("salamandra-instruct", InstructClient),
            "salamandra-vision": mk_factory("salamandra-vision", VisionClient),
            "salamandra-tools": mk_factory("salamandra-tools", ToolsClient),
            "whisper-catalan": mk_factory("whisper-catalan", ASRClient),
        }

        self.service_names = {
            "salamandra-instruct": "schat",
            "salamandra-vision": "svision",
            "salamandra-tools": "stools",
            "whisper-catalan": "asr",
        }

    def _log_connect(self, model_key: str, phase: str, elapsed: float | None = None):
        svc = self.service_names.get(model_key, model_key)
        if phase == "connect":
            print(f"[LLMRouter] Connecting to {svc} space...")
        elif phase == "done":
            print(f"[LLMRouter] Response from {svc} space received in {elapsed:.2f} s")

    # ---- INSTRUCT ----
    def instruct(self, prompt: str, system: Optional[str] = None, model: str = "salamandra-instruct", **kwargs) -> str:
        if model in self.rem:
            self._log_connect(model, "connect")
            t0 = time.time()
            client = self.client_factories[model]()
            out = client.generate(prompt, system=system, **kwargs)  # type: ignore
            self._log_connect(model, "done", time.time() - t0)
            return out
        raise RuntimeError(f"Modelo local no implementado para: {model}")

    # ---- VISION ----
    def vision_describe(self, image_paths: List[str], context: Optional[Dict[str, Any]] = None, model: str = "salamandra-vision", **kwargs) -> List[str]:
        if model in self.rem:
            self._log_connect(model, "connect")
            t0 = time.time()
            client = self.client_factories[model]()
            out = client.describe(image_paths, context=context, **kwargs)  # type: ignore
            self._log_connect(model, "done", time.time() - t0)
            return out
        raise RuntimeError(f"Modelo local no implementado para: {model}")

    # ---- TOOLS ----
    def chat_with_tools(self, messages: List[Dict[str, str]], tools: Optional[List[Dict[str, Any]]] = None, model: str = "salamandra-tools", **kwargs) -> Dict[str, Any]:
        if model in self.rem:
            self._log_connect(model, "connect")
            t0 = time.time()
            client = self.client_factories[model]()
            out = client.chat(messages, tools=tools, **kwargs)  # type: ignore
            self._log_connect(model, "done", time.time() - t0)
            return out
        raise RuntimeError(f"Modelo local no implementado para: {model}")

    # ---- ASR ----
    def asr_transcribe(self, audio_path: str, model: str = "whisper-catalan", **kwargs) -> Dict[str, Any]:
        if model in self.rem:
            self._log_connect(model, "connect")
            t0 = time.time()
            client = self.client_factories[model]()
            out = client.transcribe(audio_path, **kwargs)  # type: ignore
            self._log_connect(model, "done", time.time() - t0)
            return out
        raise RuntimeError(f"Modelo local no implementado para: {model}")