skytnt commited on
Commit
256bea9
1 Parent(s): e593d58

add undo button

Browse files
Files changed (3) hide show
  1. app.py +47 -19
  2. javascript/app.js +14 -14
  3. midi_synthesizer.py +3 -2
app.py CHANGED
@@ -121,7 +121,7 @@ def send_msgs(msgs):
121
  return json.dumps(msgs)
122
 
123
 
124
- def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  model = models[model_name]
@@ -187,8 +187,10 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
187
  for token_seq in mid:
188
  mid_seq.append(token_seq.tolist())
189
  elif tab == 2 and mid_seq is not None:
 
190
  mid = np.asarray(mid_seq, dtype=np.int64)
191
  else:
 
192
  mid_seq = []
193
  mid = None
194
 
@@ -196,12 +198,11 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
196
  max_len += len(mid)
197
 
198
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
199
- if tab == 2:
200
- init_msgs = [create_msg("visualizer_continue", tokenizer.version)]
201
- else:
202
- init_msgs = [create_msg("visualizer_clear", tokenizer.version),
203
- create_msg("visualizer_append", events)]
204
- yield mid_seq, None, None, seed, send_msgs(init_msgs)
205
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
206
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
207
  disable_channels=disable_channels, generator=generator)
@@ -213,29 +214,51 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
213
  events.append(tokenizer.tokens2event(token_seq))
214
  ct = time.time()
215
  if ct - t > 0.5:
216
- yield mid_seq, None, None, seed, send_msgs(
217
- [create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
 
218
  t = ct
219
  events = []
220
 
 
221
  mid = tokenizer.detokenize(mid_seq)
 
222
  with open(f"output.mid", 'wb') as f:
223
  f.write(MIDI.score2midi(mid))
224
- audio = synthesizer.synthesis(MIDI.score2opus(mid))
225
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
226
- yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
 
 
227
 
228
 
229
  def cancel_run(model_name, mid_seq):
230
  if mid_seq is None:
231
  return None, None, []
232
  tokenizer = models[model_name][2]
 
233
  mid = tokenizer.detokenize(mid_seq)
 
234
  with open(f"output.mid", 'wb') as f:
235
  f.write(MIDI.score2midi(mid))
236
- audio = synthesizer.synthesis(MIDI.score2opus(mid))
 
 
 
 
 
 
 
 
 
 
 
237
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
238
- return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
 
 
 
 
239
 
240
 
241
  def load_javascript(dir="javascript"):
@@ -341,7 +364,7 @@ if __name__ == "__main__":
341
  type="value", value=list(models.keys())[0])
342
  tab_select = gr.State(value=0)
343
  with gr.Tabs():
344
- with gr.TabItem("instrument prompt") as tab1:
345
  input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
346
  multiselect=True, max_choices=15, type="value")
347
  input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
@@ -388,6 +411,7 @@ if __name__ == "__main__":
388
  [input_midi, input_midi_events])
389
  with gr.TabItem("last output prompt") as tab3:
390
  gr.Markdown("Continue generating on the last output. Just click the generate button")
 
391
 
392
  tab1.select(lambda: 0, None, tab_select, queue=False)
393
  tab2.select(lambda: 1, None, tab_select, queue=False)
@@ -406,18 +430,22 @@ if __name__ == "__main__":
406
  run_btn = gr.Button("generate", variant="primary")
407
  stop_btn = gr.Button("stop and output")
408
  output_midi_seq = gr.State()
 
409
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
410
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
411
  output_midi = gr.File(label="output midi", file_types=[".mid"])
412
- run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, input_instruments,
413
- input_drum_kit, input_bpm, input_time_sig, input_key_sig, input_midi,
414
- input_midi_events, input_reduce_cc_st, input_remap_track_channel,
415
  input_add_default_instr, input_remove_empty_channels,
416
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
417
  input_top_k, input_allow_cc],
418
- [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
 
419
  concurrency_limit=3)
420
  stop_btn.click(cancel_run, [input_model, output_midi_seq],
421
  [output_midi, output_audio, js_msg],
422
  cancels=run_event, queue=False)
 
 
423
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
121
  return json.dumps(msgs)
122
 
123
 
124
+ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  model = models[model_name]
 
187
  for token_seq in mid:
188
  mid_seq.append(token_seq.tolist())
189
  elif tab == 2 and mid_seq is not None:
190
+ continuation_state.append(len(mid_seq))
191
  mid = np.asarray(mid_seq, dtype=np.int64)
192
  else:
193
+ continuation_state = [0]
194
  mid_seq = []
195
  mid = None
196
 
 
198
  max_len += len(mid)
199
 
200
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
201
+ init_msgs = [create_msg("progress", [0, gen_events])]
202
+ if tab != 2:
203
+ init_msgs += [create_msg("visualizer_clear", tokenizer.version),
204
+ create_msg("visualizer_append", events)]
205
+ yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
 
206
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
207
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
208
  disable_channels=disable_channels, generator=generator)
 
214
  events.append(tokenizer.tokens2event(token_seq))
215
  ct = time.time()
