Spaces:
Running
Running
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| OUTDIR = Path("outputs") | |
| OUTDIR.mkdir(parents=True, exist_ok=True) | |
| def slug(s: str) -> str: | |
| """Make a safe filename slug (ASCII, underscores).""" | |
| if s is None: | |
| s = "" | |
| return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_") | |
| def save_wav(path: Path, sr: int, audio): | |
| import numpy as np | |
| from scipy.io import wavfile as wav | |
| if hasattr(audio, "detach"): | |
| audio = audio.detach().cpu().numpy() | |
| a = np.array(audio).astype(np.float32) | |
| a = np.squeeze(a) | |
| if a.ndim == 2 and a.shape[0] < a.shape[1]: | |
| a = a.T | |
| # normalize if needed (safety) | |
| max_abs = np.max(np.abs(a)) if a.size else 1.0 | |
| if np.isfinite(max_abs) and max_abs > 1.0: | |
| a = a / max_abs | |
| wav.write(str(path), int(sr), a) | |
| MODEL_NAMES = { | |
| "suno/bark-small": "bark", | |
| "facebook/mms-tts-rus": "mms", | |
| "facebook/seamless-m4t-v2-large": "seamless", | |
| } | |
| _model_cache: Dict[str, object] = {} | |
| _device_hint = "auto" | |
| def _load_bark(): | |
| import torch | |
| from transformers import pipeline | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| pipe = pipeline( | |
| task="text-to-speech", | |
| model="suno/bark-small", | |
| device=device, | |
| model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} | |
| ) | |
| if getattr(pipe.model.config, "pad_token_id", None) is None: | |
| pipe.model.config.pad_token_id = pipe.model.config.eos_token_id | |
| def generate(text: str): | |
| out = pipe(text) | |
| return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32) | |
| return generate | |
| def _load_mms(): | |
| import torch | |
| from transformers import pipeline | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| pipe = pipeline( | |
| "text-to-speech", | |
| model="facebook/mms-tts-rus", | |
| device=device, | |
| model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} | |
| ) | |
| if getattr(pipe.model.config, "pad_token_id", None) is None: | |
| pipe.model.config.pad_token_id = pipe.model.config.eos_token_id | |
| def generate(text: str): | |
| out = pipe(text) | |
| return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32) | |
| return generate | |
| def _load_seamless(): | |
| import torch | |
| import numpy as np | |
| from transformers import AutoProcessor | |
| from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import SeamlessM4Tv2Model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| proc = AutoProcessor.from_pretrained( | |
| "facebook/seamless-m4t-v2-large", | |
| use_fast=False | |
| ) | |
| model = SeamlessM4Tv2Model.from_pretrained( | |
| "facebook/seamless-m4t-v2-large", | |
| low_cpu_mem_usage=False | |
| ).to(device) | |
| def generate(text: str): | |
| inputs = proc(text=text, src_lang="rus", return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| audio = model.generate(**inputs, tgt_lang="rus")[0] | |
| audio = audio.detach().cpu().numpy().squeeze().astype(np.float32) | |
| return 16000, audio | |
| return generate | |
| def get_generator(kind: str): | |
| if kind in _model_cache: | |
| return _model_cache[kind] | |
| if kind == "bark": | |
| gen = _load_bark() | |
| elif kind == "mms": | |
| gen = _load_mms() | |
| elif kind == "seamless": | |
| gen = _load_seamless() | |
| else: | |
| raise ValueError(f"Unknown model kind: {kind}") | |
| _model_cache[kind] = gen | |
| return gen | |
| DEFAULT_PROMPTS = ( | |
| "Привет! Это короткий тест русского TTS.\n" | |
| "Сегодня мы проверяем интонации, паузы и четкость дикции.\n" | |
| "Немного сложнее: числа 3.14 и 2025 читаем правильно." | |
| ) | |
| def run_tts( | |
| prompts_text: str, | |
| split_lines: bool, | |
| model_choice: str, | |
| ): | |
| """Main Gradio callback: TTS. | |
| Returns: | |
| files: list[str] — пути к wav | |
| df: pd.DataFrame — таблица метаданных | |
| last_audio: str | None — путь к последнему файлу для предпросмотра | |
| """ | |
| text_items: List[str] = [] | |
| if split_lines: | |
| for line in [s.strip() for s in prompts_text.splitlines()]: | |
| if line: | |
| text_items.append(line) | |
| else: | |
| text_items = [prompts_text.strip()] if prompts_text.strip() else [] | |
| if not text_items: | |
| return [], pd.DataFrame(), None | |
| kind = MODEL_NAMES[model_choice] | |
| gen = get_generator(kind) | |
| stamp_dir = OUTDIR / "tts" / time.strftime("%Y%m%d-%H%M%S") | |
| stamp_dir.mkdir(parents=True, exist_ok=True) | |
| rows = [] | |
| file_paths: List[str] = [] | |
| last_audio_path = None | |
| for p in text_items: | |
| t0 = time.time() | |
| sr, audio = gen(p) | |
| dt = time.time() - t0 | |
| path = stamp_dir / f"{slug(model_choice)}__{slug(p)}.wav" | |
| save_wav(path, sr, audio) | |
| rows.append({ | |
| "task": "tts", | |
| "model": model_choice, | |
| "prompt": p, | |
| "file": str(path), | |
| "sr": sr, | |
| "gen_time_s": round(dt, 3), | |
| }) | |
| file_paths.append(str(path)) | |
| last_audio_path = str(path) | |
| df = pd.DataFrame(rows) | |
| return file_paths, df, last_audio_path | |
| _music_pipes: Dict[str, object] = {} | |
| MUSIC_MODELS = [ | |
| "facebook/musicgen-small", | |
| ] | |
| def get_music_pipe(model_name: str): | |
| import torch | |
| from transformers import pipeline | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| pipe = pipeline( | |
| "text-to-audio", | |
| model=model_name, | |
| device=device, | |
| model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32} | |
| ) | |
| return pipe | |
| MUSIC_DEFAULT_PROMPTS = ( | |
| "High-energy 90s rock track with distorted electric guitars, driving bass, and hard-hitting acoustic drums\n" | |
| "Modern electronic dance track with punchy kick, bright synth lead, and sidechained pads, 128 BPM\n" | |
| "Dark industrial electro with gritty bass, sharp snares, and mechanical percussion" | |
| ) | |
| def run_music( | |
| prompts_text: str, | |
| split_lines: bool, | |
| model_name: str, | |
| do_sample: bool, | |
| ): | |
| """Main Gradio callback: MusicGen.""" | |
| text_items: List[str] = [] | |
| if split_lines: | |
| for line in [s.strip() for s in prompts_text.splitlines()]: | |
| if line: | |
| text_items.append(line) | |
| else: | |
| text_items = [prompts_text.strip()] if prompts_text.strip() else [] | |
| if not text_items: | |
| return [], pd.DataFrame(), None | |
| pipe = get_music_pipe(model_name) | |
| stamp_dir = OUTDIR / "music" / slug(model_name) / time.strftime("%Y%m%d-%H%M%S") | |
| stamp_dir.mkdir(parents=True, exist_ok=True) | |
| rows = [] | |
| file_paths: List[str] = [] | |
| last_audio_path = None | |
| for p in text_items: | |
| t0 = time.time() | |
| out = pipe(p, forward_params={"do_sample": bool(do_sample)}) | |
| dt = time.time() - t0 | |
| sr = int(out["sampling_rate"]) | |
| audio = np.asarray(out["audio"], dtype=np.float32) | |
| path = stamp_dir / f"{slug(p)}.wav" | |
| save_wav(path, sr, audio) | |
| rows.append({ | |
| "task": "music", | |
| "model": model_name, | |
| "prompt": p, | |
| "file": str(path), | |
| "sr": sr, | |
| "gen_time_s": round(dt, 3), | |
| }) | |
| file_paths.append(str(path)) | |
| last_audio_path = str(path) | |
| df = pd.DataFrame(rows) | |
| return file_paths, df, last_audio_path | |
| tts_description_md = ( | |
| """ | |
| Russian TTS Bench: выберите модель и введите один или несколько промптов.\ | |
| По умолчанию каждая строка — отдельный промпт. Результаты сохраняются в `outputs/tts/…`. | |
| **Модели:** | |
| - `suno/bark-small` — небольшой мультиязычный TTS. | |
| - `facebook/mms-tts-rus` — русская TTS из проекта MMS. | |
| - `facebook/seamless-m4t-v2-large` — крупная модель перевода/говорения; тяжёлая для CPU. | |
| """ | |
| ) | |
| music_description_md = ( | |
| """ | |
| **Music Gen:** текст → музыка на базе MusicGen. По умолчанию каждая строка — отдельный промпт.\ | |
| Результаты сохраняются в `outputs/music/<model>/…`. | |
| **Модели:** | |
| - `facebook/musicgen-small` | |
| - (опционально) `facebook/musicgen-stereo-small` — раскомментируйте в коде. | |
| """ | |
| ) | |
| def run_tts_ui(prompts_text, split_lines, model_choice): | |
| files, _, last = run_tts(prompts_text, split_lines, model_choice) | |
| samples_update = gr.update(choices=files, value=(last or (files[-1] if files else None))) | |
| return files, (last or None), samples_update | |
| def run_music_ui(prompts_text, split_lines, model_name, do_sample): | |
| files, _, last = run_music(prompts_text, split_lines, model_name, do_sample) | |
| samples_update = gr.update(choices=files, value=(last or (files[-1] if files else None))) | |
| return files, (last or None), samples_update | |
| with gr.Blocks(title="Speech & Music Bench") as demo: | |
| gr.Markdown("#Speech & Music Bench") | |
| with gr.Tab("TTS"): | |
| gr.Markdown(tts_description_md) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| label="Модель TTS", | |
| choices=list(MODEL_NAMES.keys()), | |
| value="suno/bark-small", | |
| ) | |
| split_lines_tts = gr.Checkbox(value=True, label="Одна строка = один промпт") | |
| prompts_tts = gr.Textbox( | |
| label="Промпты", | |
| value=DEFAULT_PROMPTS, | |
| lines=6, | |
| placeholder="Каждая строка — отдельный промпт…", | |
| ) | |
| run_btn_tts = gr.Button("Сгенерировать речь", variant="primary") | |
| with gr.Row(): | |
| files_tts = gr.Files(label="Файлы .wav для скачивания") | |
| with gr.Row(): | |
| samples_tts = gr.Dropdown( | |
| label="Все сгенерённые семплы (выберите для прослушивания)", | |
| choices=[], | |
| allow_custom_value=False | |
| ) | |
| with gr.Row(): | |
| preview_tts = gr.Audio(label="Предпросмотр выбранного семпла", autoplay=False) | |
| run_btn_tts.click( | |
| fn=run_tts_ui, | |
| inputs=[prompts_tts, split_lines_tts, model_choice], | |
| outputs=[files_tts, preview_tts, samples_tts], | |
| ) | |
| samples_tts.change( | |
| fn=lambda p: gr.update(value=p), | |
| inputs=samples_tts, | |
| outputs=preview_tts, | |
| ) | |
| with gr.Tab("Music"): | |
| gr.Markdown(music_description_md) | |
| with gr.Row(): | |
| music_model = gr.Dropdown( | |
| label="Модель MusicGen", | |
| choices=MUSIC_MODELS, | |
| value=MUSIC_MODELS[0], | |
| ) | |
| split_lines_music = gr.Checkbox(value=True, label="Одна строка = один промпт") | |
| do_sample = gr.Checkbox(value=True, label="do_sample") | |
| prompts_music = gr.Textbox( | |
| label="Музыкальные промпты", | |
| value=MUSIC_DEFAULT_PROMPTS, | |
| lines=6, | |
| placeholder="Каждая строка — отдельный промпт…", | |
| ) | |
| run_btn_music = gr.Button("Сгенерировать музыку", variant="primary") | |
| with gr.Row(): | |
| files_music = gr.Files(label="Файлы .wav для скачивания") | |
| with gr.Row(): | |
| samples_music = gr.Dropdown( | |
| label="Все сгенерённые треки (выберите для прослушивания)", | |
| choices=[], | |
| allow_custom_value=False | |
| ) | |
| with gr.Row(): | |
| preview_music = gr.Audio(label="Предпросмотр выбранного трека", autoplay=False) | |
| run_btn_music.click( | |
| fn=run_music_ui, | |
| inputs=[prompts_music, split_lines_music, music_model, do_sample], | |
| outputs=[files_music, preview_music, samples_music], | |
| ) | |
| samples_music.change( | |
| fn=lambda p: gr.update(value=p), | |
| inputs=samples_music, | |
| outputs=preview_music, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |