Spaces:
Paused
Paused
add undo button
Browse files- app.py +47 -19
- javascript/app.js +14 -14
- midi_synthesizer.py +3 -2
app.py
CHANGED
@@ -121,7 +121,7 @@ def send_msgs(msgs):
|
|
121 |
return json.dumps(msgs)
|
122 |
|
123 |
|
124 |
-
def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
125 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
126 |
gen_events, temp, top_p, top_k, allow_cc):
|
127 |
model = models[model_name]
|
@@ -187,8 +187,10 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
|
|
187 |
for token_seq in mid:
|
188 |
mid_seq.append(token_seq.tolist())
|
189 |
elif tab == 2 and mid_seq is not None:
|
|
|
190 |
mid = np.asarray(mid_seq, dtype=np.int64)
|
191 |
else:
|
|
|
192 |
mid_seq = []
|
193 |
mid = None
|
194 |
|
@@ -196,12 +198,11 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
|
|
196 |
max_len += len(mid)
|
197 |
|
198 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
yield mid_seq, None, None, seed, send_msgs(init_msgs)
|
205 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
206 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
207 |
disable_channels=disable_channels, generator=generator)
|
@@ -213,29 +214,51 @@ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, time_sig, key_sig,
|
|
213 |
events.append(tokenizer.tokens2event(token_seq))
|
214 |
ct = time.time()
|
215 |
if ct - t > 0.5:
|
216 |
-
yield mid_seq, None, None, seed,
|
217 |
-
|
|
|
218 |
t = ct
|
219 |
events = []
|
220 |
|
|
|
221 |
mid = tokenizer.detokenize(mid_seq)
|
|
|
222 |
with open(f"output.mid", 'wb') as f:
|
223 |
f.write(MIDI.score2midi(mid))
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
227 |
|
228 |
|
229 |
def cancel_run(model_name, mid_seq):
|
230 |
if mid_seq is None:
|
231 |
return None, None, []
|
232 |
tokenizer = models[model_name][2]
|
|
|
233 |
mid = tokenizer.detokenize(mid_seq)
|
|
|
234 |
with open(f"output.mid", 'wb') as f:
|
235 |
f.write(MIDI.score2midi(mid))
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
238 |
-
|
|
|
|
|
|
|
|
|
239 |
|
240 |
|
241 |
def load_javascript(dir="javascript"):
|
@@ -341,7 +364,7 @@ if __name__ == "__main__":
|
|
341 |
type="value", value=list(models.keys())[0])
|
342 |
tab_select = gr.State(value=0)
|
343 |
with gr.Tabs():
|
344 |
-
with gr.TabItem("
|
345 |
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
346 |
multiselect=True, max_choices=15, type="value")
|
347 |
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
@@ -388,6 +411,7 @@ if __name__ == "__main__":
|
|
388 |
[input_midi, input_midi_events])
|
389 |
with gr.TabItem("last output prompt") as tab3:
|
390 |
gr.Markdown("Continue generating on the last output. Just click the generate button")
|
|
|
391 |
|
392 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
393 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
@@ -406,18 +430,22 @@ if __name__ == "__main__":
|
|
406 |
run_btn = gr.Button("generate", variant="primary")
|
407 |
stop_btn = gr.Button("stop and output")
|
408 |
output_midi_seq = gr.State()
|
|
|
409 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
410 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
411 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
412 |
-
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq,
|
413 |
-
input_drum_kit, input_bpm, input_time_sig, input_key_sig,
|
414 |
-
input_midi_events, input_reduce_cc_st, input_remap_track_channel,
|
415 |
input_add_default_instr, input_remove_empty_channels,
|
416 |
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
417 |
input_top_k, input_allow_cc],
|
418 |
-
[output_midi_seq,
|
|
|
419 |
concurrency_limit=3)
|
420 |
stop_btn.click(cancel_run, [input_model, output_midi_seq],
|
421 |
[output_midi, output_audio, js_msg],
|
422 |
cancels=run_event, queue=False)
|
|
|
|
|
423 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
121 |
return json.dumps(msgs)
|
122 |
|
123 |
|
124 |
+
def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
125 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
126 |
gen_events, temp, top_p, top_k, allow_cc):
|
127 |
model = models[model_name]
|
|
|
187 |
for token_seq in mid:
|
188 |
mid_seq.append(token_seq.tolist())
|
189 |
elif tab == 2 and mid_seq is not None:
|
190 |
+
continuation_state.append(len(mid_seq))
|
191 |
mid = np.asarray(mid_seq, dtype=np.int64)
|
192 |
else:
|
193 |
+
continuation_state = [0]
|
194 |
mid_seq = []
|
195 |
mid = None
|
196 |
|
|
|
198 |
max_len += len(mid)
|
199 |
|
200 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
201 |
+
init_msgs = [create_msg("progress", [0, gen_events])]
|
202 |
+
if tab != 2:
|
203 |
+
init_msgs += [create_msg("visualizer_clear", tokenizer.version),
|
204 |
+
create_msg("visualizer_append", events)]
|
205 |
+
yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
|
|
|
206 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
207 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
208 |
disable_channels=disable_channels, generator=generator)
|
|
|
214 |
events.append(tokenizer.tokens2event(token_seq))
|
215 |
ct = time.time()
|
216 |
if ct - t > 0.5:
|
217 |
+
yield (mid_seq, continuation_state, None, None, seed,
|
218 |
+
send_msgs([create_msg("visualizer_append", events),
|
219 |
+
create_msg("progress", [i + 1, gen_events])]))
|
220 |
t = ct
|
221 |
events = []
|
222 |
|
223 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
224 |
mid = tokenizer.detokenize(mid_seq)
|
225 |
+
audio = synthesizer.synthesis(MIDI.score2opus(mid))
|
226 |
with open(f"output.mid", 'wb') as f:
|
227 |
f.write(MIDI.score2midi(mid))
|
228 |
+
end_msgs = [create_msg("visualizer_clear", tokenizer.version),
|
229 |
+
create_msg("visualizer_append", events),
|
230 |
+
create_msg("visualizer_end", None),
|
231 |
+
create_msg("progress", [0, 0])]
|
232 |
+
yield mid_seq, continuation_state, "output.mid", (44100, audio), seed, send_msgs(end_msgs)
|
233 |
|
234 |
|
235 |
def cancel_run(model_name, mid_seq):
|
236 |
if mid_seq is None:
|
237 |
return None, None, []
|
238 |
tokenizer = models[model_name][2]
|
239 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
240 |
mid = tokenizer.detokenize(mid_seq)
|
241 |
+
audio = synthesizer.synthesis(MIDI.score2opus(mid))
|
242 |
with open(f"output.mid", 'wb') as f:
|
243 |
f.write(MIDI.score2midi(mid))
|
244 |
+
end_msgs = [create_msg("visualizer_clear", tokenizer.version),
|
245 |
+
create_msg("visualizer_append", events),
|
246 |
+
create_msg("visualizer_end", None),
|
247 |
+
create_msg("progress", [0, 0])]
|
248 |
+
return "output.mid", (44100, audio), send_msgs(end_msgs)
|
249 |
+
|
250 |
+
|
251 |
+
def undo_continuation(mid_seq, continuation_state):
|
252 |
+
if mid_seq is None or len(continuation_state) < 2:
|
253 |
+
return mid_seq, continuation_state, send_msgs([])
|
254 |
+
mid_seq = mid_seq[:continuation_state[-1]]
|
255 |
+
continuation_state = continuation_state[:-1]
|
256 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
257 |
+
end_msgs = [create_msg("visualizer_clear", tokenizer.version),
|
258 |
+
create_msg("visualizer_append", events),
|
259 |
+
create_msg("visualizer_end", None),
|
260 |
+
create_msg("progress", [0, 0])]
|
261 |
+
return mid_seq, continuation_state, send_msgs(end_msgs)
|
262 |
|
263 |
|
264 |
def load_javascript(dir="javascript"):
|
|
|
364 |
type="value", value=list(models.keys())[0])
|
365 |
tab_select = gr.State(value=0)
|
366 |
with gr.Tabs():
|
367 |
+
with gr.TabItem("custom prompt") as tab1:
|
368 |
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
369 |
multiselect=True, max_choices=15, type="value")
|
370 |
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
|
|
411 |
[input_midi, input_midi_events])
|
412 |
with gr.TabItem("last output prompt") as tab3:
|
413 |
gr.Markdown("Continue generating on the last output. Just click the generate button")
|
414 |
+
undo_btn = gr.Button("undo the last continuation")
|
415 |
|
416 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
417 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
|
|
430 |
run_btn = gr.Button("generate", variant="primary")
|
431 |
stop_btn = gr.Button("stop and output")
|
432 |
output_midi_seq = gr.State()
|
433 |
+
output_continuation_state = gr.State([0])
|
434 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
435 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
436 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
437 |
+
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
438 |
+
input_instruments, input_drum_kit, input_bpm, input_time_sig, input_key_sig,
|
439 |
+
input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
|
440 |
input_add_default_instr, input_remove_empty_channels,
|
441 |
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
442 |
input_top_k, input_allow_cc],
|
443 |
+
[output_midi_seq, output_continuation_state,
|
444 |
+
output_midi, output_audio, input_seed, js_msg],
|
445 |
concurrency_limit=3)
|
446 |
stop_btn.click(cancel_run, [input_model, output_midi_seq],
|
447 |
[output_midi, output_audio, js_msg],
|
448 |
cancels=run_event, queue=False)
|
449 |
+
undo_btn.click(undo_continuation, [output_midi_seq, output_continuation_state],
|
450 |
+
[output_midi_seq, output_continuation_state, js_msg], queue=False)
|
451 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
CHANGED
@@ -400,6 +400,8 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
400 |
}
|
401 |
})
|
402 |
|
|
|
|
|
403 |
function createProgressBar(progressbarContainer){
|
404 |
let parentProgressbar = progressbarContainer.parentNode;
|
405 |
let divProgress = document.createElement('div');
|
@@ -421,15 +423,23 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
421 |
divInner.style.width = "0%";
|
422 |
divProgress.appendChild(divInner);
|
423 |
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
|
|
424 |
}
|
425 |
|
426 |
function removeProgressBar(progressbarContainer){
|
427 |
let parentProgressbar = progressbarContainer.parentNode;
|
428 |
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
429 |
parentProgressbar.removeChild(divProgress);
|
|
|
430 |
}
|
431 |
|
432 |
function setProgressBar(progressbarContainer, progress, total){
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
let parentProgressbar = progressbarContainer.parentNode;
|
434 |
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
435 |
let divInner = parentProgressbar.querySelector(".progress");
|
@@ -453,31 +463,21 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
453 |
case "visualizer_clear":
|
454 |
midi_visualizer.clearMidiEvents(false);
|
455 |
midi_visualizer.version = msg.data
|
456 |
-
createProgressBar(midi_visualizer_container_inited)
|
457 |
-
break;
|
458 |
-
case "visualizer_continue":
|
459 |
-
midi_visualizer.version = msg.data
|
460 |
-
createProgressBar(midi_visualizer_container_inited)
|
461 |
break;
|
462 |
case "visualizer_append":
|
463 |
msg.data.forEach( value => {
|
464 |
midi_visualizer.appendMidiEvent(value);
|
465 |
})
|
466 |
break;
|
|
|
|
|
|
|
|
|
467 |
case "progress":
|
468 |
let progress = msg.data[0]
|
469 |
let total = msg.data[1]
|
470 |
setProgressBar(midi_visualizer_container_inited, progress, total)
|
471 |
break;
|
472 |
-
case "visualizer_end":
|
473 |
-
midi_visualizer.clearMidiEvents(true);
|
474 |
-
msg.data.forEach( value => {
|
475 |
-
midi_visualizer.appendMidiEvent(value);
|
476 |
-
})
|
477 |
-
midi_visualizer.finishAppendMidiEvent()
|
478 |
-
midi_visualizer.setPlayTime(0);
|
479 |
-
removeProgressBar(midi_visualizer_container_inited);
|
480 |
-
break;
|
481 |
default:
|
482 |
}
|
483 |
}
|
|
|
400 |
}
|
401 |
})
|
402 |
|
403 |
+
let hasProgressBar = false;
|
404 |
+
|
405 |
function createProgressBar(progressbarContainer){
|
406 |
let parentProgressbar = progressbarContainer.parentNode;
|
407 |
let divProgress = document.createElement('div');
|
|
|
423 |
divInner.style.width = "0%";
|
424 |
divProgress.appendChild(divInner);
|
425 |
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
426 |
+
hasProgressBar = true;
|
427 |
}
|
428 |
|
429 |
function removeProgressBar(progressbarContainer){
|
430 |
let parentProgressbar = progressbarContainer.parentNode;
|
431 |
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
432 |
parentProgressbar.removeChild(divProgress);
|
433 |
+
hasProgressBar = false;
|
434 |
}
|
435 |
|
436 |
function setProgressBar(progressbarContainer, progress, total){
|
437 |
+
if (!hasProgressBar)
|
438 |
+
createProgressBar(midi_visualizer_container_inited)
|
439 |
+
if (hasProgressBar && total === 0){
|
440 |
+
removeProgressBar(midi_visualizer_container_inited)
|
441 |
+
return
|
442 |
+
}
|
443 |
let parentProgressbar = progressbarContainer.parentNode;
|
444 |
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
445 |
let divInner = parentProgressbar.querySelector(".progress");
|
|
|
463 |
case "visualizer_clear":
|
464 |
midi_visualizer.clearMidiEvents(false);
|
465 |
midi_visualizer.version = msg.data
|
|
|
|
|
|
|
|
|
|
|
466 |
break;
|
467 |
case "visualizer_append":
|
468 |
msg.data.forEach( value => {
|
469 |
midi_visualizer.appendMidiEvent(value);
|
470 |
})
|
471 |
break;
|
472 |
+
case "visualizer_end":
|
473 |
+
midi_visualizer.finishAppendMidiEvent()
|
474 |
+
midi_visualizer.setPlayTime(0);
|
475 |
+
break;
|
476 |
case "progress":
|
477 |
let progress = msg.data[0]
|
478 |
let total = msg.data[1]
|
479 |
setProgressBar(midi_visualizer_container_inited, progress, total)
|
480 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
default:
|
482 |
}
|
483 |
}
|
midi_synthesizer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import fluidsynth
|
2 |
import numpy as np
|
3 |
|
|
|
4 |
class MidiSynthesizer:
|
5 |
def __init__(self, soundfont_path, sample_rate=44100):
|
6 |
self.soundfont_path = soundfont_path
|
@@ -21,8 +22,8 @@ class MidiSynthesizer:
|
|
21 |
return device
|
22 |
|
23 |
def release_fluidsynth(self, device):
|
24 |
-
device[0].system_reset()
|
25 |
device[0].get_samples(self.sample_rate*5) # wait for silence
|
|
|
26 |
device[2] = False
|
27 |
|
28 |
def synthesis(self, midi_opus):
|
@@ -73,4 +74,4 @@ class MidiSynthesizer:
|
|
73 |
if max_val != 0:
|
74 |
ss = (ss / max_val) * np.iinfo(np.int16).max
|
75 |
ss = ss.astype(np.int16)
|
76 |
-
return ss
|
|
|
1 |
import fluidsynth
|
2 |
import numpy as np
|
3 |
|
4 |
+
|
5 |
class MidiSynthesizer:
|
6 |
def __init__(self, soundfont_path, sample_rate=44100):
|
7 |
self.soundfont_path = soundfont_path
|
|
|
22 |
return device
|
23 |
|
24 |
def release_fluidsynth(self, device):
|
|
|
25 |
device[0].get_samples(self.sample_rate*5) # wait for silence
|
26 |
+
device[0].system_reset()
|
27 |
device[2] = False
|
28 |
|
29 |
def synthesis(self, midi_opus):
|
|
|
74 |
if max_val != 0:
|
75 |
ss = (ss / max_val) * np.iinfo(np.int16).max
|
76 |
ss = ss.astype(np.int16)
|
77 |
+
return ss
|