engine / api_client.py
VeuReu's picture
Update api_client.py
ac342d0 verified
raw
history blame
12 kB
# api_client.py (UI - Space "veureu")
import os
import io
import base64
import zipfile
import requests
from typing import Iterable, Dict, Any
class APIClient:
"""
High-level client for communicating with the Veureu Engine API.
Endpoints managed:
POST /jobs
→ {"job_id": "..."}
GET /jobs/{job_id}/status
→ {"status": "queued|processing|done|failed", ...}
GET /jobs/{job_id}/result
→ JobResult such as {"book": {...}, "une": {...}, ...}
This class is used by the Streamlit UI to submit videos, poll job status,
retrieve results, generate audio, and interact with the TTS and casting services.
"""
def __init__(
self,
base_url: str,
use_mock: bool = False,
data_dir: str | None = None,
token: str | None = None,
timeout: int = 180
):
"""
Initialize the API client.
Args:
base_url: Base URL of the engine or TTS service.
use_mock: Whether to respond with mock data instead of real API calls.
data_dir: Optional data folder for local mock/test files.
token: Authentication token (fallback: API_SHARED_TOKEN env var).
timeout: Timeout in seconds for requests.
"""
self.base_url = base_url.rstrip("/")
self.tts_url = self.base_url # For HF Spaces, TTS lives at same base URL
self.use_mock = use_mock
self.data_dir = data_dir
self.timeout = timeout
self.session = requests.Session()
# Authorization header if token provided
token = token or os.getenv("API_SHARED_TOKEN")
if token:
self.session.headers.update({"Authorization": f"Bearer {token}"})
# -------------------------------------------------------------------------
# Internal engine calls
# -------------------------------------------------------------------------
def _post_jobs(self, video_path: str, modes: Iterable[str]) -> Dict[str, Any]:
"""Submit a video and processing modes to /jobs."""
url = f"{self.base_url}/jobs"
files = {
"file": (os.path.basename(video_path), open(video_path, "rb"), "application/octet-stream")
}
data = {"modes": ",".join(modes)}
r = self.session.post(url, files=files, data=data, timeout=self.timeout)
r.raise_for_status()
return r.json()
def _get_status(self, job_id: str) -> Dict[str, Any]:
"""Query job status."""
url = f"{self.base_url}/jobs/{job_id}/status"
r = self.session.get(url, timeout=self.timeout)
r.raise_for_status()
return r.json()
def _get_result(self, job_id: str) -> Dict[str, Any]:
"""Retrieve job result."""
url = f"{self.base_url}/jobs/{job_id}/result"
r = self.session.get(url, timeout=self.timeout)
r.raise_for_status()
return r.json()
# -------------------------------------------------------------------------
# Public API used by streamlit_app.py
# -------------------------------------------------------------------------
def process_video(self, video_path: str, modes: Iterable[str]) -> Dict[str, Any]:
"""Return {"job_id": "..."} either from mock or engine."""
if self.use_mock:
return {"job_id": "mock-123"}
return self._post_jobs(video_path, modes)
def get_job(self, job_id: str) -> Dict[str, Any]:
"""
Returns UI-friendly job data:
{"status": "done", "results": {"book": {...}, "une": {...}}}
Maps engine responses into the expected 'results' format.
"""
if self.use_mock:
return {
"status": "done",
"results": {
"book": {"text": "Example text (book)", "mp3_bytes": b""},
"une": {
"srt": "1\n00:00:00,000 --> 00:00:01,000\nExample UNE\n",
"mp3_bytes": b""
}
}
}
status_data = self._get_status(job_id)
# If still processing, return minimal structure
if status_data.get("status") in {"queued", "processing"}:
return {"status": status_data.get("status", "queued")}
raw_result = self._get_result(job_id)
results = {}
# Direct mapping of book/une sections
if "book" in raw_result:
results["book"] = {"text": raw_result["book"].get("text")}
if "une" in raw_result:
results["une"] = {"srt": raw_result["une"].get("srt")}
# Preserve characters/metrics if present
for section in ("book", "une"):
if section in raw_result:
if "characters" in raw_result[section]:
results[section]["characters"] = raw_result[section]["characters"]
if "metrics" in raw_result[section]:
results[section]["metrics"] = raw_result[section]["metrics"]
final_status = "done" if results else status_data.get("status", "unknown")
return {"status": final_status, "results": results}
# -------------------------------------------------------------------------
# TTS Services
# -------------------------------------------------------------------------
def tts_matxa(self, text: str, voice: str = "central/grau") -> dict:
"""
Call the TTS /tts/text endpoint to synthesize short audio.
Returns:
{"mp3_bytes": b"..."} on success
{"error": "..."} on failure
"""
if not self.tts_url:
raise ValueError("TTS service URL not configured.")
url = f"{self.tts_url.rstrip('/')}/tts/text"
data = {"texto": text, "voice": voice, "formato": "mp3"}
try:
r = requests.post(url, data=data, timeout=self.timeout)
r.raise_for_status()
return {"mp3_bytes": r.content}
except requests.exceptions.RequestException as e:
return {"error": str(e)}
def rebuild_video_with_ad(self, video_path: str, srt_path: str) -> dict:
"""
Rebuild a video including audio description (AD)
by calling /tts/srt. The server returns a ZIP containing an MP4.
"""
if not self.tts_url:
raise ValueError("TTS service URL not configured.")
url = f"{self.tts_url.rstrip('/')}/tts/srt"
try:
files = {
"video": (os.path.basename(video_path), open(video_path, "rb"), "video/mp4"),
"srt": (os.path.basename(srt_path), open(srt_path, "rb"), "application/x-subrip")
}
data = {"include_final_mp4": 1}
r = requests.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
for name in z.namelist():
if name.endswith(".mp4"):
return {"video_bytes": z.read(name)}
return {"error": "MP4 file not found inside ZIP."}
except zipfile.BadZipFile:
return {"error": "Invalid ZIP response from server."}
except requests.exceptions.RequestException as e:
return {"error": str(e)}
# -------------------------------------------------------------------------
# Engine casting services
# -------------------------------------------------------------------------
def create_initial_casting(
self,
video_path: str = None,
video_bytes: bytes = None,
video_name: str = None,
epsilon: float = 0.5,
min_cluster_size: int = 2
) -> dict:
"""
Calls /create_initial_casting to produce the initial actor/face clustering.
Args:
video_path: Load video from disk.
video_bytes: Provide video already in memory.
video_name: Name used if video_bytes is provided.
epsilon: DBSCAN epsilon for clustering.
min_cluster_size: Minimum number of samples for DBSCAN.
"""
url = f"{self.base_url}/create_initial_casting"
try:
# Prepare video input
if video_bytes:
files = {"video": (video_name or "video.mp4", video_bytes, "video/mp4")}
elif video_path:
with open(video_path, "rb") as f:
files = {"video": (os.path.basename(video_path), f.read(), "video/mp4")}
else:
return {"error": "Either video_path or video_bytes must be provided."}
data = {
"epsilon": str(epsilon),
"min_cluster_size": str(min_cluster_size)
}
r = self.session.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
if r.headers.get("content-type", "").startswith("application/json"):
return r.json()
return {"ok": True}
except Exception as e:
return {"error": str(e)}
# -------------------------------------------------------------------------
# Long text TTS helpers
# -------------------------------------------------------------------------
def generate_audio_from_text_file(self, text_content: str, voice: str = "central/grau") -> dict:
"""
Converts a large text into an SRT-like structure, calls /tts/srt,
and extracts 'ad_master.mp3' from the resulting ZIP.
Useful for audiobook-like generation.
"""
if not self.tts_url:
raise ValueError("TTS service URL not configured.")
# Build synthetic SRT in memory
srt_content = ""
start = 0
for idx, raw_line in enumerate(text_content.strip().split("\n")):
line = raw_line.strip()
if not line:
continue
end = start + 5 # simplistic 5 seconds per subtitle
def fmt(seconds):
h = seconds // 3600
m = (seconds % 3600) // 60
s = seconds % 60
return f"{h:02d}:{m:02d}:{s:02d},000"
srt_content += f"{idx+1}\n"
srt_content += f"{fmt(start)} --> {fmt(end)}\n"
srt_content += f"{line}\n\n"
start = end
if not srt_content:
return {"error": "Provided text is empty or cannot be processed."}
# Call server
url = f"{self.tts_url.rstrip('/')}/tts/srt"
try:
files = {"srt": ("fake_ad.srt", srt_content, "application/x-subrip")}
data = {"voice": voice, "ad_format": "mp3"}
r = requests.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
if "ad_master.mp3" in z.namelist():
return {"mp3_bytes": z.read("ad_master.mp3")}
return {"error": "'ad_master.mp3' not found inside ZIP."}
except requests.exceptions.RequestException as e:
return {"error": f"Error calling SRT API: {e}"}
except zipfile.BadZipFile:
return {"error": "Invalid ZIP response from server."}
def tts_long_text(self, text: str, voice: str = "central/grau") -> dict:
"""
Call /tts/text_long for very long text TTS synthesis.
Returns raw MP3 bytes.
"""
if not self.tts_url:
raise ValueError("TTS service URL not configured.")
url = f"{self.tts_url.rstrip('/')}/tts/text_long"
data = {"texto": text, "voice": voice, "formato": "mp3"}
try:
r = requests.post(url, data=data, timeout=self.timeout * 10)
r.raise_for_status()
return {"mp3_bytes": r.content}
except requests.exceptions.RequestException as e:
return {"error": str(e)}