BIBLETUM commited on
Commit
76c29f8
·
verified ·
1 Parent(s): 5c17c33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -42
app.py CHANGED
@@ -3,68 +3,227 @@ import time
3
  from pathlib import Path
4
  from typing import List, Tuple, Dict
5
 
6
-
7
  import numpy as np
8
  import pandas as pd
9
  import gradio as gr
10
 
11
-
12
  # === Utils ===
13
  OUTDIR = Path("outputs")
14
  OUTDIR.mkdir(parents=True, exist_ok=True)
15
 
16
 
17
  def slug(s: str) -> str:
18
- return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_")
 
 
 
19
 
20
 
21
  def save_wav(path: Path, sr: int, audio):
22
- import numpy as np
23
- import scipy.io.wavfile as wav
24
-
25
-
26
- if hasattr(audio, "detach"):
27
- audio = audio.detach().cpu().numpy()
28
- a = np.array(audio).astype(np.float32)
29
- a = np.squeeze(a)
30
- if a.ndim == 2 and a.shape[0] < a.shape[1]:
31
- a = a.T
32
- # normalize if needed
33
- max_abs = np.max(np.abs(a)) if a.size else 1.0
34
- if np.isfinite(max_abs) and max_abs > 1.0:
35
- a = a / max_abs
36
- wav.write(str(path), int(sr), a)
37
-
38
-
39
 
40
 
41
  # === Lazy model registry ===
42
  MODEL_NAMES = {
43
- "suno/bark-small": "bark",
44
- "facebook/mms-tts-rus": "mms",
45
- "facebook/seamless-m4t-v2-large": "seamless",
46
  }
47
 
48
-
49
  _model_cache: Dict[str, object] = {}
50
- _device_hint = "auto" # for pipelines; Seamless picks cpu/gpu inside
51
-
52
-
53
 
54
 
55
  def _load_bark():
56
- from transformers import pipeline
57
- pipe = pipeline("text-to-speech", model="suno/bark-small", device_map=_device_hint)
58
- # Bark иногда не имеет pad_token_id
59
- if getattr(pipe.model.config, "pad_token_id", None) is None:
60
- pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
61
-
62
-
63
- def generate(text: str) -> Tuple[int, np.ndarray]:
64
- out = pipe(text)
65
- return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
66
-
67
-
68
- return generate
69
-
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from pathlib import Path
4
  from typing import List, Tuple, Dict
5
 
 
6
  import numpy as np
7
  import pandas as pd
8
  import gradio as gr
9
 
 
10
  # === Utils ===
11
  OUTDIR = Path("outputs")
12
  OUTDIR.mkdir(parents=True, exist_ok=True)
13
 
14
 
15
  def slug(s: str) -> str:
16
+ """Make a safe filename slug (ASCII, underscores)."""
17
+ if s is None:
18
+ s = ""
19
+ return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_")
20
 
21
 
22
  def save_wav(path: Path, sr: int, audio):
23
+ import numpy as np
24
+ import scipy.io.wavfile as wav
25
+
26
+ if hasattr(audio, "detach"):
27
+ audio = audio.detach().cpu().numpy()
28
+ a = np.array(audio).astype(np.float32)
29
+ a = np.squeeze(a)
30
+ if a.ndim == 2 and a.shape[0] < a.shape[1]:
31
+ a = a.T
32
+ # normalize if needed
33
+ max_abs = np.max(np.abs(a)) if a.size else 1.0
34
+ if np.isfinite(max_abs) and max_abs > 1.0:
35
+ a = a / max_abs
36
+ wav.write(str(path), int(sr), a)
 
 
 
37
 
38
 
39
  # === Lazy model registry ===
40
  MODEL_NAMES = {
41
+ "suno/bark-small": "bark",
42
+ "facebook/mms-tts-rus": "mms",
43
+ "facebook/seamless-m4t-v2-large": "seamless",
44
  }
45
 
 
46
  _model_cache: Dict[str, object] = {}
47
+ _device_hint = "auto" # for pipelines; Seamless picks cpu/gpu inside
 
 
48
 
49
 
50
  def _load_bark():
