Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import mido | |
| from io import BytesIO | |
| # import pyrubberband as pyrb | |
| from webUI.natural_language_guided_4.track_maker import DiffSynth, Track | |
| def get_arrangement_module(gradioWebUI, virtual_instruments_state, midi_files_state): | |
| # Load configurations | |
| uNet = gradioWebUI.uNet | |
| freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution | |
| VAE_scale = gradioWebUI.VAE_scale | |
| height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels | |
| timesteps = gradioWebUI.timesteps | |
| VAE_quantizer = gradioWebUI.VAE_quantizer | |
| VAE_decoder = gradioWebUI.VAE_decoder | |
| CLAP = gradioWebUI.CLAP | |
| CLAP_tokenizer = gradioWebUI.CLAP_tokenizer | |
| device = gradioWebUI.device | |
| squared = gradioWebUI.squared | |
| sample_rate = gradioWebUI.sample_rate | |
| noise_strategy = gradioWebUI.noise_strategy | |
| def read_midi(midi, midi_dict): | |
| mid = mido.MidiFile(file=BytesIO(midi)) | |
| tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] | |
| midi_info_text = f"Uploaded midi:" | |
| for i, track in enumerate(tracks): | |
| midi_info_text += f"\n{len(track.events)} events loaded from Track {i}." | |
| midis = midi_dict["midis"] | |
| midis["uploaded_midi"] = mid | |
| midi_dict["midis"] = midis | |
| return {midi_info_textbox: gr.Textbox(label="Midi info", lines=10, | |
| placeholder=midi_info_text), | |
| current_midi_state: "uploaded_midi", | |
| midi_files_state: midi_dict} | |
| def make_track(inpaint_steps, current_midi_name, midi_dict, max_notes, noising_strength, attack, before_release, current_instruments, | |
| virtual_instruments_dict): | |
| if noising_strength < 1: | |
| print(f"Warning: making track with noising_strength = {noising_strength} < 1") | |
| virtual_instruments = virtual_instruments_dict["virtual_instruments"] | |
| sample_steps = int(inpaint_steps) | |
| print(f"current_instruments: {current_instruments}") | |
| instrument_names = current_instruments | |
| instruments_configs = {} | |
| for virtual_instrument_name in instrument_names: | |
| virtual_instrument = virtual_instruments[virtual_instrument_name] | |
| latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to( | |
| device) | |
| sampler = virtual_instrument["sampler"] | |
| batchsize = 1 | |
| latent_representation = latent_representation.repeat(batchsize, 1, 1, 1) | |
| instruments_configs[virtual_instrument_name] = { | |
| 'sample_steps': sample_steps, | |
| 'sampler': sampler, | |
| 'noising_strength': noising_strength, | |
| 'latent_representation': latent_representation, | |
| 'attack': attack, | |
| 'before_release': before_release} | |
| diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device) | |
| midis = midi_dict["midis"] | |
| mid = midis[current_midi_name] | |
| full_audio = diffSynth.get_music(mid, instrument_names, max_notes=max_notes) | |
| return {track_audio: (sample_rate, full_audio)} | |
| with gr.Tab("Arrangement"): | |
| default_instrument = "preset_string" | |
| current_instruments_state = gr.State(value=[default_instrument for _ in range(100)]) | |
| current_midi_state = gr.State(value="Ode_to_Joy_Easy_variation") | |
| gr.Markdown("Make music with generated sounds!") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=3): | |
| def check_midis(midi_dict): | |
| midis = midi_dict["midis"] | |
| midi_names = list(midis.keys()) | |
| instrument_dropdown = gr.Dropdown( | |
| midi_names, label="Select from preset midi files", value="Ode_to_Joy_Easy_variation" | |
| ) | |
| def select_midi(midi_name): | |
| # print(f"midi_name: {midi_name}") | |
| mid = midis[midi_name] | |
| tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] | |
| midi_info_text = f"Name: {midi_name}" | |
| for i, track in enumerate(tracks): | |
| midi_info_text += f"\n{len(track.events)} events loaded from Track {i}." | |
| return {midi_info_textbox: gr.Textbox(label="Midi info", lines=10, | |
| placeholder=midi_info_text), | |
| current_midi_state: midi_name} | |
| instrument_dropdown.select(select_midi, inputs=instrument_dropdown, | |
| outputs=[midi_info_textbox, current_midi_state]) | |
| midi_file = gr.File(label="Upload a midi file", type="binary", scale=1) | |
| midi_info_textbox = gr.Textbox(label="Midi info", lines=10, | |
| placeholder="Please select/upload a midi on the left.", scale=3, | |
| visible=False) | |
| with gr.Column(scale=3, ): | |
| def render_select_instruments(current_midi_name, midi_dict, virtual_instruments_dict): | |
| virtual_instruments = virtual_instruments_dict["virtual_instruments"] | |
| instrument_names = list(virtual_instruments.keys()) | |
| midis = midi_dict["midis"] | |
| mid = midis[current_midi_name] | |
| tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] | |
| dropdowns = [] | |
| for i, track in enumerate(tracks): | |
| dropdowns.append(gr.Dropdown( | |
| instrument_names, value=default_instrument, label=f"Track {i}: {len(track.events)} notes", | |
| info=f"Select an instrument to play this track!" | |
| )) | |
| def select_instruments(*instruments): | |
| return instruments | |
| for d in dropdowns: | |
| d.select(select_instruments, inputs=dropdowns, | |
| outputs=current_instruments_state) | |
| with gr.Column(scale=3): | |
| max_notes_slider = gr.Slider(minimum=10.0, maximum=999.0, value=100.0, step=1.0, | |
| label="Maximum number of synthesized notes in each track", | |
| info="Lower this value to prevent Gradio timeouts") | |
| make_track_button = gr.Button(variant="primary", value="Make track", scale=1) | |
| track_audio = gr.Audio(type="numpy", label="Play music", interactive=False) | |
| with gr.Row(variant="panel", visible=False): | |
| with gr.Tab("Origin sound"): | |
| inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0, | |
| label="inpaint_steps") | |
| noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.) | |
| end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01, | |
| label="end_noise_level_ratio") | |
| attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec") | |
| before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, | |
| label="before_release in sec") | |
| release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec") | |
| mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01, | |
| label="mask_flexivity") | |
| with gr.Tab("Length adjustment config"): | |
| use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True) | |
| test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1) | |
| test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1) | |
| test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1) | |
| duration_slider = gradioWebUI.get_duration_slider() | |
| with gr.Tab("Pitch shift config"): | |
| pitch_shift_radio = gr.Radio(choices=["librosa", "torchaudio", "rubberband"], | |
| value="librosa") | |
| with gr.Row(variant="panel", visible=False): | |
| with gr.Column(scale=2): | |
| with gr.Row(variant="panel"): | |
| source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", | |
| height=600, scale=1) | |
| source_sound_phase_image = gr.Image(label="New sound phase", type="numpy", | |
| height=600, scale=1) | |
| make_track_button.click(make_track, | |
| inputs=[inpaint_steps_slider, current_midi_state, midi_files_state, | |
| max_notes_slider, noising_strength_slider, | |
| attack_slider, | |
| before_release_slider, | |
| current_instruments_state, | |
| virtual_instruments_state], | |
| outputs=[track_audio]) | |
| midi_file.change(read_midi, | |
| inputs=[midi_file, | |
| midi_files_state], | |
| outputs=[midi_info_textbox, | |
| current_midi_state, | |
| midi_files_state]) | |