Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -136,22 +136,21 @@ class InferenceModel(object):
|
|
| 136 |
@property
|
| 137 |
def input_shapes(self):
|
| 138 |
return {
|
| 139 |
-
|
| 140 |
-
|
| 141 |
}
|
| 142 |
|
| 143 |
def _parse_gin(self, gin_files):
|
| 144 |
"""解析用于训练模型的 gin 文件。"""
|
| 145 |
print(f"[{current_time()}] 日志:解析 gin 文件")
|
| 146 |
gin_bindings = [
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
]
|
| 152 |
with gin.unlock_config():
|
| 153 |
-
gin.parse_config_files_and_bindings(
|
| 154 |
-
gin_files, gin_bindings, finalize_config=False)
|
| 155 |
|
| 156 |
def _load_model(self):
|
| 157 |
"""在解析训练 gin 配置后加载 T5X `Model`。"""
|
|
@@ -159,11 +158,11 @@ class InferenceModel(object):
|
|
| 159 |
model_config = gin.get_configurable(network.T5Config)()
|
| 160 |
module = network.Transformer(config=model_config)
|
| 161 |
return models.ContinuousInputsEncoderDecoderModel(
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
|
| 168 |
|
| 169 |
def restore_from_checkpoint(self, checkpoint_path):
|
|
@@ -176,33 +175,31 @@ class InferenceModel(object):
|
|
| 176 |
partitioner=self.partitioner)
|
| 177 |
|
| 178 |
restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
|
| 179 |
-
|
| 180 |
|
| 181 |
train_state_axes = train_state_initializer.train_state_axes
|
| 182 |
self._predict_fn = self._get_predict_fn(train_state_axes)
|
| 183 |
self._train_state = train_state_initializer.from_checkpoint_or_scratch(
|
| 184 |
-
|
| 185 |
|
| 186 |
@functools.lru_cache()
|
| 187 |
def _get_predict_fn(self, train_state_axes):
|
| 188 |
"""生成一个分区的预测函数用于解码。"""
|
| 189 |
print(f"[{current_time()}] 日志:生成用于解码的预测函数")
|
| 190 |
def partial_predict_fn(params, batch, decode_rng):
|
| 191 |
-
return self.model.predict_batch_with_aux(
|
| 192 |
-
params, batch, decoder_params={'decode_rng': None})
|
| 193 |
return self.partitioner.partition(
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
)
|
| 200 |
|
| 201 |
def predict_tokens(self, batch, seed=0):
|
| 202 |
"""从预处理的数据集批次中预测 tokens。"""
|
| 203 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
| 204 |
-
prediction, _ = self._predict_fn(
|
| 205 |
-
self._train_state.params, batch, jax.random.PRNGKey(seed))
|
| 206 |
return self.vocabulary.decode_tf(prediction).numpy()
|
| 207 |
|
| 208 |
def __call__(self, audio):
|
|
@@ -255,16 +252,16 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
| 255 |
def preprocess(self, ds):
|
| 256 |
pp_chain = [
|
| 257 |
functools.partial(
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
# 在训练期间进行缓存。
|
| 264 |
preprocessors.add_dummy_targets,
|
| 265 |
functools.partial(
|
| 266 |
-
|
| 267 |
-
|
| 268 |
]
|
| 269 |
for pp in pp_chain:
|
| 270 |
ds = pp(ds)
|
|
@@ -276,10 +273,10 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
| 276 |
# 向下取整到最接近的符号化时间步。
|
| 277 |
start_time -= start_time % (1 / self.codec.steps_per_second)
|
| 278 |
return {
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
}
|
| 284 |
|
| 285 |
@staticmethod
|
|
|
|
| 136 |
@property
|
| 137 |
def input_shapes(self):
|
| 138 |
return {
|
| 139 |
+
'encoder_input_tokens': (self.batch_size, self.inputs_length),
|
| 140 |
+
'decoder_input_tokens': (self.batch_size, self.outputs_length)
|
| 141 |
}
|
| 142 |
|
| 143 |
def _parse_gin(self, gin_files):
|
| 144 |
"""解析用于训练模型的 gin 文件。"""
|
| 145 |
print(f"[{current_time()}] 日志:解析 gin 文件")
|
| 146 |
gin_bindings = [
|
| 147 |
+
'from __gin__ import dynamic_registration',
|
| 148 |
+
'from mt3 import vocabularies',
|
| 149 |
+
'[email protected]()',
|
| 150 |
+
'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
|
| 151 |
]
|
| 152 |
with gin.unlock_config():
|
| 153 |
+
gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)
|
|
|
|
| 154 |
|
| 155 |
def _load_model(self):
|
| 156 |
"""在解析训练 gin 配置后加载 T5X `Model`。"""
|
|
|
|
| 158 |
model_config = gin.get_configurable(network.T5Config)()
|
| 159 |
module = network.Transformer(config=model_config)
|
| 160 |
return models.ContinuousInputsEncoderDecoderModel(
|
| 161 |
+
module=module,
|
| 162 |
+
input_vocabulary=self.output_features['inputs'].vocabulary,
|
| 163 |
+
output_vocabulary=self.output_features['targets'].vocabulary,
|
| 164 |
+
optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
|
| 165 |
+
input_depth=spectrograms.input_depth(self.spectrogram_config))
|
| 166 |
|
| 167 |
|
| 168 |
def restore_from_checkpoint(self, checkpoint_path):
|
|
|
|
| 175 |
partitioner=self.partitioner)
|
| 176 |
|
| 177 |
restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
|
| 178 |
+
path=checkpoint_path, mode='specific', dtype='float32')
|
| 179 |
|
| 180 |
train_state_axes = train_state_initializer.train_state_axes
|
| 181 |
self._predict_fn = self._get_predict_fn(train_state_axes)
|
| 182 |
self._train_state = train_state_initializer.from_checkpoint_or_scratch(
|
| 183 |
+
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
|
| 184 |
|
| 185 |
@functools.lru_cache()
|
| 186 |
def _get_predict_fn(self, train_state_axes):
|
| 187 |
"""生成一个分区的预测函数用于解码。"""
|
| 188 |
print(f"[{current_time()}] 日志:生成用于解码的预测函数")
|
| 189 |
def partial_predict_fn(params, batch, decode_rng):
|
| 190 |
+
return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None})
|
|
|
|
| 191 |
return self.partitioner.partition(
|
| 192 |
+
partial_predict_fn,
|
| 193 |
+
in_axis_resources=(
|
| 194 |
+
train_state_axes.params,
|
| 195 |
+
t5x.partitioning.PartitionSpec('data',), None),
|
| 196 |
+
out_axis_resources=t5x.partitioning.PartitionSpec('data',)
|
| 197 |
)
|
| 198 |
|
| 199 |
def predict_tokens(self, batch, seed=0):
|
| 200 |
"""从预处理的数据集批次中预测 tokens。"""
|
| 201 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
| 202 |
+
prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
|
|
| 203 |
return self.vocabulary.decode_tf(prediction).numpy()
|
| 204 |
|
| 205 |
def __call__(self, audio):
|
|
|
|
| 252 |
def preprocess(self, ds):
|
| 253 |
pp_chain = [
|
| 254 |
functools.partial(
|
| 255 |
+
t5.data.preprocessors.split_tokens_to_inputs_length,
|
| 256 |
+
sequence_length=self.sequence_length,
|
| 257 |
+
output_features=self.output_features,
|
| 258 |
+
feature_key='inputs',
|
| 259 |
+
additional_feature_keys=['input_times']),
|
| 260 |
# 在训练期间进行缓存。
|
| 261 |
preprocessors.add_dummy_targets,
|
| 262 |
functools.partial(
|
| 263 |
+
preprocessors.compute_spectrograms,
|
| 264 |
+
spectrogram_config=self.spectrogram_config)
|
| 265 |
]
|
| 266 |
for pp in pp_chain:
|
| 267 |
ds = pp(ds)
|
|
|
|
| 273 |
# 向下取整到最接近的符号化时间步。
|
| 274 |
start_time -= start_time % (1 / self.codec.steps_per_second)
|
| 275 |
return {
|
| 276 |
+
'est_tokens': tokens,
|
| 277 |
+
'start_time': start_time,
|
| 278 |
+
# 内部 MT3 代码期望原始输入,这里不使用。
|
| 279 |
+
'raw_inputs': []
|
| 280 |
}
|
| 281 |
|
| 282 |
@staticmethod
|