51
+ from transformers import pipeline
52
+ pipe = pipeline("text-to-speech", model="suno/bark-small", device_map=_device_hint)
53
+ # Bark иногда не имеет pad_token_id
54
+ if getattr(pipe.model.config, "pad_token_id", None) is None:
55
+ pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
56
+
57
+ def generate(text: str) -> Tuple[int, np.ndarray]:
58
+ out = pipe(text)
59
+ return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
60
+
61
+ return generate
62
+
63
+
64
+ def _load_mms():
65
+ from transformers import pipeline
66
+ pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus", device_map=_device_hint)
67
+ if getattr(pipe.model.config, "pad_token_id", None) is None:
68
+ pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
69
+
70
+ def generate(text: str) -> Tuple[int, np.ndarray]:
71
+ out = pipe(text)
72
+ return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
73
+
74
+ return generate
75
+
76
+
77
+ def _load_seamless():
78
+ import torch
79
+ import numpy as np
80
+ from transformers import AutoProcessor
81
+ # ВНИМАНИЕ: импорт класса модели из подмодуля transformers
82
+ from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import (
83
+ SeamlessM4Tv2Model,
84
+ )
85
+
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ proc = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
88
+ model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
89
+
90
+ def generate(text: str) -> Tuple[int, np.ndarray]:
91
+ inputs = proc(text=text, src_lang="rus", return_tensors="pt")
92
+ inputs = {k: v.to(device) for k, v in inputs.items()}
93
+ with torch.no_grad():
94
+ audio = model.generate(**inputs, tgt_lang="rus")[0]
95
+ audio = audio.detach().cpu().numpy().squeeze().astype(np.float32)
96
+ return 16000, audio # Seamless выдаёт 16kHz
97
+
98
+ return generate
99
+
100
+
101
+ def get_generator(kind: str):
102
+ if kind in _model_cache:
103
+ return _model_cache[kind]
104
+ if kind == "bark":
105
+ gen = _load_bark()
106
+ elif kind == "mms":
107
+ gen = _load_mms()
108
+ elif kind == "seamless":
109
+ gen = _load_seamless()
110
+ else:
111
+ raise ValueError(f"Unknown model kind: {kind}")
112
+ _model_cache[kind] = gen
113
+ return gen
114
+
115
+
116
+ # === Inference ===
117
+ DEFAULT_PROMPTS = (
118
+ "Привет! Это короткий тест русского TTS.\n"
119
+ "Сегодня мы проверяем интонации, паузы и четкость дикции.\n"
120
+ "Немного сложнее: числа 3.14 и 2025 читаем правильно."
121
+ )
122
+
123
+
124
+ def run_tts(
125
+ prompts_text: str,
126
+ split_lines: bool,
127
+ model_choice: str,
128
+ ) -> tuple:
129
+ """Main Gradio callback.
130
+
131
+ Returns:
132
+ files: list[str] — файловые пути для скачивания
133
+ df: pd.DataFrame — таблица с метаданными
134
+ last_audio: tuple[int, np.ndarray] | None — предпросмотр последнего файла
135
+ """
136
+ text_items: List[str] = []
137
+ if split_lines:
138
+ for line in [s.strip() for s in prompts_text.splitlines()]:
139
+ if line:
140
+ text_items.append(line)
141
+ else:
142
+ text_items = [prompts_text.strip()] if prompts_text.strip() else []
143
+
144
+ if not text_items:
145
+ return [], pd.DataFrame(), None
146
+
147
+ kind = MODEL_NAMES[model_choice]
148
+ gen = get_generator(kind)
149
+
150
+ stamp_dir = OUTDIR / time.strftime("%Y%m%d-%H%M%S")
151
+ stamp_dir.mkdir(parents=True, exist_ok=True)
152
+
153
+ rows = []
154
+ file_paths: List[str] = []
155
+ last_audio_payload = None
156
+
157
+ for p in text_items:
158
+ t0 = time.time()
159
+ sr, audio = gen(p)
160
+ dt = time.time() - t0
161
+ path = stamp_dir / f"{slug(model_choice)}__{slug(p)}.wav"
162
+ save_wav(path, sr, audio)
163
+ rows.append(
164
+ {
165
+ "model": model_choice,
166
+ "prompt": p,
167
+ "file": str(path),
168
+ "sr": sr,
169
+ "gen_time_s": round(dt, 3),
170
+ }
171
+ )
172
+ file_paths.append(str(path))
173
+ last_audio_payload = (sr, audio)
174
+
175
+ df = pd.DataFrame(rows)
176
+ return file_paths, df, last_audio_payload
177
+
178
+
179
+ # === UI ===
180
+ description_md = (
181
+ """
182
+ Russian TTS Bench: выберите модель и введите один или несколько промптов.\
183
+ По умолчанию каждая строка — отдельный промпт. Результаты сохраняются в `outputs/…`.
184
+
185
+ **Модели:**
186
+ - `suno/bark-small` — небольшой мультиязычный TTS.
187
+ - `facebook/mms-tts-rus` — русская TTS из проекта MMS.
188
+ - `facebook/seamless-m4t-v2-large` — крупная модель перевода/говорения; тяжёлая для CPU.
189
+
190
+ ⚠️ На CPU генерация может быть очень медленной, особенно для Seamless. Для комфортной работы выберите Space с GPU.
191
+ """
192
+ )
193
+
194
+ with gr.Blocks(title="Russian TTS Bench") as demo:
195
+ gr.Markdown("# 🗣️ Russian TTS Bench")
196
+ gr.Markdown(description_md)
197
+
198
+ with gr.Row():
199
+ model_choice = gr.Dropdown(
200
+ label="Модель",
201
+ choices=list(MODEL_NAMES.keys()),
202
+ value="suno/bark-small",
203
+ )
204
+ split_lines = gr.Checkbox(value=True, label="Одна строка = один промпт")
205
+
206
+ prompts = gr.Textbox(
207
+ label="Промпты",
208
+ value=DEFAULT_PROMPTS,
209
+ lines=6,
210
+ placeholder="Каждая строка — отдельный промпт…",
211
+ )
212
+
213
+ run_btn = gr.Button("Сгенерировать", variant="primary")
214
+
215
+ with gr.Row():
216
+ files = gr.Files(label="Файлы .wav для скачивания")
217
+ with gr.Row():
218
+ df_out = gr.Dataframe(label="Таблица результатов", interactive=False)
219
+ with gr.Row():
220
+ preview = gr.Audio(label="Предпросмотр последнего семпла", autoplay=False)
221
+
222
+ run_btn.click(
223
+ fn=run_tts,
224
+ inputs=[prompts, split_lines, model_choice],
225
+ outputs=[files, df_out, preview],
226
+ )
227
+
228
+ if __name__ == "__main__":
229
+ demo.launch()