asigalov61 commited on
Commit
2a46ddc
·
verified ·
1 Parent(s): 493a4ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -278,7 +278,8 @@ def generate_music(prime, num_gen_tokens, num_gen_batches, model_temperature, mo
278
  if len(prime) >= 7168:
279
  prime = [18816] + prime[-7168:]
280
 
281
- inputs = prime if prime else [18816]
 
282
  print("Generating...")
283
  inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
284
  with ctx:
@@ -297,7 +298,7 @@ def generate_music(prime, num_gen_tokens, num_gen_batches, model_temperature, mo
297
  print_sep()
298
  return out.tolist()
299
 
300
- def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens,
301
  model_temperature, model_top_p, add_drums, add_outro, final_composition, generated_batches, block_lines):
302
  """
303
  Generate tokens using the model, update the composition state, and prepare outputs.
@@ -334,6 +335,14 @@ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens,
334
  # Use the last note's time as a marker.
335
  block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
336
 
 
 
 
 
 
 
 
 
337
  if final_composition:
338
  if add_outro:
339
  final_composition.append(18817) # Outro token
@@ -446,6 +455,8 @@ def reset(final_composition=[], generated_batches=[], block_lines=[]):
446
  print_sep()
447
  return [], [], []
448
 
 
 
449
  # -----------------------------
450
  # GRADIO INTERFACE SETUP
451
  # -----------------------------
@@ -506,6 +517,8 @@ with gr.Blocks() as demo:
506
  [final_composition, generated_batches, block_lines])
507
 
508
  gr.Markdown("## Generate")
 
 
509
  num_prime_tokens = gr.Slider(16, 7168, value=7168, step=1, label="Number of prime tokens")
510
  num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
511
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
@@ -524,7 +537,7 @@ with gr.Blocks() as demo:
524
  outputs.extend([audio_output, plot_output])
525
  generate_btn.click(
526
  generate_music_and_state,
527
- [input_midi, num_prime_tokens, num_gen_tokens, model_temperature, model_top_p, add_drums, add_outro,
528
  final_composition, generated_batches, block_lines],
529
  outputs
530
  )
 
278
  if len(prime) >= 7168:
279
  prime = [18816] + prime[-7168:]
280
 
281
+ inputs = prime
282
+
283
  print("Generating...")
284
  inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
285
  with ctx:
 
298
  print_sep()
299
  return out.tolist()
300
 
301
+ def generate_music_and_state(input_midi, prime_instruments, num_prime_tokens, num_gen_tokens,
302
  model_temperature, model_top_p, add_drums, add_outro, final_composition, generated_batches, block_lines):
303
  """
304
  Generate tokens using the model, update the composition state, and prepare outputs.
 
335
  # Use the last note's time as a marker.
336
  block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
337
 
338
+ if not final_composition and input is None:
339
+ final_composition = [18816, 0]
340
+
341
+ for i, instr in enumerate(prime_instruments):
342
+ instr_num = patch2number[instr]
343
+ final_composition.append((128*instr_num)+(72-(i*12))+256)
344
+ final_composition.append((8*16)+5+16768)
345
+
346
  if final_composition:
347
  if add_outro:
348
  final_composition.append(18817) # Outro token
 
455
  print_sep()
456
  return [], [], []
457
 
458
+ patch2number = {v: k for k, v in TMIDIX.Number2patch.items()}
459
+
460
  # -----------------------------
461
  # GRADIO INTERFACE SETUP
462
  # -----------------------------
 
517
  [final_composition, generated_batches, block_lines])
518
 
519
  gr.Markdown("## Generate")
520
+ prime_instruments = gr.Dropdown(label="Prime instruments (select up to 5)", choices=list(patch2number.keys()),
521
+ multiselect=True, max_choices=5, type="value")
522
  num_prime_tokens = gr.Slider(16, 7168, value=7168, step=1, label="Number of prime tokens")
523
  num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
524
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
 
537
  outputs.extend([audio_output, plot_output])
538
  generate_btn.click(
539
  generate_music_and_state,
540
+ [input_midi, prime_instruments, num_prime_tokens, num_gen_tokens, model_temperature, model_top_p, add_drums, add_outro,
541
  final_composition, generated_batches, block_lines],
542
  outputs
543
  )