Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -257,26 +257,22 @@ def save_midi(tokens, batch_number=None):
|
|
| 257 |
|
| 258 |
@spaces.GPU
|
| 259 |
def generate_music(prime,
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
| 266 |
):
|
| 267 |
|
| 268 |
-
model.cuda()
|
| 269 |
-
model.eval()
|
| 270 |
-
|
| 271 |
-
print('Generating...')
|
| 272 |
-
|
| 273 |
if not prime:
|
| 274 |
inputs = [19461]
|
| 275 |
|
| 276 |
else:
|
| 277 |
-
inputs = prime
|
| 278 |
|
| 279 |
-
if gen_outro:
|
| 280 |
inputs.extend([18945])
|
| 281 |
|
| 282 |
if gen_drums:
|
|
@@ -301,11 +297,17 @@ def generate_music(prime,
|
|
| 301 |
verbose=False)
|
| 302 |
|
| 303 |
output = out.tolist()
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
#==================================================================================
|
| 311 |
|
|
@@ -316,12 +318,13 @@ block_lines = []
|
|
| 316 |
#==================================================================================
|
| 317 |
|
| 318 |
def generate_callback(input_midi,
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
| 325 |
):
|
| 326 |
|
| 327 |
global generated_batches
|
|
@@ -333,7 +336,8 @@ def generate_callback(input_midi,
|
|
| 333 |
block_lines.append(midi_score[-1][1] / 1000)
|
| 334 |
|
| 335 |
batched_gen_tokens = generate_music(final_composition,
|
| 336 |
-
num_gen_tokens,
|
|
|
|
| 337 |
NUM_OUT_BATCHES,
|
| 338 |
gen_outro,
|
| 339 |
gen_drums,
|
|
@@ -385,18 +389,15 @@ def generate_callback(input_midi,
|
|
| 385 |
#==================================================================================
|
| 386 |
|
| 387 |
def generate_callback_wrapper(input_midi,
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
| 394 |
):
|
| 395 |
|
| 396 |
-
print('=' * 70)
|
| 397 |
-
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 398 |
-
start_time = reqtime.time()
|
| 399 |
-
|
| 400 |
print('=' * 70)
|
| 401 |
if input_midi is not None:
|
| 402 |
fn = os.path.basename(input_midi.name)
|
|
@@ -413,6 +414,7 @@ def generate_callback_wrapper(input_midi,
|
|
| 413 |
result = generate_callback(input_midi,
|
| 414 |
num_prime_tokens,
|
| 415 |
num_gen_tokens,
|
|
|
|
| 416 |
gen_outro,
|
| 417 |
gen_drums,
|
| 418 |
model_temperature,
|
|
@@ -420,12 +422,6 @@ def generate_callback_wrapper(input_midi,
|
|
| 420 |
)
|
| 421 |
|
| 422 |
generated_batches.extend([sublist[2] for sublist in result])
|
| 423 |
-
|
| 424 |
-
print('=' * 70)
|
| 425 |
-
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 426 |
-
print('=' * 70)
|
| 427 |
-
print('Req execution time:', (reqtime.time() - start_time), 'sec')
|
| 428 |
-
print('*' * 70)
|
| 429 |
|
| 430 |
return tuple(item for sublist in result for item in sublist[:2])
|
| 431 |
|
|
@@ -499,7 +495,7 @@ def reset():
|
|
| 499 |
final_composition = []
|
| 500 |
generated_batches = []
|
| 501 |
block_lines = []
|
| 502 |
-
|
| 503 |
#==================================================================================
|
| 504 |
|
| 505 |
PDT = timezone('US/Pacific')
|
|
@@ -529,17 +525,18 @@ with gr.Blocks() as demo:
|
|
| 529 |
for faster execution and endless generation!
|
| 530 |
""")
|
| 531 |
|
| 532 |
-
gr.Markdown("## Upload
|
| 533 |
|
| 534 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
| 535 |
input_midi.upload(reset)
|
| 536 |
|
| 537 |
gr.Markdown("## Generate")
|
| 538 |
|
| 539 |
-
num_prime_tokens = gr.Slider(15,
|
| 540 |
num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
|
| 541 |
-
|
| 542 |
-
gen_drums = gr.Checkbox(value=False, label="
|
|
|
|
| 543 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 544 |
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
| 545 |
|
|
@@ -561,6 +558,7 @@ with gr.Blocks() as demo:
|
|
| 561 |
[input_midi,
|
| 562 |
num_prime_tokens,
|
| 563 |
num_gen_tokens,
|
|
|
|
| 564 |
gen_outro,
|
| 565 |
gen_drums,
|
| 566 |
model_temperature,
|
|
|
|
| 257 |
|
| 258 |
@spaces.GPU
|
| 259 |
def generate_music(prime,
|
| 260 |
+
num_gen_tokens,
|
| 261 |
+
num_mem_tokens,
|
| 262 |
+
num_gen_batches,
|
| 263 |
+
gen_outro,
|
| 264 |
+
gen_drums,
|
| 265 |
+
model_temperature,
|
| 266 |
+
model_sampling_top_p
|
| 267 |
):
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
if not prime:
|
| 270 |
inputs = [19461]
|
| 271 |
|
| 272 |
else:
|
| 273 |
+
inputs = prime[-num_mem_tokens:]
|
| 274 |
|
| 275 |
+
if gen_outro == 'Force':
|
| 276 |
inputs.extend([18945])
|
| 277 |
|
| 278 |
if gen_drums:
|
|
|
|
| 297 |
verbose=False)
|
| 298 |
|
| 299 |
output = out.tolist()
|
| 300 |
+
|
| 301 |
+
output_batches = []
|
| 302 |
+
|
| 303 |
+
if gen_outro == 'Disable':
|
| 304 |
+
for o in output:
|
| 305 |
+
output_batches.append([t for t in o if not 18944 < t < 19330])
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
output_batches = output
|
| 309 |
+
|
| 310 |
+
return output_batches
|
| 311 |
|
| 312 |
#==================================================================================
|
| 313 |
|
|
|
|
| 318 |
#==================================================================================
|
| 319 |
|
| 320 |
def generate_callback(input_midi,
|
| 321 |
+
num_prime_tokens,
|
| 322 |
+
num_gen_tokens,
|
| 323 |
+
num_mem_tokens,
|
| 324 |
+
gen_outro,
|
| 325 |
+
gen_drums,
|
| 326 |
+
model_temperature,
|
| 327 |
+
model_sampling_top_p
|
| 328 |
):
|
| 329 |
|
| 330 |
global generated_batches
|
|
|
|
| 336 |
block_lines.append(midi_score[-1][1] / 1000)
|
| 337 |
|
| 338 |
batched_gen_tokens = generate_music(final_composition,
|
| 339 |
+
num_gen_tokens,
|
| 340 |
+
num_mem_tokens,
|
| 341 |
NUM_OUT_BATCHES,
|
| 342 |
gen_outro,
|
| 343 |
gen_drums,
|
|
|
|
| 389 |
#==================================================================================
|
| 390 |
|
| 391 |
def generate_callback_wrapper(input_midi,
|
| 392 |
+
num_prime_tokens,
|
| 393 |
+
num_gen_tokens,
|
| 394 |
+
num_mem_tokens,
|
| 395 |
+
gen_outro,
|
| 396 |
+
gen_drums,
|
| 397 |
+
model_temperature,
|
| 398 |
+
model_sampling_top_p
|
| 399 |
):
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
print('=' * 70)
|
| 402 |
if input_midi is not None:
|
| 403 |
fn = os.path.basename(input_midi.name)
|
|
|
|
| 414 |
result = generate_callback(input_midi,
|
| 415 |
num_prime_tokens,
|
| 416 |
num_gen_tokens,
|
| 417 |
+
num_mem_tokens,
|
| 418 |
gen_outro,
|
| 419 |
gen_drums,
|
| 420 |
model_temperature,
|
|
|
|
| 422 |
)
|
| 423 |
|
| 424 |
generated_batches.extend([sublist[2] for sublist in result])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
return tuple(item for sublist in result for item in sublist[:2])
|
| 427 |
|
|
|
|
| 495 |
final_composition = []
|
| 496 |
generated_batches = []
|
| 497 |
block_lines = []
|
| 498 |
+
|
| 499 |
#==================================================================================
|
| 500 |
|
| 501 |
PDT = timezone('US/Pacific')
|
|
|
|
| 525 |
for faster execution and endless generation!
|
| 526 |
""")
|
| 527 |
|
| 528 |
+
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
| 529 |
|
| 530 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
| 531 |
input_midi.upload(reset)
|
| 532 |
|
| 533 |
gr.Markdown("## Generate")
|
| 534 |
|
| 535 |
+
num_prime_tokens = gr.Slider(15, 6990, value=600, step=3, label="Number of prime tokens")
|
| 536 |
num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
|
| 537 |
+
num_mem_tokens = gr.Slider(15, 6990, value=6990, step=3, label="Number of memory tokens")
|
| 538 |
+
gen_drums = gr.Checkbox(value=False, label="Introduce drums")
|
| 539 |
+
gen_outro = gr.Radio(["Auto", "Disable", "Force"], value="Auto", label="Outro options")
|
| 540 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 541 |
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
| 542 |
|
|
|
|
| 558 |
[input_midi,
|
| 559 |
num_prime_tokens,
|
| 560 |
num_gen_tokens,
|
| 561 |
+
num_mem_tokens,
|
| 562 |
gen_outro,
|
| 563 |
gen_drums,
|
| 564 |
model_temperature,
|