216
  if ct - t > 0.5:
217
+ yield (mid_seq, continuation_state, None, None, seed,
218
+ send_msgs([create_msg("visualizer_append", events),
219
+ create_msg("progress", [i + 1, gen_events])]))
220
  t = ct
221
  events = []
222
 
223
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
224
  mid = tokenizer.detokenize(mid_seq)
225
+ audio = synthesizer.synthesis(MIDI.score2opus(mid))
226
  with open(f"output.mid", 'wb') as f:
227
  f.write(MIDI.score2midi(mid))
228
+ end_msgs = [create_msg("visualizer_clear", tokenizer.version),
229
+ create_msg("visualizer_append", events),
230
+ create_msg("visualizer_end", None),
231
+ create_msg("progress", [0, 0])]
232
+ yield mid_seq, continuation_state, "output.mid", (44100, audio), seed, send_msgs(end_msgs)
233
 
234
 
235
  def cancel_run(model_name, mid_seq):
236
  if mid_seq is None:
237
  return None, None, []
238
  tokenizer = models[model_name][2]
239
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
240
  mid = tokenizer.detokenize(mid_seq)
241
+ audio = synthesizer.synthesis(MIDI.score2opus(mid))
242
  with open(f"output.mid", 'wb') as f:
243
  f.write(MIDI.score2midi(mid))
244
+ end_msgs = [create_msg("visualizer_clear", tokenizer.version),
245
+ create_msg("visualizer_append", events),
246
+ create_msg("visualizer_end", None),
247
+ create_msg("progress", [0, 0])]
248
+ return "output.mid", (44100, audio), send_msgs(end_msgs)
249
+
250
+
251
+ def undo_continuation(mid_seq, continuation_state):
252
+ if mid_seq is None or len(continuation_state) < 2:
253
+ return mid_seq, continuation_state, send_msgs([])
254
+ mid_seq = mid_seq[:continuation_state[-1]]
255
+ continuation_state = continuation_state[:-1]
256
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
257
+ end_msgs = [create_msg("visualizer_clear", tokenizer.version),
258
+ create_msg("visualizer_append", events),
259
+ create_msg("visualizer_end", None),
260
+ create_msg("progress", [0, 0])]
261
+ return mid_seq, continuation_state, send_msgs(end_msgs)
262
 
263
 
264
  def load_javascript(dir="javascript"):
 
364
  type="value", value=list(models.keys())[0])
365
  tab_select = gr.State(value=0)
366
  with gr.Tabs():
367
+ with gr.TabItem("custom prompt") as tab1:
368
  input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
369
  multiselect=True, max_choices=15, type="value")
370
  input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
 
411
  [input_midi, input_midi_events])
412
  with gr.TabItem("last output prompt") as tab3:
413
  gr.Markdown("Continue generating on the last output. Just click the generate button")
414
+ undo_btn = gr.Button("undo the last continuation")
415
 
416
  tab1.select(lambda: 0, None, tab_select, queue=False)
417
  tab2.select(lambda: 1, None, tab_select, queue=False)
 
430
  run_btn = gr.Button("generate", variant="primary")
431
  stop_btn = gr.Button("stop and output")
432
  output_midi_seq = gr.State()
433
+ output_continuation_state = gr.State([0])
434
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
435
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
436
  output_midi = gr.File(label="output midi", file_types=[".mid"])
437
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
438
+ input_instruments, input_drum_kit, input_bpm, input_time_sig, input_key_sig,
439
+ input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
440
  input_add_default_instr, input_remove_empty_channels,
441
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
442
  input_top_k, input_allow_cc],
443
+ [output_midi_seq, output_continuation_state,
444
+ output_midi, output_audio, input_seed, js_msg],
445
  concurrency_limit=3)
446
  stop_btn.click(cancel_run, [input_model, output_midi_seq],
447
  [output_midi, output_audio, js_msg],
448
  cancels=run_event, queue=False)
449
+ undo_btn.click(undo_continuation, [output_midi_seq, output_continuation_state],
450
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
451
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
javascript/app.js CHANGED
@@ -400,6 +400,8 @@ customElements.define('midi-visualizer', MidiVisualizer);
400
  }
401
  })
402
 
 
 
403
  function createProgressBar(progressbarContainer){
404
  let parentProgressbar = progressbarContainer.parentNode;
405
  let divProgress = document.createElement('div');
@@ -421,15 +423,23 @@ customElements.define('midi-visualizer', MidiVisualizer);
421
  divInner.style.width = "0%";
422
  divProgress.appendChild(divInner);
423
  parentProgressbar.insertBefore(divProgress, progressbarContainer);
 
424
  }
425
 
426
  function removeProgressBar(progressbarContainer){
427
  let parentProgressbar = progressbarContainer.parentNode;
428
  let divProgress = parentProgressbar.querySelector(".progressDiv");
429
  parentProgressbar.removeChild(divProgress);
 
430
  }
431
 
