import os import time from pathlib import Path from typing import List, Tuple, Dict import numpy as np import pandas as pd import gradio as gr OUTDIR = Path("outputs") OUTDIR.mkdir(parents=True, exist_ok=True) def slug(s: str) -> str: """Make a safe filename slug (ASCII, underscores).""" if s is None: s = "" return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_") def save_wav(path: Path, sr: int, audio): import numpy as np from scipy.io import wavfile as wav if hasattr(audio, "detach"): audio = audio.detach().cpu().numpy() a = np.array(audio).astype(np.float32) a = np.squeeze(a) if a.ndim == 2 and a.shape[0] < a.shape[1]: a = a.T # normalize if needed (safety) max_abs = np.max(np.abs(a)) if a.size else 1.0 if np.isfinite(max_abs) and max_abs > 1.0: a = a / max_abs wav.write(str(path), int(sr), a) MODEL_NAMES = { "suno/bark-small": "bark", "facebook/mms-tts-rus": "mms", "facebook/seamless-m4t-v2-large": "seamless", } _model_cache: Dict[str, object] = {} _device_hint = "auto" def _load_bark(): import torch from transformers import pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" pipe = pipeline( task="text-to-speech", model="suno/bark-small", device=device, model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} ) if getattr(pipe.model.config, "pad_token_id", None) is None: pipe.model.config.pad_token_id = pipe.model.config.eos_token_id def generate(text: str): out = pipe(text) return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32) return generate def _load_mms(): import torch from transformers import pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" pipe = pipeline( "text-to-speech", model="facebook/mms-tts-rus", device=device, model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} ) if getattr(pipe.model.config, "pad_token_id", None) is None: pipe.model.config.pad_token_id = pipe.model.config.eos_token_id def generate(text: str): out = pipe(text) return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32) return generate def _load_seamless(): import torch import numpy as np from transformers import AutoProcessor from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import SeamlessM4Tv2Model device = "cuda" if torch.cuda.is_available() else "cpu" proc = AutoProcessor.from_pretrained( "facebook/seamless-m4t-v2-large", use_fast=False ) model = SeamlessM4Tv2Model.from_pretrained( "facebook/seamless-m4t-v2-large", low_cpu_mem_usage=False ).to(device) def generate(text: str): inputs = proc(text=text, src_lang="rus", return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): audio = model.generate(**inputs, tgt_lang="rus")[0] audio = audio.detach().cpu().numpy().squeeze().astype(np.float32) return 16000, audio return generate def get_generator(kind: str): if kind in _model_cache: return _model_cache[kind] if kind == "bark": gen = _load_bark() elif kind == "mms": gen = _load_mms() elif kind == "seamless": gen = _load_seamless() else: raise ValueError(f"Unknown model kind: {kind}") _model_cache[kind] = gen return gen DEFAULT_PROMPTS = ( "Привет! Это короткий тест русского TTS.\n" "Сегодня мы проверяем интонации, паузы и четкость дикции.\n" "Немного сложнее: числа 3.14 и 2025 читаем правильно." ) def run_tts( prompts_text: str, split_lines: bool, model_choice: str, ): """Main Gradio callback: TTS. Returns: files: list[str] — пути к wav df: pd.DataFrame — таблица метаданных last_audio: str | None — путь к последнему файлу для предпросмотра """ text_items: List[str] = [] if split_lines: for line in [s.strip() for s in prompts_text.splitlines()]: if line: text_items.append(line) else: text_items = [prompts_text.strip()] if prompts_text.strip() else [] if not text_items: return [], pd.DataFrame(), None kind = MODEL_NAMES[model_choice] gen = get_generator(kind) stamp_dir = OUTDIR / "tts" / time.strftime("%Y%m%d-%H%M%S") stamp_dir.mkdir(parents=True, exist_ok=True) rows = [] file_paths: List[str] = [] last_audio_path = None for p in text_items: t0 = time.time() sr, audio = gen(p) dt = time.time() - t0 path = stamp_dir / f"{slug(model_choice)}__{slug(p)}.wav" save_wav(path, sr, audio) rows.append({ "task": "tts", "model": model_choice, "prompt": p, "file": str(path), "sr": sr, "gen_time_s": round(dt, 3), }) file_paths.append(str(path)) last_audio_path = str(path) df = pd.DataFrame(rows) return file_paths, df, last_audio_path _music_pipes: Dict[str, object] = {} MUSIC_MODELS = [ "facebook/musicgen-small", ] def get_music_pipe(model_name: str): import torch from transformers import pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" pipe = pipeline( "text-to-audio", model=model_name, device=device, model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} ) return pipe MUSIC_DEFAULT_PROMPTS = ( "High-energy 90s rock track with distorted electric guitars, driving bass, and hard-hitting acoustic drums\n" "Modern electronic dance track with punchy kick, bright synth lead, and sidechained pads, 128 BPM\n" "Dark industrial electro with gritty bass, sharp snares, and mechanical percussion" ) def run_music( prompts_text: str, split_lines: bool, model_name: str, do_sample: bool, ): """Main Gradio callback: MusicGen.""" text_items: List[str] = [] if split_lines: for line in [s.strip() for s in prompts_text.splitlines()]: if line: text_items.append(line) else: text_items = [prompts_text.strip()] if prompts_text.strip() else [] if not text_items: return [], pd.DataFrame(), None pipe = get_music_pipe(model_name) stamp_dir = OUTDIR / "music" / slug(model_name) / time.strftime("%Y%m%d-%H%M%S") stamp_dir.mkdir(parents=True, exist_ok=True) rows = [] file_paths: List[str] = [] last_audio_path = None for p in text_items: t0 = time.time() out = pipe(p, forward_params={"do_sample": bool(do_sample)}) dt = time.time() - t0 sr = int(out["sampling_rate"]) audio = np.asarray(out["audio"], dtype=np.float32) path = stamp_dir / f"{slug(p)}.wav" save_wav(path, sr, audio) rows.append({ "task": "music", "model": model_name, "prompt": p, "file": str(path), "sr": sr, "gen_time_s": round(dt, 3), }) file_paths.append(str(path)) last_audio_path = str(path) df = pd.DataFrame(rows) return file_paths, df, last_audio_path tts_description_md = ( """ Russian TTS Bench: выберите модель и введите один или несколько промптов.\ По умолчанию каждая строка — отдельный промпт. Результаты сохраняются в `outputs/tts/…`. **Модели:** - `suno/bark-small` — небольшой мультиязычный TTS. - `facebook/mms-tts-rus` — русская TTS из проекта MMS. - `facebook/seamless-m4t-v2-large` — крупная модель перевода/говорения; тяжёлая для CPU. """ ) music_description_md = ( """ **Music Gen:** текст → музыка на базе MusicGen. По умолчанию каждая строка — отдельный промпт.\ Результаты сохраняются в `outputs/music//…`. **Модели:** - `facebook/musicgen-small` - (опционально) `facebook/musicgen-stereo-small` — раскомментируйте в коде. """ ) def run_tts_ui(prompts_text, split_lines, model_choice): files, _, last = run_tts(prompts_text, split_lines, model_choice) samples_update = gr.update(choices=files, value=(last or (files[-1] if files else None))) return files, (last or None), samples_update def run_music_ui(prompts_text, split_lines, model_name, do_sample): files, _, last = run_music(prompts_text, split_lines, model_name, do_sample) samples_update = gr.update(choices=files, value=(last or (files[-1] if files else None))) return files, (last or None), samples_update with gr.Blocks(title="Speech & Music Bench") as demo: gr.Markdown("#Speech & Music Bench") with gr.Tab("TTS"): gr.Markdown(tts_description_md) with gr.Row(): model_choice = gr.Dropdown( label="Модель TTS", choices=list(MODEL_NAMES.keys()), value="suno/bark-small", ) split_lines_tts = gr.Checkbox(value=True, label="Одна строка = один промпт") prompts_tts = gr.Textbox( label="Промпты", value=DEFAULT_PROMPTS, lines=6, placeholder="Каждая строка — отдельный промпт…", ) run_btn_tts = gr.Button("Сгенерировать речь", variant="primary") with gr.Row(): files_tts = gr.Files(label="Файлы .wav для скачивания") with gr.Row(): samples_tts = gr.Dropdown( label="Все сгенерённые семплы (выберите для прослушивания)", choices=[], allow_custom_value=False ) with gr.Row(): preview_tts = gr.Audio(label="Предпросмотр выбранного семпла", autoplay=False) run_btn_tts.click( fn=run_tts_ui, inputs=[prompts_tts, split_lines_tts, model_choice], outputs=[files_tts, preview_tts, samples_tts], ) samples_tts.change( fn=lambda p: gr.update(value=p), inputs=samples_tts, outputs=preview_tts, ) with gr.Tab("Music"): gr.Markdown(music_description_md) with gr.Row(): music_model = gr.Dropdown( label="Модель MusicGen", choices=MUSIC_MODELS, value=MUSIC_MODELS[0], ) split_lines_music = gr.Checkbox(value=True, label="Одна строка = один промпт") do_sample = gr.Checkbox(value=True, label="do_sample") prompts_music = gr.Textbox( label="Музыкальные промпты", value=MUSIC_DEFAULT_PROMPTS, lines=6, placeholder="Каждая строка — отдельный промпт…", ) run_btn_music = gr.Button("Сгенерировать музыку", variant="primary") with gr.Row(): files_music = gr.Files(label="Файлы .wav для скачивания") with gr.Row(): samples_music = gr.Dropdown( label="Все сгенерённые треки (выберите для прослушивания)", choices=[], allow_custom_value=False ) with gr.Row(): preview_music = gr.Audio(label="Предпросмотр выбранного трека", autoplay=False) run_btn_music.click( fn=run_music_ui, inputs=[prompts_music, split_lines_music, music_model, do_sample], outputs=[files_music, preview_music, samples_music], ) samples_music.change( fn=lambda p: gr.update(value=p), inputs=samples_music, outputs=preview_music, ) if __name__ == "__main__": demo.launch()