File size: 6,435 Bytes
ce5bf11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations

import os
import shlex
import subprocess
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional

# Minimal, robust MVP audio-only pipeline
# - Extract audio with ffmpeg
# - Diarize with pyannote (if HF token available); otherwise, fallback: single segment over full duration
# - ASR with Whisper (AINA if available optional). To keep footprint reasonable and robust,
#   we'll default to a lightweight faster-whisper if present; otherwise, return empty text.
# - Generate basic SRT from segments and ASR texts.


def extract_audio_ffmpeg(video_path: str, audio_out: Path, sr: int = 16000, mono: bool = True) -> str:
    audio_out.parent.mkdir(parents=True, exist_ok=True)
    cmd = f'ffmpeg -y -i "{video_path}" -vn {"-ac 1" if mono else ""} -ar {sr} -f wav "{audio_out}"'
    subprocess.run(shlex.split(cmd), check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    return str(audio_out)


def _get_video_duration_seconds(video_path: str) -> float:
    try:
        # Use ffprobe to get duration
        cmd = f'ffprobe -v error -select_streams v:0 -show_entries stream=duration -of default=nw=1 "{video_path}"'
        out = subprocess.check_output(shlex.split(cmd), stderr=subprocess.DEVNULL).decode("utf-8", errors="ignore")
        for line in out.splitlines():
            if line.startswith("duration="):
                try:
                    return float(line.split("=", 1)[1])
                except Exception:
                    pass
    except Exception:
        pass
    return 0.0


def diarize_audio(wav_path: str, base_dir: Path, hf_token_env: str | None = None) -> Tuple[List[Dict[str, Any]], List[str]]:
    """Returns segments [{'start','end','speaker'}] and dummy clip_paths (not used in MVP)."""
    segments: List[Dict[str, Any]] = []
    clip_paths: List[str] = []
    # Prefer PYANNOTE_TOKEN if provided; fallback to explicit env name, then HF_TOKEN
    token = os.getenv("PYANNOTE_TOKEN") or (os.getenv(hf_token_env) if hf_token_env else os.getenv("HF_TOKEN"))
    try:
        if token:
            from pyannote.audio import Pipeline  # type: ignore
            pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=token)
            diarization = pipeline(wav_path)
            # Collect segments
            # We don't export individual clips in MVP; just timestamps.
            for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
                segments.append({
                    "start": float(getattr(turn, "start", 0.0) or 0.0),
                    "end": float(getattr(turn, "end", 0.0) or 0.0),
                    "speaker": str(speaker) if speaker is not None else f"SPEAKER_{i:02d}",
                })
        else:
            # Fallback: single segment using full duration
            # Caller must provide video path to compute exact duration; as we only have wav, skip precise duration
            # and fallback to 0..0 (UI tolerates).
            segments.append({"start": 0.0, "end": 0.0, "speaker": "SPEAKER_00"})
    except Exception:
        # Robust fallback
        segments.append({"start": 0.0, "end": 0.0, "speaker": "SPEAKER_00"})
    # Sort by start
    segments = sorted(segments, key=lambda s: s.get("start", 0.0))
    return segments, clip_paths


def _fmt_srt_time(seconds: float) -> str:
    h = int(seconds // 3600)
    m = int((seconds % 3600) // 60)
    s = int(seconds % 60)
    ms = int(round((seconds - int(seconds)) * 1000))
    return f"{h:02}:{m:02}:{s:02},{ms:03}"


def _generate_srt(segments: List[Dict[str, Any]], texts: List[str]) -> str:
    n = min(len(segments), len(texts))
    lines: List[str] = []
    for i in range(n):
        seg = segments[i]
        text = (texts[i] or "").strip()
        start = float(seg.get("start", 0.0))
        end = float(seg.get("end", max(start + 2.0, start)))
        speaker = seg.get("speaker")
        if speaker:
            text = f"[{speaker}]: {text}" if text else f"[{speaker}]"
        lines.append(str(i + 1))
        lines.append(f"{_fmt_srt_time(start)} --> {_fmt_srt_time(end)}")
        lines.append(text)
        lines.append("")
    return "\n".join(lines).strip() + "\n"


def asr_transcribe_wav_simple(wav_path: str) -> str:
    """Very robust ASR stub: try faster-whisper small if present; otherwise return empty text.

    Intended for MVP in Spaces without heavy GPU. """
    try:
        from faster_whisper import WhisperModel  # type: ignore
        model = WhisperModel("Systran/faster-whisper-small", device="cpu")
        # Short transcript without timestamps
        segments, info = model.transcribe(wav_path, vad_filter=True, without_timestamps=True, language=None)
        text = " ".join(seg.text.strip() for seg in segments if getattr(seg, "text", None))
        return text.strip()
    except Exception:
        # As last resort, empty text
        return ""


def generate(video_path: str, out_dir: Path) -> Dict[str, Any]:
    """End-to-end MVP that returns {'une_srt','free_text','artifacts':{...}}."""
    out_dir.mkdir(parents=True, exist_ok=True)
    wav_path = extract_audio_ffmpeg(video_path, out_dir / f"{Path(video_path).stem}.wav")

    # Diarization (robust)
    segments, _ = diarize_audio(wav_path, out_dir, hf_token_env="HF_TOKEN")

    # ASR (for MVP: single transcript of full audio to use as 'free_text')
    free_text = asr_transcribe_wav_simple(wav_path)

    # Build per-segment 'texts' using a simple split of free_text if we have multiple segments
    if not segments:
        segments = [{"start": 0.0, "end": 0.0, "speaker": "SPEAKER_00"}]
    texts: List[str] = []
    if len(segments) <= 1:
        texts = [free_text]
    else:
        # Naive split into N parts by words
        words = free_text.split()
        chunk = max(1, len(words) // len(segments))
        for i in range(len(segments)):
            start_idx = i * chunk
            end_idx = (i + 1) * chunk if i < len(segments) - 1 else len(words)
            texts.append(" ".join(words[start_idx:end_idx]))

    une_srt = _generate_srt(segments, texts)

    return {
        "une_srt": une_srt,
        "free_text": free_text,
        "artifacts": {
            "wav_path": str(wav_path),
        },
    }