432
  function setProgressBar(progressbarContainer, progress, total){
 
 
 
 
 
 
433
  let parentProgressbar = progressbarContainer.parentNode;
434
  let divProgress = parentProgressbar.querySelector(".progressDiv");
435
  let divInner = parentProgressbar.querySelector(".progress");
@@ -453,31 +463,21 @@ customElements.define('midi-visualizer', MidiVisualizer);
453
  case "visualizer_clear":
454
  midi_visualizer.clearMidiEvents(false);
455
  midi_visualizer.version = msg.data
456
- createProgressBar(midi_visualizer_container_inited)
457
- break;
458
- case "visualizer_continue":
459
- midi_visualizer.version = msg.data
460
- createProgressBar(midi_visualizer_container_inited)
461
  break;
462
  case "visualizer_append":
463
  msg.data.forEach( value => {
464
  midi_visualizer.appendMidiEvent(value);
465
  })
466
  break;
 
 
 
 
467
  case "progress":
468
  let progress = msg.data[0]
469
  let total = msg.data[1]
470
  setProgressBar(midi_visualizer_container_inited, progress, total)
471
  break;
472
- case "visualizer_end":
473
- midi_visualizer.clearMidiEvents(true);
474
- msg.data.forEach( value => {
475
- midi_visualizer.appendMidiEvent(value);
476
- })
477
- midi_visualizer.finishAppendMidiEvent()
478
- midi_visualizer.setPlayTime(0);
479
- removeProgressBar(midi_visualizer_container_inited);
480
- break;
481
  default:
482
  }
483
  }
 
400
  }
401
  })
402
 
403
+ let hasProgressBar = false;
404
+
405
  function createProgressBar(progressbarContainer){
406
  let parentProgressbar = progressbarContainer.parentNode;
407
  let divProgress = document.createElement('div');
 
423
  divInner.style.width = "0%";
424
  divProgress.appendChild(divInner);
425
  parentProgressbar.insertBefore(divProgress, progressbarContainer);
426
+ hasProgressBar = true;
427
  }
428
 
429
  function removeProgressBar(progressbarContainer){
430
  let parentProgressbar = progressbarContainer.parentNode;
431
  let divProgress = parentProgressbar.querySelector(".progressDiv");
432
  parentProgressbar.removeChild(divProgress);
433
+ hasProgressBar = false;
434
  }
435
 
436
  function setProgressBar(progressbarContainer, progress, total){
437
+ if (!hasProgressBar)
438
+ createProgressBar(midi_visualizer_container_inited)
439
+ if (hasProgressBar && total === 0){
440
+ removeProgressBar(midi_visualizer_container_inited)
441
+ return
442
+ }
443
  let parentProgressbar = progressbarContainer.parentNode;
444
  let divProgress = parentProgressbar.querySelector(".progressDiv");
445
  let divInner = parentProgressbar.querySelector(".progress");
 
463
  case "visualizer_clear":
464
  midi_visualizer.clearMidiEvents(false);
465
  midi_visualizer.version = msg.data
 
 
 
 
 
466
  break;
467
  case "visualizer_append":
468
  msg.data.forEach( value => {
469
  midi_visualizer.appendMidiEvent(value);
470
  })
471
  break;
472
+ case "visualizer_end":
473
+ midi_visualizer.finishAppendMidiEvent()
474
+ midi_visualizer.setPlayTime(0);
475
+ break;
476
  case "progress":
477
  let progress = msg.data[0]
478
  let total = msg.data[1]
479
  setProgressBar(midi_visualizer_container_inited, progress, total)
480
  break;
 
 
 
 
 
 
 
 
 
481
  default:
482
  }
483
  }
midi_synthesizer.py CHANGED
@@ -1,6 +1,7 @@
1
  import fluidsynth
2
  import numpy as np
3
 
 
4
  class MidiSynthesizer:
5
  def __init__(self, soundfont_path, sample_rate=44100):
6
  self.soundfont_path = soundfont_path
@@ -21,8 +22,8 @@ class MidiSynthesizer:
21
  return device
22
 
23
  def release_fluidsynth(self, device):
24
- device[0].system_reset()
25
  device[0].get_samples(self.sample_rate*5) # wait for silence
 
26
  device[2] = False
27
 
28
  def synthesis(self, midi_opus):
@@ -73,4 +74,4 @@ class MidiSynthesizer:
73
  if max_val != 0:
74
  ss = (ss / max_val) * np.iinfo(np.int16).max
75
  ss = ss.astype(np.int16)
76
- return ss
 
1
  import fluidsynth
2
  import numpy as np
3
 
4
+
5
  class MidiSynthesizer:
6
  def __init__(self, soundfont_path, sample_rate=44100):
7
  self.soundfont_path = soundfont_path
 
22
  return device
23
 
24
  def release_fluidsynth(self, device):
 
25
  device[0].get_samples(self.sample_rate*5) # wait for silence
26
+ device[0].system_reset()
27
  device[2] = False
28
 
29
  def synthesis(self, midi_opus):
 
74
  if max_val != 0:
75
  ss = (ss / max_val) * np.iinfo(np.int16).max
76
  ss = ss.astype(np.int16)
77
+ return ss