File size: 9,555 Bytes
bd6e54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import librosa
import numpy as np
import torch
import gradio as gr
import mido
from io import BytesIO
# import pyrubberband as pyrb
import torchaudio.transforms as transforms

from model.DiffSynthSampler import DiffSynthSampler
from tools import adsr_envelope, adjust_audio_length
from webUI.natural_language_guided.track_maker import DiffSynth, Track
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \
    spectrogram_to_Gradio_image


def get_arrangement_module(gradioWebUI, virtual_instruments_state, midi_files_state):
    # Load configurations
    uNet = gradioWebUI.uNet
    freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
    VAE_scale = gradioWebUI.VAE_scale
    height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels

    timesteps = gradioWebUI.timesteps
    VAE_quantizer = gradioWebUI.VAE_quantizer
    VAE_decoder = gradioWebUI.VAE_decoder
    CLAP = gradioWebUI.CLAP
    CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
    device = gradioWebUI.device
    squared = gradioWebUI.squared
    sample_rate = gradioWebUI.sample_rate
    noise_strategy = gradioWebUI.noise_strategy

    def read_midi(midi, midi_files_dict):
        print(midi)
        midi_name = midi_file.name
        mid = mido.MidiFile(file=BytesIO(midi))
        tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]

        midi_info_text = f"Name: {midi_name}"
        for track in tracks:
            midi_info_text += f"\n {len(track.events)}"


        return {midi_info_textbox: gr.Textbox(label="Midi info", lines=10,
                                                            placeholder=midi_info_text),
                midi_files_state: midi_files_dict}

    def refresh_instruments(virtual_instruments_dict):
        virtual_instruments_names = list(virtual_instruments_dict["virtual_instruments"].keys())
        print(f"virtual_instruments_names: {virtual_instruments_names}")

        return {select_instrument_dropdown: gr.Dropdown.update(choices=["New Option 1", "New Option 2", "New Option 3"])}

    def select_sound(virtual_instrument_name, virtual_instruments_dict):
        virtual_instruments = virtual_instruments_dict["virtual_instruments"]
        virtual_instrument = virtual_instruments[virtual_instrument_name]

        return {source_sound_spectrogram_image: virtual_instrument["spectrogram_gradio_image"],
                source_sound_phase_image: virtual_instrument["phase_gradio_image"],
                source_sound_audio: virtual_instrument["signal"]}

    def make_track(inpaint_steps, midi, noising_strength, attack, before_release, instrument_names,

                   virtual_instruments_dict):

        if noising_strength < 1:
            print(f"Warning: making track with noising_strength = {noising_strength} < 1")
        virtual_instruments = virtual_instruments_dict["virtual_instruments"]
        sample_steps = int(inpaint_steps)

        instrument_names = instrument_names.split("@")
        instruments_configs = {}
        for virtual_instrument_name in instrument_names:
            virtual_instrument = virtual_instruments[virtual_instrument_name]

            latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(
                device)
            sampler = virtual_instrument["sampler"]

            batchsize = 1

            latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)

            mid = mido.MidiFile(file=BytesIO(midi))
            instruments_configs[virtual_instrument_name] = {
                'sample_steps': sample_steps,
                'sampler': sampler,
                'noising_strength': noising_strength,
                'latent_representation': latent_representation,
                'attack': attack,
                'before_release': before_release}

        diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device)

        full_audio = diffSynth.get_music(mid, instrument_names)

        return {track_audio: (sample_rate, full_audio)}

    with gr.Tab("Arrangement"):
        gr.Markdown("Make music with generated sounds!")
        with gr.Row(variant="panel"):
            with gr.Column(scale=3):
                preset_button_1 = gr.Button(variant="primary", value="Ode_to_Joy", scale=1)
                preset_button_2 = gr.Button(variant="primary", value="Ode_to_Joy", scale=1)
                preset_button_3 = gr.Button(variant="primary", value="Ode_to_Joy", scale=1)
                midi_file = gr.File(label="Upload midi file", type="binary", scale=2)
            with gr.Column(scale=3):
                midi_info_textbox = gr.Textbox(label="Midi info", lines=10, placeholder="Please select/upload a midi on the left.")
                instrument_names_textbox = gr.Textbox(label="Instrument names", lines=2,
                                                     placeholder="Names of your instrument used to play the midi", scale=1)
            with gr.Column(scale=3):
                refresh_instrument_button = gr.Button(variant="primary", value="Refresh instruments", scale=1)
                # instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1,
                #                                      placeholder="Name of your instrument", scale=1)
                select_instrument_dropdown = gr.Dropdown(choices=["Option 1", "Option 2", "Option 3"], label="Choose an option")
                source_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
            with gr.Column(scale=3):
                make_track_button = gr.Button(variant="primary", value="Make track", scale=1)
                track_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
        with gr.Row(variant="panel"):
            with gr.Tab("Origin sound"):
                inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0,
                                                 label="inpaint_steps")
                noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.)
                end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01,
                                                         label="end_noise_level_ratio")
                attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec")
                before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01,
                                                  label="before_release in sec")
                release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec")
                mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01,
                                                  label="mask_flexivity")
            with gr.Tab("Length adjustment config"):
                use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True)
                test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1)
                test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1)
                test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1)
                duration_slider = gradioWebUI.get_duration_slider()
            with gr.Tab("Pitch shift config"):
                pitch_shift_radio = gr.Radio(choices=["librosa", "torchaudio", "rubberband"],
                                                          value="librosa")

        with gr.Row(variant="panel"):
            with gr.Column(scale=2):
                with gr.Row(variant="panel"):
                    source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
                                                              height=600, scale=1)
                    source_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
                                                        height=600, scale=1)



    # instrument_name_textbox.change(select_sound,
    #                                inputs=[instrument_name_textbox, virtual_instruments_state],
    #                                outputs=[source_sound_audio])

    refresh_instrument_button.click(refresh_instruments,
                                   inputs=[virtual_instruments_state],
                                   outputs=[select_instrument_dropdown])

    make_track_button.click(make_track,
                            inputs=[inpaint_steps_slider, midi_file,
                                    noising_strength_slider,
                                    attack_slider,
                                    before_release_slider,
                                    instrument_names_textbox,
                                    virtual_instruments_state],
                            outputs=[track_audio])

    midi_file.change(read_midi,
                     inputs=[midi_file,
                             midi_files_state],
                     outputs=[midi_info_textbox,
                              midi_files_state])