BIBLETUM commited on
Commit
6cb8e5b
·
verified ·
1 Parent(s): 332bdfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -14
app.py CHANGED
@@ -42,36 +42,51 @@ _model_cache: Dict[str, object] = {}
42
  _device_hint = "auto"
43
 
44
  def _load_bark():
 
45
  from transformers import pipeline
46
- pipe = pipeline("text-to-speech", model="suno/bark-small", device_map=_device_hint)
 
 
 
 
 
 
 
 
47
  if getattr(pipe.model.config, "pad_token_id", None) is None:
48
  pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
49
 
50
- def generate(text: str) -> Tuple[int, np.ndarray]:
51
  out = pipe(text)
52
  return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
53
 
54
  return generate
55
 
 
56
  def _load_mms():
 
57
  from transformers import pipeline
58
- pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus", device_map=_device_hint)
 
 
 
 
 
 
59
  if getattr(pipe.model.config, "pad_token_id", None) is None:
60
  pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
61
 
62
- def generate(text: str) -> Tuple[int, np.ndarray]:
63
  out = pipe(text)
64
  return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
65
-
66
  return generate
67
 
 
68
  def _load_seamless():
69
  import torch
70
  import numpy as np
71
  from transformers import AutoProcessor
72
- from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import (
73
- SeamlessM4Tv2Model,
74
- )
75
 
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
77
 
@@ -79,7 +94,11 @@ def _load_seamless():
79
  "facebook/seamless-m4t-v2-large",
80
  use_fast=False
81
  )
82
- model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
 
 
 
 
83
 
84
  def generate(text: str):
85
  inputs = proc(text=text, src_lang="rus", return_tensors="pt")
@@ -88,9 +107,9 @@ def _load_seamless():
88
  audio = model.generate(**inputs, tgt_lang="rus")[0]
89
  audio = audio.detach().cpu().numpy().squeeze().astype(np.float32)
90
  return 16000, audio
91
-
92
  return generate
93
 
 
94
  def get_generator(kind: str):
95
  if kind in _model_cache:
96
  return _model_cache[kind]
@@ -171,13 +190,18 @@ MUSIC_MODELS = [
171
  ]
172
 
173
  def get_music_pipe(model_name: str):
174
- if model_name in _music_pipes:
175
- return _music_pipes[model_name]
176
  from transformers import pipeline
177
- pipe = pipeline("text-to-audio", model=model_name, device_map=_device_hint)
178
- _music_pipes[model_name] = pipe
 
 
 
 
 
179
  return pipe
180
 
 
181
  MUSIC_DEFAULT_PROMPTS = (
182
  "High-energy 90s rock track with distorted electric guitars, driving bass, and hard-hitting acoustic drums\n"
183
  "Modern electronic dance track with punchy kick, bright synth lead, and sidechained pads, 128 BPM\n"
 
42
  _device_hint = "auto"
43
 
44
  def _load_bark():
45
+ import torch
46
  from transformers import pipeline
47
+
48
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
49
+ pipe = pipeline(
50
+ task="text-to-speech",
51
+ model="suno/bark-small",
52
+ device=device,
53
+ model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32}
54
+ )
55
+
56
  if getattr(pipe.model.config, "pad_token_id", None) is None:
57
  pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
58
 
59
+ def generate(text: str):
60
  out = pipe(text)
61
  return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
62
 
63
  return generate
64
 
65
+
66
  def _load_mms():
67
+ import torch
68
  from transformers import pipeline
69
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
70
+ pipe = pipeline(
71
+ "text-to-speech",
72
+ model="facebook/mms-tts-rus",
73
+ device=device,
74
+ model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32}
75
+ )
76
  if getattr(pipe.model.config, "pad_token_id", None) is None:
77
  pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
78
 
79
+ def generate(text: str):
80
  out = pipe(text)
81
  return int(out["sampling_rate"]), np.asarray(out["audio"], dtype=np.float32)
 
82
  return generate
83
 
84
+
85
  def _load_seamless():
86
  import torch
87
  import numpy as np
88
  from transformers import AutoProcessor
89
+ from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import SeamlessM4Tv2Model
 
 
90
 
91
  device = "cuda" if torch.cuda.is_available() else "cpu"
92
 
 
94
  "facebook/seamless-m4t-v2-large",
95
  use_fast=False
96
  )
97
+
98
+ model = SeamlessM4Tv2Model.from_pretrained(
99
+ "facebook/seamless-m4t-v2-large",
100
+ low_cpu_mem_usage=False
101
+ ).to(device)
102
 
103
  def generate(text: str):
104
  inputs = proc(text=text, src_lang="rus", return_tensors="pt")
 
107
  audio = model.generate(**inputs, tgt_lang="rus")[0]
108
  audio = audio.detach().cpu().numpy().squeeze().astype(np.float32)
109
  return 16000, audio
 
110
  return generate
111
 
112
+
113
  def get_generator(kind: str):
114
  if kind in _model_cache:
115
  return _model_cache[kind]
 
190
  ]
191
 
192
  def get_music_pipe(model_name: str):
193
+ import torch
 
194
  from transformers import pipeline
195
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
196
+ pipe = pipeline(
197
+ "text-to-audio",
198
+ model=model_name,
199
+ device=device,
200
+ model_kwargs={"low_cpu_mem_usage": False, "torch_dtype": torch.float32}
201
+ )
202
  return pipe
203
 
204
+
205
  MUSIC_DEFAULT_PROMPTS = (
206
  "High-energy 90s rock track with distorted electric guitars, driving bass, and hard-hitting acoustic drums\n"
207
  "Modern electronic dance track with punchy kick, bright synth lead, and sidechained pads, 128 BPM\n"