Spaces:
Build error
Build error
WeixuanYuan
commited on
Commit
•
b88cc47
1
Parent(s):
18a55e0
Upload 31 files
Browse files- MyTest.py +3 -0
- NN.json +0 -0
- app.py +184 -0
- configurations/conf.json +83 -0
- configurations/read_configuration.py +152 -0
- data_generation/data_generation.py +380 -0
- data_generation/decoding.py +64 -0
- data_generation/encoding.py +92 -0
- example.ipynb +0 -0
- external sources.txt +3 -0
- generate_synthetic_data_online.py +431 -0
- load_data.py +150 -0
- melody_synth/complex_torch_synth.py +221 -0
- melody_synth/melody_generator.py +121 -0
- melody_synth/non_random_LFOs.py +121 -0
- melody_synth/random_duration.py +86 -0
- melody_synth/random_midi.py +86 -0
- melody_synth/random_pitch.py +87 -0
- melody_synth/random_rhythm.py +143 -0
- model/VAE.py +230 -0
- model/VAE_torchV.py +171 -0
- model/perceptual_label_predictor.py +68 -0
- models/decoder_5_13.pt +3 -0
- models/encoder_5_13.pt +3 -0
- models/new_trained_models/perceptual_label_predictor.h5 +3 -0
- models/perceptual_label_predictor.h5 +3 -0
- new_sound_generation.py +203 -0
- requirements.txt +5 -0
- test_audio.wav +0 -0
- tools.py +84 -0
- webUI/initial_example_encodes.json +210 -0
MyTest.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from load_data import load_data
|
2 |
+
|
3 |
+
data_cache = load_data(500)
|
NN.json
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import librosa
|
7 |
+
from tools import VAE_out_put_to_spc, np_power_to_db
|
8 |
+
from model.VAE_torchV import Encoder, Decoder
|
9 |
+
|
10 |
+
SPECTROGRAM_RESOLUTION = (512, 256, 3)
|
11 |
+
encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to("cuda")
|
12 |
+
decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to("cuda")
|
13 |
+
model_name = "test"
|
14 |
+
encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt"))
|
15 |
+
decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt"))
|
16 |
+
|
17 |
+
INIT_ENCODE_CACHE = {"init": np.random.random((24, ))}
|
18 |
+
with open('webUI/initial_example_encodes.json', 'r') as f:
|
19 |
+
list_dict = json.load(f)
|
20 |
+
for k in list_dict.keys():
|
21 |
+
INIT_ENCODE_CACHE[k] = np.array(list_dict[k])
|
22 |
+
|
23 |
+
#################################
|
24 |
+
def prepare_image(image):
|
25 |
+
# Rescale to 0-255 and convert to uint8
|
26 |
+
rescaled = (image + 80.0) / 80.0
|
27 |
+
rescaled = (255.0 * rescaled).astype(np.uint8)
|
28 |
+
return rescaled
|
29 |
+
|
30 |
+
def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)):
|
31 |
+
"""Show a spectrogram."""
|
32 |
+
reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
|
33 |
+
flipped_log_spectrums, rec_signals = [], []
|
34 |
+
for reconstruction in reconstruction_batch:
|
35 |
+
spc = VAE_out_put_to_spc(reconstruction)
|
36 |
+
spc = np.reshape(spc, resolution)
|
37 |
+
magnitude_spectrum = np.abs(spc)
|
38 |
+
log_spectrum = np_power_to_db(magnitude_spectrum)
|
39 |
+
flipped_log_spectrum = np.flipud(log_spectrum)
|
40 |
+
|
41 |
+
colorful_spc = np.ones((512, 256, 3)) * -80.0
|
42 |
+
colorful_spc[:, :, 0] = flipped_log_spectrum
|
43 |
+
colorful_spc[:, :, 1] = flipped_log_spectrum
|
44 |
+
colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0
|
45 |
+
flipped_log_spectrum = prepare_image(colorful_spc)
|
46 |
+
|
47 |
+
# get_audio
|
48 |
+
abs_spec = np.zeros((513, 256))
|
49 |
+
abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256)))
|
50 |
+
rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
51 |
+
flipped_log_spectrums.append(flipped_log_spectrum)
|
52 |
+
rec_signals.append(rec_signal)
|
53 |
+
|
54 |
+
return flipped_log_spectrums, 16000, rec_signals
|
55 |
+
|
56 |
+
|
57 |
+
def get_example_module(encodeCache):
|
58 |
+
def show_example(selected_example, encodeCache):
|
59 |
+
example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to("cuda")
|
60 |
+
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode)
|
61 |
+
flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0]
|
62 |
+
|
63 |
+
return flipped_log_spectrum, str(encodeCache), (sampleRate, rec_signal), encodeCache
|
64 |
+
|
65 |
+
with gr.Tab("Examples"):
|
66 |
+
gr.Markdown("Predefined examples.")
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column():
|
69 |
+
selected_example = gr.Dropdown(
|
70 |
+
list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example!"
|
71 |
+
)
|
72 |
+
example_button = gr.Button(value="Show example")
|
73 |
+
with gr.Column():
|
74 |
+
example_image_output = gr.Image(label="Reconstruction", type="numpy")
|
75 |
+
example_image_output.style(height=250, width=600)
|
76 |
+
example_audio_output = gr.Audio(type="numpy", label="Play reconstruction output!")
|
77 |
+
example_text_output = gr.Textbox()
|
78 |
+
example_button.click(show_example, inputs=[selected_example, encodeCache],
|
79 |
+
outputs=[example_image_output, example_text_output, example_audio_output, encodeCache])
|
80 |
+
|
81 |
+
def get_reconstruction_module():
|
82 |
+
|
83 |
+
def do_nothing(image_input):
|
84 |
+
return np.random.random(SPECTROGRAM_RESOLUTION)
|
85 |
+
|
86 |
+
with gr.Tab("Reconstruction"):
|
87 |
+
gr.Markdown("Test reconstruction.")
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column():
|
90 |
+
test_reconstruction_input = gr.Number(label="Batch_index")
|
91 |
+
test_reconstruction_button = gr.Button(value="Generate")
|
92 |
+
with gr.Column():
|
93 |
+
test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy")
|
94 |
+
test_reconstruction_output.style(height=250, width=600)
|
95 |
+
test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output)
|
96 |
+
|
97 |
+
def get_interpolation_module(encodeCache):
|
98 |
+
|
99 |
+
def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache):
|
100 |
+
# Todo: use batch
|
101 |
+
first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24)))
|
102 |
+
second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24)))
|
103 |
+
ratio = torch.Tensor([interpulation_input_ratio])
|
104 |
+
interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio)
|
105 |
+
|
106 |
+
interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to("cuda")
|
107 |
+
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode)
|
108 |
+
first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0]
|
109 |
+
second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1]
|
110 |
+
interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2]
|
111 |
+
return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache
|
112 |
+
|
113 |
+
def refresh_interpolation_input(encodeCache):
|
114 |
+
return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), str(list(encodeCache.keys())), encodeCache
|
115 |
+
|
116 |
+
with gr.Tab("Interpolation"):
|
117 |
+
gr.Markdown("Test Interpolation.")
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
with gr.Row():
|
121 |
+
first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input")
|
122 |
+
second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input")
|
123 |
+
with gr.Row():
|
124 |
+
first_input_audio = gr.Audio(type="numpy", label="first_input_audio")
|
125 |
+
first_input_audio.style(length=125)
|
126 |
+
second_input_audio = gr.Audio(type="numpy", label="second_input_audio")
|
127 |
+
second_input_audio.style(length=125)
|
128 |
+
|
129 |
+
interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.")
|
130 |
+
interpulation_refresh_button = gr.Button(value="Refresh")
|
131 |
+
interpulation_button = gr.Button(value="Interpulate")
|
132 |
+
with gr.Column():
|
133 |
+
with gr.Row():
|
134 |
+
first_input_spectrogram = gr.Image(label="First Input", type="numpy")
|
135 |
+
first_input_spectrogram.style(height=250, width=125)
|
136 |
+
interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy")
|
137 |
+
interpolation_spectrogram.style(height=250, width=125)
|
138 |
+
second_input_spectrogram = gr.Image(label="Second Input", type="numpy")
|
139 |
+
second_input_spectrogram.style(height=250, width=125)
|
140 |
+
interpolation_audio = gr.Audio(type="numpy", label="Interpolation")
|
141 |
+
interpolation_audio.style(length=125)
|
142 |
+
|
143 |
+
interpolation_text_output = gr.Textbox()
|
144 |
+
interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache],
|
145 |
+
outputs=[first_interpulation_input, second_interpulation_input, interpolation_text_output, encodeCache])
|
146 |
+
interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache],
|
147 |
+
outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache])
|
148 |
+
|
149 |
+
def get_random_sampling_module(encodeCache):
|
150 |
+
|
151 |
+
def random_sample(mu, sigma, encodeCache):
|
152 |
+
random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24)
|
153 |
+
encodeCache["mytest"] = random_encode.detach().numpy()
|
154 |
+
random_encode = random_encode.to("cuda")
|
155 |
+
flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode)
|
156 |
+
random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0]
|
157 |
+
return random_log_spectrum, (sampleRate, random_rec_signal), encodeCache
|
158 |
+
|
159 |
+
with gr.Tab("Random sampling"):
|
160 |
+
gr.Markdown("Test reconstruction.")
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column():
|
163 |
+
with gr.Row():
|
164 |
+
mu = gr.Number(label="mu")
|
165 |
+
sigma = gr.Number(label="sigma")
|
166 |
+
random_sampling_button = gr.Button(value="Sample")
|
167 |
+
with gr.Column():
|
168 |
+
random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy")
|
169 |
+
random_sampling_spectrogram.style(height=250, width=600)
|
170 |
+
random_sampling_audio = gr.Audio(type="numpy", label="Interpolation")
|
171 |
+
random_sampling_audio.style(length=125)
|
172 |
+
random_sampling_button.click(random_sample, inputs=[mu, sigma, encodeCache], outputs=[random_sampling_spectrogram, random_sampling_audio, encodeCache])
|
173 |
+
|
174 |
+
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
initial_examples = gr.State(value=INIT_ENCODE_CACHE)
|
177 |
+
# initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)})
|
178 |
+
get_example_module(initial_examples)
|
179 |
+
# get_reconstruction_module()
|
180 |
+
get_random_sampling_module(initial_examples)
|
181 |
+
get_interpolation_module(initial_examples)
|
182 |
+
|
183 |
+
# demo.launch(share=True)
|
184 |
+
demo.launch(share=True, debug=True)
|
configurations/conf.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"midpoints":
|
3 |
+
{
|
4 |
+
"osc1_amp": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
|
5 |
+
"osc_amp2": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
|
6 |
+
"osc2_amp": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
|
7 |
+
"attack": [0.001, 0.03, 0.1, 0.25, 0.40, 0.7],
|
8 |
+
"decay": [0.001, 0.2, 0.60, 1.2],
|
9 |
+
"sustain": [0.01, 0.2, 0.5, 1.0],
|
10 |
+
"release": [0.001, 0.15, 0.35, 0.8],
|
11 |
+
"cutoff_freq": [2200, 2400, 2600, 2800, 3000,
|
12 |
+
3200, 3400, 3600, 3800, 4000,
|
13 |
+
4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800],
|
14 |
+
"osc_types": [0, 1, 2, 3, 4, 5],
|
15 |
+
"amp_mod_depth": [0, 0, 0, 0.1, 0.3, 0.5, 1.0],
|
16 |
+
"amp_mod_freq": [0, 0, 1, 2, 4, 8],
|
17 |
+
"mod_waveforms": [0,1,2,3],
|
18 |
+
"pitch_mod_depth": [0,1],
|
19 |
+
"pitch_mod_freq": [1,2,4,8]
|
20 |
+
},
|
21 |
+
|
22 |
+
"subspace_range":
|
23 |
+
{
|
24 |
+
"osc_amp2": 2,
|
25 |
+
"osc1_amp": 2,
|
26 |
+
"osc2_amp": 2,
|
27 |
+
"attack": 1,
|
28 |
+
"decay": 1,
|
29 |
+
"sustain": 2,
|
30 |
+
"release": 2,
|
31 |
+
"cutoff_freq": 0,
|
32 |
+
"osc_types": 0
|
33 |
+
},
|
34 |
+
|
35 |
+
"is_discrete":
|
36 |
+
{
|
37 |
+
"osc_amp2": false,
|
38 |
+
"osc1_amp": false,
|
39 |
+
"osc2_amp": false,
|
40 |
+
"attack": false,
|
41 |
+
"decay": false,
|
42 |
+
"sustain": false,
|
43 |
+
"release": false,
|
44 |
+
"cutoff_freq": false,
|
45 |
+
"duration": false,
|
46 |
+
"osc_types": true,
|
47 |
+
"mod_waveforms": true,
|
48 |
+
"amp_mod_depth": true,
|
49 |
+
"amp_mod_freq": true,
|
50 |
+
"pitch_mod_depth": true,
|
51 |
+
"pitch_mod_freq": true
|
52 |
+
},
|
53 |
+
|
54 |
+
"sample_rate": 16384,
|
55 |
+
"n_sample_note": 65536,
|
56 |
+
"n_sample_music": 65536,
|
57 |
+
|
58 |
+
"STFT_hyperParameter":
|
59 |
+
{
|
60 |
+
"frame_length": 512,
|
61 |
+
"frame_step": 256
|
62 |
+
},
|
63 |
+
|
64 |
+
"midi_midpoints":
|
65 |
+
{
|
66 |
+
"duration": [0.1,0.5,1.0,2.0],
|
67 |
+
"pitch": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76]
|
68 |
+
},
|
69 |
+
|
70 |
+
"midi_is_discrete":
|
71 |
+
{
|
72 |
+
"duration": false,
|
73 |
+
"pitch": true
|
74 |
+
},
|
75 |
+
|
76 |
+
"midi_max_n_notes": 8,
|
77 |
+
|
78 |
+
"resolution":
|
79 |
+
{
|
80 |
+
"time_resolution": 509,
|
81 |
+
"freq_resolution": 513
|
82 |
+
}
|
83 |
+
}
|
configurations/read_configuration.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
bins_path = 'configurations/conf.json'
|
3 |
+
|
4 |
+
|
5 |
+
def parameter_range(parameter_name):
|
6 |
+
"""
|
7 |
+
:param parameter_name:
|
8 |
+
:return: List[Float]--midpoints of bins for the input synthesizer parameter
|
9 |
+
"""
|
10 |
+
|
11 |
+
with open(bins_path) as f:
|
12 |
+
midpoints = json.load(f)["midpoints"]
|
13 |
+
return midpoints[parameter_name]
|
14 |
+
|
15 |
+
|
16 |
+
def cluster_range():
|
17 |
+
"""
|
18 |
+
:return: Dict[String:Int]--defines the range of cluster to search
|
19 |
+
"""
|
20 |
+
with open(bins_path) as f:
|
21 |
+
cluster_r = json.load(f)["subspace_range"]
|
22 |
+
return cluster_r
|
23 |
+
|
24 |
+
|
25 |
+
def midi_parameter_range(parameter_name):
|
26 |
+
"""
|
27 |
+
:param parameter_name:
|
28 |
+
:return: List[Float]--midpoints of bins for the input midi parameter
|
29 |
+
"""
|
30 |
+
with open(bins_path) as f:
|
31 |
+
r = json.load(f)["midi_midpoints"]
|
32 |
+
return r[parameter_name]
|
33 |
+
|
34 |
+
|
35 |
+
def is_discrete(parameter_name):
|
36 |
+
"""
|
37 |
+
:param parameter_name:
|
38 |
+
:return: Boolean--if the input synthesizer parameter is discrete
|
39 |
+
"""
|
40 |
+
with open(bins_path) as f:
|
41 |
+
is_dis = json.load(f)["is_discrete"]
|
42 |
+
return is_dis[parameter_name]
|
43 |
+
|
44 |
+
|
45 |
+
def midi_is_discrete(parameter_name):
|
46 |
+
"""
|
47 |
+
:param parameter_name:
|
48 |
+
:return: Boolean--if the input midi parameter is discrete
|
49 |
+
"""
|
50 |
+
with open(bins_path) as f:
|
51 |
+
is_dis = json.load(f)["midi_is_discrete"]
|
52 |
+
return is_dis[parameter_name]
|
53 |
+
|
54 |
+
|
55 |
+
def get_label_size():
|
56 |
+
"""
|
57 |
+
:return: Int--length of synthesizer parameter encoding
|
58 |
+
"""
|
59 |
+
with open(bins_path) as f:
|
60 |
+
conf = json.load(f)
|
61 |
+
midpoints = conf["midpoints"]
|
62 |
+
n_labels = 0
|
63 |
+
for key in midpoints:
|
64 |
+
n_labels = n_labels + len(midpoints[key])
|
65 |
+
|
66 |
+
return n_labels
|
67 |
+
|
68 |
+
|
69 |
+
def get_bins_length():
|
70 |
+
"""
|
71 |
+
:return: Dict[String:Int]--Number of bins for all synthesizer parameters
|
72 |
+
"""
|
73 |
+
with open(bins_path) as f:
|
74 |
+
midpoints = json.load(f)["midpoints"]
|
75 |
+
bins_length = {}
|
76 |
+
|
77 |
+
for key in midpoints:
|
78 |
+
bins_length[key] = len(midpoints[key])
|
79 |
+
|
80 |
+
return bins_length
|
81 |
+
|
82 |
+
|
83 |
+
def get_conf_stft_hyperparameter():
|
84 |
+
"""
|
85 |
+
:return: Dict[String:Int]--STFT hyper parameters
|
86 |
+
"""
|
87 |
+
with open(bins_path) as f:
|
88 |
+
STFT_hyperParameters = json.load(f)["STFT_hyperParameter"]
|
89 |
+
|
90 |
+
return STFT_hyperParameters
|
91 |
+
|
92 |
+
|
93 |
+
def get_conf_sample_rate():
|
94 |
+
"""
|
95 |
+
:return: Int--sample_rate
|
96 |
+
"""
|
97 |
+
with open(bins_path) as f:
|
98 |
+
sample_rate = json.load(f)["sample_rate"]
|
99 |
+
|
100 |
+
return sample_rate
|
101 |
+
|
102 |
+
|
103 |
+
def get_conf_n_sample_note():
|
104 |
+
"""
|
105 |
+
:return: Int--sample number of a note example
|
106 |
+
"""
|
107 |
+
with open(bins_path) as f:
|
108 |
+
n_sample_note = json.load(f)["n_sample_note"]
|
109 |
+
|
110 |
+
return n_sample_note
|
111 |
+
|
112 |
+
|
113 |
+
def get_conf_n_sample():
|
114 |
+
"""
|
115 |
+
:return: Int--sample number of a melody example
|
116 |
+
"""
|
117 |
+
with open(bins_path) as f:
|
118 |
+
n_sample = json.load(f)["n_sample_music"]
|
119 |
+
|
120 |
+
return n_sample
|
121 |
+
|
122 |
+
|
123 |
+
def get_conf_time_resolution():
|
124 |
+
"""
|
125 |
+
:return: Int--spectrogram resolution on time dimension
|
126 |
+
"""
|
127 |
+
with open(bins_path) as f:
|
128 |
+
resolution = json.load(f)["resolution"]
|
129 |
+
|
130 |
+
return resolution["time_resolution"]
|
131 |
+
|
132 |
+
|
133 |
+
def get_conf_pitch_resolution():
|
134 |
+
"""
|
135 |
+
:return: Int--spectrogram resolution on pitch dimension
|
136 |
+
"""
|
137 |
+
with open(bins_path) as f:
|
138 |
+
resolution = json.load(f)["resolution"]
|
139 |
+
|
140 |
+
return resolution["freq_resolution"]
|
141 |
+
|
142 |
+
|
143 |
+
def get_conf_max_n_notes():
|
144 |
+
"""
|
145 |
+
:return: Int--maximum number of notes to be generated in a melody
|
146 |
+
"""
|
147 |
+
with open(bins_path) as f:
|
148 |
+
max_n_notes = json.load(f)["midi_max_n_notes"]
|
149 |
+
|
150 |
+
return max_n_notes
|
151 |
+
|
152 |
+
|
data_generation/data_generation.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
|
5 |
+
from data_generation.encoding import ParameterDescription, Sample
|
6 |
+
from melody_synth.random_midi import RandomMidi
|
7 |
+
from melody_synth.melody_generator import MelodyGenerator
|
8 |
+
from scipy.io.wavfile import write
|
9 |
+
from pathlib import Path
|
10 |
+
from tqdm import tqdm
|
11 |
+
import json
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
import tensorflow as tf
|
15 |
+
import matplotlib
|
16 |
+
from configurations.read_configuration import parameter_range, is_discrete, get_conf_stft_hyperparameter
|
17 |
+
import shutil
|
18 |
+
|
19 |
+
# from model.log_spectrogram import power_to_db
|
20 |
+
from tools import power_to_db
|
21 |
+
|
22 |
+
num_params = 16
|
23 |
+
|
24 |
+
|
25 |
+
def plot_spectrogram(signal: np.ndarray,
|
26 |
+
path: str,
|
27 |
+
frame_length=512,
|
28 |
+
frame_step=256):
|
29 |
+
"""Computes the spectrogram of the given signal and saves it.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
signal: np.ndarray
|
34 |
+
The signal for which to compute the spectrogram.
|
35 |
+
path: str
|
36 |
+
Path to save the the computed spectrogram.
|
37 |
+
frame_length:
|
38 |
+
Window size of the FFT.
|
39 |
+
frame_step:
|
40 |
+
Hop size of the FFT.
|
41 |
+
"""
|
42 |
+
|
43 |
+
# Compute spectrum for each frame. Returns complex tensor.
|
44 |
+
# todo: duplicate code in log_spectrogram.py. Move this somewhere else perhaps.
|
45 |
+
spectrogram = tf.signal.stft(signal,
|
46 |
+
frame_length=frame_length,
|
47 |
+
frame_step=frame_step,
|
48 |
+
pad_end=False) # Returns 63 frames instead of 64 otherwise
|
49 |
+
|
50 |
+
# Compute the magnitudes
|
51 |
+
magnitude_spectrum = tf.abs(spectrogram)
|
52 |
+
log_spectrum = power_to_db(magnitude_spectrum)
|
53 |
+
matplotlib.pyplot.imsave(path, np.transpose(log_spectrum), vmin=-100, vmax=0, origin='lower')
|
54 |
+
|
55 |
+
|
56 |
+
def plot_mel_spectrogram(signal: np.ndarray,
|
57 |
+
path: str,
|
58 |
+
frame_length=512,
|
59 |
+
frame_step=256):
|
60 |
+
|
61 |
+
spectrogram = librosa.feature.melspectrogram(signal, sr=16384, n_fft=2048, hop_length=frame_step, win_length=frame_length)
|
62 |
+
matplotlib.pyplot.imsave(path, spectrogram, vmin=-100, vmax=0, origin='lower')
|
63 |
+
|
64 |
+
|
65 |
+
# List of ParameterDescription objects that specify the parameters for generation
|
66 |
+
param_descriptions: List[ParameterDescription]
|
67 |
+
param_descriptions = [
|
68 |
+
|
69 |
+
# Oscillator levels
|
70 |
+
ParameterDescription(name="osc1_amp",
|
71 |
+
values=parameter_range('osc1_amp'),
|
72 |
+
discrete=is_discrete('osc1_amp')),
|
73 |
+
ParameterDescription(name="osc2_amp",
|
74 |
+
values=parameter_range('osc2_amp'),
|
75 |
+
discrete=is_discrete('osc2_amp')),
|
76 |
+
|
77 |
+
# ADSR params
|
78 |
+
ParameterDescription(name="attack",
|
79 |
+
values=parameter_range('attack'),
|
80 |
+
discrete=is_discrete('attack')),
|
81 |
+
ParameterDescription(name="decay",
|
82 |
+
values=parameter_range('decay'),
|
83 |
+
discrete=is_discrete('decay')),
|
84 |
+
ParameterDescription(name="sustain",
|
85 |
+
values=parameter_range('sustain'),
|
86 |
+
discrete=is_discrete('sustain')),
|
87 |
+
ParameterDescription(name="release",
|
88 |
+
values=parameter_range('release'),
|
89 |
+
discrete=is_discrete('release')),
|
90 |
+
|
91 |
+
ParameterDescription(name="cutoff_freq",
|
92 |
+
values=parameter_range('cutoff_freq'),
|
93 |
+
discrete=is_discrete('cutoff_freq')),
|
94 |
+
|
95 |
+
# Oscillators types
|
96 |
+
# 0 for sin saw, 1 for sin square, 2 for saw square
|
97 |
+
# 3 for sin triangle, 4 for triangle saw, 5 for triangle square
|
98 |
+
ParameterDescription(name="osc_types",
|
99 |
+
values=parameter_range('osc_types'),
|
100 |
+
discrete=is_discrete('osc_types')),
|
101 |
+
]
|
102 |
+
|
103 |
+
|
104 |
+
def generate_dataset_for_cnn(n: int,
|
105 |
+
path_name="./data/data_cnn_model",
|
106 |
+
sample_rate=16384,
|
107 |
+
n_samples_for_note=16384 * 4,
|
108 |
+
n_samples_for_melody=16384 * 4, write_parameter=True, write_spectrogram=True):
|
109 |
+
"""
|
110 |
+
Generate dataset of size n for 'Inversynth' cnn model
|
111 |
+
|
112 |
+
:param n: Int
|
113 |
+
:param path_name: String--path to save the dataset
|
114 |
+
:param sample_rate: Int
|
115 |
+
:param n_samples_for_note: Int
|
116 |
+
:param n_samples_for_melody: Int
|
117 |
+
:param write_parameter: Boolean--if write parameter values in a .txt file
|
118 |
+
:param write_spectrogram: Boolean--write spectrogram with parameter values in the file name
|
119 |
+
:return:
|
120 |
+
"""
|
121 |
+
|
122 |
+
shutil.rmtree(path_name)
|
123 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
124 |
+
print("Generating dataset...")
|
125 |
+
|
126 |
+
synth = MelodyGenerator(sample_rate,
|
127 |
+
n_samples_for_note, n_samples_for_melody)
|
128 |
+
randomMidi = RandomMidi()
|
129 |
+
|
130 |
+
for i in tqdm(range(n)):
|
131 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
132 |
+
|
133 |
+
# Dict of parameter values, what our synthesizer expects as input
|
134 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
135 |
+
|
136 |
+
strategy = {"rhythm_strategy": "free_rhythm",
|
137 |
+
"pitch_strategy": "free_pitch",
|
138 |
+
"duration_strategy": "random_duration",
|
139 |
+
}
|
140 |
+
midi_encode, midi = randomMidi(strategy)
|
141 |
+
signal = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
142 |
+
|
143 |
+
# Path to store each sample with its label
|
144 |
+
path = path_name + f"/{i}"
|
145 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
146 |
+
|
147 |
+
if write_parameter:
|
148 |
+
suffix = 'spectrogram'
|
149 |
+
for parameter_value in parameter_values:
|
150 |
+
suffix += f'_{parameter_value.name}={"%.3f" % parameter_value.value}'
|
151 |
+
if write_spectrogram:
|
152 |
+
plot_spectrogram(signal, path=path + f"/{suffix}.png", frame_length=1024, frame_step=256)
|
153 |
+
else:
|
154 |
+
with open(path + f"/{suffix}.txt", "w") as f:
|
155 |
+
f.write("test")
|
156 |
+
f.close()
|
157 |
+
|
158 |
+
write(path + f"/{i}.wav", synth.sample_rate, signal)
|
159 |
+
|
160 |
+
sample = Sample(parameter_values)
|
161 |
+
|
162 |
+
# Dump label as json
|
163 |
+
with open(path + "/label.json", "w") as label_file:
|
164 |
+
label = sample.get_values()
|
165 |
+
label['midi'] = midi
|
166 |
+
# print(len(label["encoding"]))
|
167 |
+
json.dump(label, label_file, ensure_ascii=True)
|
168 |
+
|
169 |
+
print('Data generation done!')
|
170 |
+
|
171 |
+
|
172 |
+
def generate_dataset_for_triplet(n: int,
|
173 |
+
path_name="./data/data_triplet_val_10_500",
|
174 |
+
sample_rate=16384,
|
175 |
+
n_samples_for_note=16384 * 4,
|
176 |
+
n_samples_for_melody=16384 * 4,
|
177 |
+
n_labels=30,
|
178 |
+
write_spectrogram=True):
|
179 |
+
"""
|
180 |
+
Generate dataset of size n for triplet model
|
181 |
+
|
182 |
+
:param write_spectrogram: Boolean--if write spectrogram
|
183 |
+
:param n: Int :param path_name: String--path to save the dataset :param sample_rate: Int :param
|
184 |
+
n_samples_for_note: Int :param n_samples_for_melody: Int :param n_labels: Int--number of synthesizer parameter
|
185 |
+
combinations contained in the dataset (a hyper parameter of triplet model)
|
186 |
+
"""
|
187 |
+
|
188 |
+
shutil.rmtree(path_name)
|
189 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
190 |
+
print("Generating dataset...")
|
191 |
+
synth = MelodyGenerator(sample_rate,
|
192 |
+
n_samples_for_note, n_samples_for_melody)
|
193 |
+
randomMidi = RandomMidi()
|
194 |
+
|
195 |
+
parameter_values_examples = [[param.generate() for param in param_descriptions] for i in range(n_labels)]
|
196 |
+
parameter_values_raw_examples = [{param.name: param.value for param in parameter_values} for parameter_values in
|
197 |
+
parameter_values_examples]
|
198 |
+
|
199 |
+
np.random.seed()
|
200 |
+
for i in tqdm(range(n)):
|
201 |
+
label_index = np.random.randint(0, n_labels)
|
202 |
+
parameter_values = parameter_values_examples[label_index]
|
203 |
+
parameter_values_raw = parameter_values_raw_examples[label_index]
|
204 |
+
|
205 |
+
strategy = {"rhythm_strategy": "free_rhythm",
|
206 |
+
"pitch_strategy": "free_pitch",
|
207 |
+
"duration_strategy": "random_duration",
|
208 |
+
}
|
209 |
+
midi_encode, midi = randomMidi(strategy)
|
210 |
+
signal = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
211 |
+
|
212 |
+
# Path to store each sample with its label
|
213 |
+
path = path_name + f"/{i}"
|
214 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
215 |
+
|
216 |
+
write(path + f"/{i}.wav", synth.sample_rate, signal)
|
217 |
+
suffix = 'spectrogram'
|
218 |
+
for parameter_value in parameter_values:
|
219 |
+
suffix += f'_{parameter_value.name}={"%.3f" % parameter_value.value}'
|
220 |
+
|
221 |
+
if write_spectrogram:
|
222 |
+
hp = get_conf_stft_hyperparameter()
|
223 |
+
frame_l = hp['frame_length']
|
224 |
+
frame_s = hp['frame_length']
|
225 |
+
plot_spectrogram(signal, path=path + f"/{suffix}.png", frame_length=frame_l, frame_step=frame_s)
|
226 |
+
else:
|
227 |
+
with open(path + f"/{suffix}.txt", "w") as f:
|
228 |
+
f.write("test")
|
229 |
+
f.close()
|
230 |
+
|
231 |
+
with open(path + "/label_index.json", "w") as label_index_file:
|
232 |
+
index_json = {'index': label_index}
|
233 |
+
json.dump(index_json, label_index_file, ensure_ascii=False)
|
234 |
+
|
235 |
+
# save midi as .txt file
|
236 |
+
with open(path + "/midi.txt", "w") as midi_file:
|
237 |
+
midi_file.write(str(midi))
|
238 |
+
midi_file.close()
|
239 |
+
|
240 |
+
print('Data generation done!')
|
241 |
+
|
242 |
+
|
243 |
+
def manhattan_distance(SP1, SP2):
|
244 |
+
"""
|
245 |
+
:param SP1: first input synthesizer parameter combination
|
246 |
+
:param SP2: second input synthesizer parameter combination
|
247 |
+
:return: Float--manhattan distance between SP1 and SP2
|
248 |
+
"""
|
249 |
+
|
250 |
+
md = []
|
251 |
+
for key in SP1:
|
252 |
+
parameter_name = key
|
253 |
+
value1 = SP1[parameter_name]
|
254 |
+
value2 = SP2[parameter_name]
|
255 |
+
bins = parameter_range(parameter_name)
|
256 |
+
bin_index1 = np.argmin(np.abs(np.array(bins) - value1))
|
257 |
+
bin_index2 = np.argmin(np.abs(np.array(bins) - value2))
|
258 |
+
|
259 |
+
if parameter_name == "osc_types":
|
260 |
+
if bin_index1 == bin_index2:
|
261 |
+
d = 0
|
262 |
+
else:
|
263 |
+
d = 1
|
264 |
+
else:
|
265 |
+
d = np.abs(bin_index1 - bin_index2) / (len(bins) - 1)
|
266 |
+
md.append(d)
|
267 |
+
|
268 |
+
return np.average(md)
|
269 |
+
|
270 |
+
|
271 |
+
def generate_dataset_for_mixed_input_model(n: int,
|
272 |
+
path_name="./data/data_mixed_input",
|
273 |
+
sample_rate=16384,
|
274 |
+
n_samples_for_note=16384 * 4,
|
275 |
+
n_samples_for_melody=16384 * 4
|
276 |
+
):
|
277 |
+
"""
|
278 |
+
Generate dataset of size n for mixed_input_model model
|
279 |
+
|
280 |
+
:param n: Int
|
281 |
+
:param path_name: String--path to save the dataset
|
282 |
+
:param sample_rate: Int
|
283 |
+
:param n_samples_for_note: Int
|
284 |
+
:param n_samples_for_melody: Int
|
285 |
+
:return:
|
286 |
+
"""
|
287 |
+
|
288 |
+
shutil.rmtree(path_name)
|
289 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
290 |
+
print("Generating dataset...")
|
291 |
+
synth = MelodyGenerator(sample_rate,
|
292 |
+
n_samples_for_note, n_samples_for_melody)
|
293 |
+
randomMidi = RandomMidi()
|
294 |
+
|
295 |
+
strategy = {"rhythm_strategy": "free_rhythm",
|
296 |
+
"pitch_strategy": "free_pitch",
|
297 |
+
"duration_strategy": "random_duration",
|
298 |
+
}
|
299 |
+
strategy0 = {"rhythm_strategy": "single_note_rhythm",
|
300 |
+
"pitch_strategy": "fixed_pitch",
|
301 |
+
"duration_strategy": "fixed_duration",
|
302 |
+
}
|
303 |
+
strategy1 = {"rhythm_strategy": "single_note_rhythm",
|
304 |
+
"pitch_strategy": "fixed_pitch1",
|
305 |
+
"duration_strategy": "fixed_duration",
|
306 |
+
}
|
307 |
+
strategy2 = {"rhythm_strategy": "single_note_rhythm",
|
308 |
+
"pitch_strategy": "fixed_pitch2",
|
309 |
+
"duration_strategy": "fixed_duration",
|
310 |
+
}
|
311 |
+
strategy3 = {"rhythm_strategy": "single_note_rhythm",
|
312 |
+
"pitch_strategy": "fixed_pitch3",
|
313 |
+
"duration_strategy": "fixed_duration",
|
314 |
+
}
|
315 |
+
strategy4 = {"rhythm_strategy": "single_note_rhythm",
|
316 |
+
"pitch_strategy": "fixed_pitch4",
|
317 |
+
"duration_strategy": "fixed_duration",
|
318 |
+
}
|
319 |
+
|
320 |
+
np.random.seed()
|
321 |
+
for i in tqdm(range(n)):
|
322 |
+
path = path_name + f"/{i}"
|
323 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
324 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
325 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
326 |
+
|
327 |
+
# generate query music
|
328 |
+
midi_encode, midi = randomMidi(strategy)
|
329 |
+
signal_query = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
330 |
+
write(path + f"/{i}.wav", synth.sample_rate, signal_query)
|
331 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
332 |
+
|
333 |
+
if np.random.rand() < 0.01: # 50% positive
|
334 |
+
with open(path + "/label.json", "w") as label_file:
|
335 |
+
sample = Sample(parameter_values)
|
336 |
+
label = sample.get_values()
|
337 |
+
label['manhattan_distance'] = 0.
|
338 |
+
json.dump(label, label_file, ensure_ascii=False)
|
339 |
+
else:
|
340 |
+
with open(path + "/label.json", "w") as label_file:
|
341 |
+
query_sp = parameter_values_raw
|
342 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
343 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
344 |
+
sample = Sample(parameter_values)
|
345 |
+
label = sample.get_values()
|
346 |
+
md = manhattan_distance(query_sp, parameter_values_raw)
|
347 |
+
label['manhattan_distance'] = md
|
348 |
+
json.dump(label, label_file, ensure_ascii=False)
|
349 |
+
|
350 |
+
# generate query music
|
351 |
+
midi_encode, midi = randomMidi(strategy0)
|
352 |
+
signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
353 |
+
write(path + f"/{i}_0.wav", synth.sample_rate, signal_single_note)
|
354 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
355 |
+
|
356 |
+
# generate query music
|
357 |
+
midi_encode, midi = randomMidi(strategy1)
|
358 |
+
signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
359 |
+
write(path + f"/{i}_1.wav", synth.sample_rate, signal_single_note)
|
360 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
361 |
+
|
362 |
+
# generate query music
|
363 |
+
midi_encode, midi = randomMidi(strategy2)
|
364 |
+
signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
365 |
+
write(path + f"/{i}_2.wav", synth.sample_rate, signal_single_note)
|
366 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
367 |
+
|
368 |
+
# generate query music
|
369 |
+
midi_encode, midi = randomMidi(strategy3)
|
370 |
+
signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
371 |
+
write(path + f"/{i}_3.wav", synth.sample_rate, signal_single_note)
|
372 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
373 |
+
|
374 |
+
# generate query music
|
375 |
+
midi_encode, midi = randomMidi(strategy4)
|
376 |
+
signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
|
377 |
+
write(path + f"/{i}_4.wav", synth.sample_rate, signal_single_note)
|
378 |
+
# plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
|
379 |
+
|
380 |
+
print('Data generation done!')
|
data_generation/decoding.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
from data_generation.data_generation import param_descriptions
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from melody_synth.melody_generator import MelodyGenerator
|
6 |
+
from melody_synth.random_midi import RandomMidi
|
7 |
+
|
8 |
+
|
9 |
+
def decode_label(prediction: np.ndarray,
|
10 |
+
sample_rate: int,
|
11 |
+
n_samples: int,
|
12 |
+
return_params=False,
|
13 |
+
discard_parameters=[]):
|
14 |
+
"""Parses a network prediction array, synthesizes the described audio and returns it.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
prediction: np.ndarray
|
19 |
+
The network prediction array
|
20 |
+
sample_rate: int
|
21 |
+
Sample rate of the audio to generate.
|
22 |
+
n_samples: int
|
23 |
+
Number of samples per wav file.
|
24 |
+
return_params: bool
|
25 |
+
Whether or not to also return the parameters alongside the signal
|
26 |
+
discard_parameters: List[str]
|
27 |
+
Parameter names that should be discarded (set to their default value)
|
28 |
+
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
np.ndarray:
|
32 |
+
The generated signal
|
33 |
+
"""
|
34 |
+
|
35 |
+
params: Dict[str, float] = {}
|
36 |
+
index = 0
|
37 |
+
for i, param_description in enumerate(param_descriptions):
|
38 |
+
# Parses the one-hot-encoding of the prediction array
|
39 |
+
bits = len(param_description.values)
|
40 |
+
curr_prediction = prediction[index:index + bits]
|
41 |
+
|
42 |
+
hot_index = curr_prediction.argmax()
|
43 |
+
params[param_description.name] = param_description.parameter_value(hot_index).value
|
44 |
+
index += bits
|
45 |
+
|
46 |
+
for param_str in discard_parameters:
|
47 |
+
params[param_str] = 0 # todo: make this safe and change to default value and not just 0
|
48 |
+
|
49 |
+
synth = MelodyGenerator(sample_rate,
|
50 |
+
n_samples, n_samples)
|
51 |
+
randomMidi = RandomMidi()
|
52 |
+
|
53 |
+
strategy = {"rhythm_strategy": "single_note_rhythm",
|
54 |
+
"pitch_strategy": "fixed_pitch",
|
55 |
+
"duration_strategy": "fixed_duration",
|
56 |
+
}
|
57 |
+
midi_encode, midi = randomMidi(strategy)
|
58 |
+
|
59 |
+
signal = synth.get_melody(params, midi=midi).numpy()
|
60 |
+
|
61 |
+
if return_params:
|
62 |
+
return signal, params
|
63 |
+
|
64 |
+
return signal
|
data_generation/encoding.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def parameter_range_low_high(parameter_range: List):
|
6 |
+
"""
|
7 |
+
:param parameter_range:List[Float]--midpoints of bins
|
8 |
+
:return: List[Float]--lower and upper bounds of bins
|
9 |
+
"""
|
10 |
+
temp1 = np.array(parameter_range[1:])
|
11 |
+
temp2 = np.array(parameter_range[:len(temp1)])
|
12 |
+
temp1 = 0.5 * (temp1 + temp2)
|
13 |
+
|
14 |
+
return np.hstack([parameter_range[0], temp1, parameter_range[len(parameter_range) - 1]])
|
15 |
+
|
16 |
+
|
17 |
+
class ParameterValue:
|
18 |
+
"""Describes a one hot encoded parameter value."""
|
19 |
+
|
20 |
+
name: str
|
21 |
+
value: float
|
22 |
+
encoding: List[float]
|
23 |
+
index: int
|
24 |
+
|
25 |
+
def __init__(self, name, value, encoding, index):
|
26 |
+
self.name = name
|
27 |
+
self.value = value
|
28 |
+
self.encoding = encoding
|
29 |
+
self.index = index
|
30 |
+
|
31 |
+
|
32 |
+
class ParameterDescription:
|
33 |
+
"""A description for generating a parameter value."""
|
34 |
+
|
35 |
+
# Discrete is used to generate samples that don't exactly fit into a bin for training.
|
36 |
+
def __init__(self, name, values: List[float], discrete=True):
|
37 |
+
self.name = name
|
38 |
+
self.values = values
|
39 |
+
self.discrete = discrete
|
40 |
+
self.parameter_low_high = parameter_range_low_high(values)
|
41 |
+
|
42 |
+
# one-hot encoding as per paper
|
43 |
+
# Value used for specifying a different value than values[index], useful for non-discrete params. todo: too adhoc?
|
44 |
+
def parameter_value(self, index, value=None) -> ParameterValue:
|
45 |
+
if value is None:
|
46 |
+
value = self.values[index]
|
47 |
+
encoding = np.zeros(len(self.values), dtype=float)
|
48 |
+
encoding[index] = 1.0
|
49 |
+
return ParameterValue(
|
50 |
+
name=self.name,
|
51 |
+
value=value,
|
52 |
+
encoding=encoding,
|
53 |
+
index=index
|
54 |
+
)
|
55 |
+
|
56 |
+
# random even distribution as per paper
|
57 |
+
def generate(self) -> ParameterValue:
|
58 |
+
# choose a bin if parameter is discrete
|
59 |
+
if self.discrete:
|
60 |
+
index = np.random.randint(0, len(self.values))
|
61 |
+
return self.parameter_value(index)
|
62 |
+
# otherwise generate a random value
|
63 |
+
else:
|
64 |
+
indexFinder = np.random.uniform(0, 1)
|
65 |
+
l = np.linspace(0.0, 1, len(self.values))
|
66 |
+
index = np.argmin(np.abs(l - indexFinder))
|
67 |
+
value = (self.parameter_low_high[index+1] - self.parameter_low_high[index]) * np.random.uniform(0, 1) + self.parameter_low_high[index]
|
68 |
+
|
69 |
+
return self.parameter_value(index, value)
|
70 |
+
|
71 |
+
# get the index of the best matching bin
|
72 |
+
def get_bin_index(self, value):
|
73 |
+
return np.argmin(np.abs(np.array(self.values) - value))
|
74 |
+
|
75 |
+
def decode(self, encoding: List[float]) -> ParameterValue:
|
76 |
+
index = np.array(encoding).argmax()
|
77 |
+
return self.parameter_value(index)
|
78 |
+
|
79 |
+
|
80 |
+
class Sample:
|
81 |
+
"""Describes the label of one training sample."""
|
82 |
+
|
83 |
+
parameters: List[ParameterValue]
|
84 |
+
|
85 |
+
def __init__(self, parameters):
|
86 |
+
self.parameters = parameters
|
87 |
+
|
88 |
+
def get_values(self) -> Dict[str, dict]:
|
89 |
+
return {
|
90 |
+
"parameters": {p.name: p.value for p in self.parameters},
|
91 |
+
"encoding": list(np.hstack(p.encoding for p in self.parameters))
|
92 |
+
}
|
example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
external sources.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
1. External datasets have been preprocessed and stored in the directory '/data/external_data'. Links for the datasets can be found in the same directory.
|
2 |
+
|
3 |
+
2. External codes: Some code in '/model/VAE.py', 'non_random_LFOs.py' and '/melody/complex_torch_synth.py' references external code. References are made in the files.
|
generate_synthetic_data_online.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import librosa
|
3 |
+
import matplotlib
|
4 |
+
import pandas as pd
|
5 |
+
from typing import Optional
|
6 |
+
from torch import tensor
|
7 |
+
from ddsp.core import tf_float32
|
8 |
+
import torch
|
9 |
+
from torch import Tensor
|
10 |
+
import numpy as np
|
11 |
+
import tensorflow as tf
|
12 |
+
from torchsynth.config import SynthConfig
|
13 |
+
import ddsp
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import Dict
|
16 |
+
from data_generation.encoding import ParameterDescription
|
17 |
+
from typing import List
|
18 |
+
from configurations.read_configuration import parameter_range, is_discrete, midi_parameter_range, midi_is_discrete
|
19 |
+
import shutil
|
20 |
+
from tqdm import tqdm
|
21 |
+
from scipy.io.wavfile import write
|
22 |
+
from melody_synth.complex_torch_synth import DoubleSawSynth, SinSawSynth, SinTriangleSynth, TriangleSawSynth
|
23 |
+
|
24 |
+
sample_rate = 16000
|
25 |
+
n_samples = sample_rate * 4.5
|
26 |
+
|
27 |
+
|
28 |
+
class NoteGenerator:
|
29 |
+
"""
|
30 |
+
This class is responsible for single-note audio generation by function 'get_note'.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
sample_rate=sample_rate,
|
35 |
+
n_samples=sample_rate * 4.5):
|
36 |
+
self.sample_rate = sample_rate
|
37 |
+
self.n_samples = n_samples
|
38 |
+
synthconfig = SynthConfig(
|
39 |
+
batch_size=1, reproducible=False, sample_rate=sample_rate,
|
40 |
+
buffer_size_seconds=np.float64(n_samples) / np.float64(sample_rate)
|
41 |
+
)
|
42 |
+
self.Saw_Square_Voice = DoubleSawSynth(synthconfig)
|
43 |
+
self.SinSawVoice = SinSawSynth(synthconfig)
|
44 |
+
self.SinTriVoice = SinTriangleSynth(synthconfig)
|
45 |
+
self.TriSawVoice = TriangleSawSynth(synthconfig)
|
46 |
+
|
47 |
+
def get_note(self, params: Dict[str, float]):
|
48 |
+
osc_amp2 = np.float64(params.get("osc_amp2", 0))
|
49 |
+
|
50 |
+
if osc_amp2 < 0.45:
|
51 |
+
osc1_amp = 0.9
|
52 |
+
osc2_amp = osc_amp2
|
53 |
+
else:
|
54 |
+
osc1_amp = 0.9 - osc_amp2
|
55 |
+
osc2_amp = 0.9
|
56 |
+
|
57 |
+
attack_1 = np.float64(params.get("attack_1", 0))
|
58 |
+
decay_1 = np.float64(params.get("decay_1", 0))
|
59 |
+
sustain_1 = np.float64(params.get("sustain_1", 0))
|
60 |
+
release_1 = np.float64(params.get("release_1", 0))
|
61 |
+
|
62 |
+
attack_2 = np.float64(params.get("attack_2", 0))
|
63 |
+
decay_2 = np.float64(params.get("decay_2", 0))
|
64 |
+
sustain_2 = np.float64(params.get("sustain_2", 0))
|
65 |
+
release_2 = np.float64(params.get("release_2", 0))
|
66 |
+
|
67 |
+
amp_mod_freq = params.get("amp_mod_freq", 0)
|
68 |
+
amp_mod_depth = params.get("amp_mod_depth", 0)
|
69 |
+
amp_mod_waveform = params.get("amp_mod_waveform", 0)
|
70 |
+
|
71 |
+
pitch_mod_freq_1 = params.get("pitch_mod_freq_1", 0)
|
72 |
+
pitch_mod_depth = params.get("pitch_mod_depth", 0)
|
73 |
+
|
74 |
+
cutoff_freq = params.get("cutoff_freq", 4000)
|
75 |
+
|
76 |
+
pitch = np.float64(params.get("pitch", 0))
|
77 |
+
duration = np.float64(params.get("duration", 0))
|
78 |
+
|
79 |
+
syn_parameters = {
|
80 |
+
("adsr_1", "attack"): tensor([attack_1]), # [0.0, 2.0]
|
81 |
+
("adsr_1", "decay"): tensor([decay_1]), # [0.0, 2.0]
|
82 |
+
("adsr_1", "sustain"): tensor([sustain_1]), # [0.0, 2.0]
|
83 |
+
("adsr_1", "release"): tensor([release_1]), # [0.0, 2.0]
|
84 |
+
("adsr_1", "alpha"): tensor([5]), # [0.1, 6.0]
|
85 |
+
|
86 |
+
("adsr_2", "attack"): tensor([attack_2]), # [0.0, 2.0]
|
87 |
+
("adsr_2", "decay"): tensor([decay_2]), # [0.0, 2.0]
|
88 |
+
("adsr_2", "sustain"): tensor([sustain_2]), # [0.0, 2.0]
|
89 |
+
("adsr_2", "release"): tensor([release_2]), # [0.0, 2.0]
|
90 |
+
("adsr_2", "alpha"): tensor([5]), # [0.1, 6.0]
|
91 |
+
("keyboard", "midi_f0"): tensor([pitch]),
|
92 |
+
("keyboard", "duration"): tensor([duration]),
|
93 |
+
|
94 |
+
# Mixer parameter
|
95 |
+
("mixer", "vco_1"): tensor([osc1_amp]), # [0, 1]
|
96 |
+
("mixer", "vco_2"): tensor([osc2_amp]), # [0, 1]
|
97 |
+
|
98 |
+
# Constant parameters:
|
99 |
+
("vco_1", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
|
100 |
+
("vco_1", "tuning"): tensor([0.0]), # [-24.0, 24]
|
101 |
+
("vco_2", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
|
102 |
+
("vco_2", "tuning"): tensor([0.0]), # [-24.0, 24]
|
103 |
+
|
104 |
+
# LFOs
|
105 |
+
("lfo_amp_sin", "frequency"): tensor([amp_mod_freq]), # [0, 20]
|
106 |
+
("lfo_amp_sin", "mod_depth"): tensor([0]), # [-10, 20]
|
107 |
+
("lfo_pitch_sin_1", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
|
108 |
+
("lfo_pitch_sin_1", "mod_depth"): tensor([10]), # [-10, 20]
|
109 |
+
("lfo_pitch_sin_2", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
|
110 |
+
("lfo_pitch_sin_2", "mod_depth"): tensor([10]), # [-10, 20]
|
111 |
+
}
|
112 |
+
|
113 |
+
osc_types = params.get("osc_types", 0)
|
114 |
+
if osc_types == 0:
|
115 |
+
synth = self.SinSawVoice
|
116 |
+
syn_parameters[("vco_2", "shape")] = tensor([1])
|
117 |
+
elif osc_types == 1:
|
118 |
+
synth = self.SinSawVoice
|
119 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
120 |
+
elif osc_types == 2:
|
121 |
+
synth = self.Saw_Square_Voice
|
122 |
+
syn_parameters[("vco_1", "shape")] = tensor([1])
|
123 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
124 |
+
elif osc_types == 3:
|
125 |
+
synth = self.SinTriVoice
|
126 |
+
elif osc_types == 4:
|
127 |
+
synth = self.TriSawVoice
|
128 |
+
syn_parameters[("vco_2", "shape")] = tensor([1])
|
129 |
+
else:
|
130 |
+
synth = self.TriSawVoice
|
131 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
132 |
+
|
133 |
+
synth.set_parameters(syn_parameters)
|
134 |
+
audio_out = synth.get_signal(amp_mod_depth, amp_mod_waveform, int(sample_rate * duration), osc1_amp, osc2_amp)
|
135 |
+
single_note = audio_out[0].detach().numpy()
|
136 |
+
|
137 |
+
cutoff_freq = tf_float32(cutoff_freq)
|
138 |
+
impulse_response = ddsp.core.sinc_impulse_response(cutoff_freq, 2048, self.sample_rate)
|
139 |
+
single_note = tf_float32(single_note)
|
140 |
+
return ddsp.core.fft_convolve(single_note[tf.newaxis, :], impulse_response)[0, :]
|
141 |
+
|
142 |
+
|
143 |
+
class MelodyGenerator:
|
144 |
+
"""
|
145 |
+
This class is responsible for multi-note audio generation by function 'get_melody'.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self,
|
149 |
+
sample_rate=sample_rate,
|
150 |
+
n_note_samples=sample_rate * 4.5,
|
151 |
+
n_melody_samples=sample_rate * 4.5):
|
152 |
+
self.sample_rate = sample_rate
|
153 |
+
self.noteGenerator = NoteGenerator(sample_rate, sample_rate * 4.5)
|
154 |
+
self.n_melody_samples = int(n_melody_samples)
|
155 |
+
|
156 |
+
def get_melody(self, params_list: List[Dict[str, float]], onsets):
|
157 |
+
track = np.zeros(self.n_melody_samples)
|
158 |
+
for i in range(len(onsets)):
|
159 |
+
location = onsets[i]
|
160 |
+
single_note = self.noteGenerator.get_note(params_list[i])
|
161 |
+
single_note = np.hstack(
|
162 |
+
[np.zeros(int(location)), single_note, np.zeros(self.n_melody_samples)])[
|
163 |
+
:self.n_melody_samples]
|
164 |
+
track = track + single_note
|
165 |
+
return track
|
166 |
+
|
167 |
+
|
168 |
+
def plot_log_spectrogram(signal: np.ndarray,
|
169 |
+
path: str,
|
170 |
+
n_fft=2048,
|
171 |
+
frame_length=1024,
|
172 |
+
frame_step=256):
|
173 |
+
"""Write spectrogram."""
|
174 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
175 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
176 |
+
magnitude_spectrum = np.abs(amp)
|
177 |
+
log_mel = np_power_to_db(magnitude_spectrum)
|
178 |
+
matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
|
179 |
+
|
180 |
+
|
181 |
+
def np_power_to_db(S, amin=1e-16, top_db=80.0):
|
182 |
+
"""A helper function for scaling."""
|
183 |
+
|
184 |
+
def np_log10(x):
|
185 |
+
numerator = np.log(x)
|
186 |
+
denominator = np.log(10)
|
187 |
+
return numerator / denominator
|
188 |
+
|
189 |
+
# Scale magnitude relative to maximum value in S. Zeros in the output
|
190 |
+
# correspond to positions where S == ref.
|
191 |
+
ref = np.max(S)
|
192 |
+
|
193 |
+
# 每个元素取max
|
194 |
+
log_spec = 10.0 * np_log10(np.maximum(amin, S))
|
195 |
+
log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
|
196 |
+
|
197 |
+
log_spec = np.maximum(log_spec, np.max(log_spec) - top_db)
|
198 |
+
|
199 |
+
return log_spec
|
200 |
+
|
201 |
+
|
202 |
+
synth = MelodyGenerator()
|
203 |
+
param_descriptions: List[ParameterDescription]
|
204 |
+
|
205 |
+
param_descriptions = [
|
206 |
+
|
207 |
+
# Oscillator levels
|
208 |
+
ParameterDescription(name="osc_amp2",
|
209 |
+
values=parameter_range('osc_amp2'),
|
210 |
+
discrete=is_discrete('osc_amp2')),
|
211 |
+
|
212 |
+
# ADSR params
|
213 |
+
ParameterDescription(name="attack_1",
|
214 |
+
values=parameter_range('attack'),
|
215 |
+
discrete=is_discrete('attack')),
|
216 |
+
ParameterDescription(name="decay_1",
|
217 |
+
values=parameter_range('decay'),
|
218 |
+
discrete=is_discrete('decay')),
|
219 |
+
ParameterDescription(name="sustain_1",
|
220 |
+
values=parameter_range('sustain'),
|
221 |
+
discrete=is_discrete('sustain')),
|
222 |
+
ParameterDescription(name="release_1",
|
223 |
+
values=parameter_range('release'),
|
224 |
+
discrete=is_discrete('release')),
|
225 |
+
ParameterDescription(name="attack_2",
|
226 |
+
values=parameter_range('attack'),
|
227 |
+
discrete=is_discrete('attack')),
|
228 |
+
ParameterDescription(name="decay_2",
|
229 |
+
values=parameter_range('decay'),
|
230 |
+
discrete=is_discrete('decay')),
|
231 |
+
ParameterDescription(name="sustain_2",
|
232 |
+
values=parameter_range('sustain'),
|
233 |
+
discrete=is_discrete('sustain')),
|
234 |
+
ParameterDescription(name="release_2",
|
235 |
+
values=parameter_range('release'),
|
236 |
+
discrete=is_discrete('release')),
|
237 |
+
|
238 |
+
ParameterDescription(name="cutoff_freq",
|
239 |
+
values=parameter_range('cutoff_freq'),
|
240 |
+
discrete=is_discrete('cutoff_freq')),
|
241 |
+
ParameterDescription(name="pitch",
|
242 |
+
values=midi_parameter_range('pitch'),
|
243 |
+
discrete=midi_is_discrete('pitch')),
|
244 |
+
ParameterDescription(name="duration",
|
245 |
+
values=midi_parameter_range('duration'),
|
246 |
+
discrete=midi_is_discrete('duration')),
|
247 |
+
|
248 |
+
ParameterDescription(name="amp_mod_freq",
|
249 |
+
values=parameter_range('amp_mod_freq'),
|
250 |
+
discrete=is_discrete('amp_mod_freq')),
|
251 |
+
ParameterDescription(name="amp_mod_depth",
|
252 |
+
values=parameter_range('amp_mod_depth'),
|
253 |
+
discrete=is_discrete('amp_mod_depth')),
|
254 |
+
|
255 |
+
ParameterDescription(name="pitch_mod_freq_1",
|
256 |
+
values=parameter_range('pitch_mod_freq'),
|
257 |
+
discrete=is_discrete('pitch_mod_freq')),
|
258 |
+
|
259 |
+
ParameterDescription(name="pitch_mod_freq_2",
|
260 |
+
values=parameter_range('pitch_mod_freq'),
|
261 |
+
discrete=is_discrete('pitch_mod_freq')),
|
262 |
+
ParameterDescription(name="pitch_mod_depth",
|
263 |
+
values=parameter_range('pitch_mod_depth'),
|
264 |
+
discrete=is_discrete('pitch_mod_depth')),
|
265 |
+
|
266 |
+
# Oscillators types
|
267 |
+
# 0 for sin saw, 1 for sin square, 2 for saw square
|
268 |
+
# 3 for sin triangle, 4 for triangle saw, 5 for triangle square
|
269 |
+
ParameterDescription(name="osc_types",
|
270 |
+
values=parameter_range('osc_types'),
|
271 |
+
discrete=is_discrete('osc_types')),
|
272 |
+
]
|
273 |
+
|
274 |
+
frame_length = 1024
|
275 |
+
frame_step = 256
|
276 |
+
spectrogram_len = 256
|
277 |
+
|
278 |
+
n_fft = 1024
|
279 |
+
|
280 |
+
|
281 |
+
def generate_synth_dataset_log_muted_512(n: int, path_name="./data/data_log", write_spec=False):
|
282 |
+
if Path(path_name).exists():
|
283 |
+
shutil.rmtree(path_name)
|
284 |
+
|
285 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
286 |
+
print("Generating dataset...")
|
287 |
+
|
288 |
+
synthetic_data = np.ones((n, 512, 256))
|
289 |
+
|
290 |
+
for i in range(n):
|
291 |
+
index = i
|
292 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
293 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
294 |
+
parameter_values_raw["duration"] = 3.0
|
295 |
+
parameter_values_raw["pitch"] = 52
|
296 |
+
parameter_values_raw["pitch_mod_depth"] = 0.0
|
297 |
+
signal = synth.get_melody([parameter_values_raw], [0])
|
298 |
+
# mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
|
299 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
300 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
301 |
+
|
302 |
+
synthetic_data[i] = amp[:512, :256]
|
303 |
+
|
304 |
+
if write_spec:
|
305 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
306 |
+
plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
|
307 |
+
print(f"Generating dataset over, {n} samples generated!")
|
308 |
+
return synthetic_data
|
309 |
+
|
310 |
+
|
311 |
+
def generate_synth_dataset_log_512(n: int, path_name="./data/data_log", write_spec=False):
|
312 |
+
"""Generate the synthetic dataset with a progress bar."""
|
313 |
+
if Path(path_name).exists():
|
314 |
+
shutil.rmtree(path_name)
|
315 |
+
|
316 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
317 |
+
print("Generating dataset...")
|
318 |
+
|
319 |
+
synthetic_data = np.ones((n, 512, 256))
|
320 |
+
|
321 |
+
for i in tqdm(range(n)):
|
322 |
+
index = i
|
323 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
324 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
325 |
+
parameter_values_raw["duration"] = 3.0
|
326 |
+
parameter_values_raw["pitch"] = 52
|
327 |
+
parameter_values_raw["pitch_mod_depth"] = 0.0
|
328 |
+
signal = synth.get_melody([parameter_values_raw], [0])
|
329 |
+
# mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
|
330 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
331 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
332 |
+
|
333 |
+
synthetic_data[i] = amp[:512, :256]
|
334 |
+
|
335 |
+
if write_spec:
|
336 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
337 |
+
plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
|
338 |
+
print(f"Generating dataset over, {n} samples generated!")
|
339 |
+
return synthetic_data
|
340 |
+
|
341 |
+
|
342 |
+
def generate_DANN_dataset_muted(n: int, path_name="./data/data_DANN", write_spec=False):
|
343 |
+
"""Generate the synthetic dataset without a progress bar."""
|
344 |
+
if Path(path_name).exists():
|
345 |
+
shutil.rmtree(path_name)
|
346 |
+
|
347 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
348 |
+
print("Generating dataset...")
|
349 |
+
|
350 |
+
multinote_data = np.ones((n, 512, 256))
|
351 |
+
single_data = np.ones((n, 512, 256))
|
352 |
+
for i in range(n):
|
353 |
+
index = i
|
354 |
+
par_list = []
|
355 |
+
n_notes = np.random.randint(1, 5)
|
356 |
+
onsets = []
|
357 |
+
for j in range(n_notes):
|
358 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
359 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
360 |
+
# parameter_values_raw["duration"] = 0.5
|
361 |
+
parameter_values_raw["pitch_mod_depth"] = 0.0
|
362 |
+
par_list.append(parameter_values_raw)
|
363 |
+
onsets.append(np.random.randint(0, sample_rate * 3))
|
364 |
+
|
365 |
+
signal = synth.get_melody(par_list, onsets)
|
366 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
367 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
368 |
+
multinote_data[i] = amp[:512, :256]
|
369 |
+
if write_spec:
|
370 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
371 |
+
plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
|
372 |
+
frame_step=frame_step)
|
373 |
+
|
374 |
+
single_par = par_list[np.argmin(onsets)]
|
375 |
+
single_par["duration"] = 3.0
|
376 |
+
single_par["pitch"] = 52
|
377 |
+
signal = synth.get_melody([single_par], [0])
|
378 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
379 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
380 |
+
single_data[i] = amp[:512, :256]
|
381 |
+
if write_spec:
|
382 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
383 |
+
plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
|
384 |
+
frame_step=frame_step)
|
385 |
+
print(f"Generating dataset over, {n} samples generated!")
|
386 |
+
return multinote_data, single_data
|
387 |
+
|
388 |
+
|
389 |
+
def generate_DANN_dataset(n: int, path_name="./data/data_DANN", write_spec=False):
|
390 |
+
"""Generate the synthetic dataset for adversarial training."""
|
391 |
+
if Path(path_name).exists():
|
392 |
+
shutil.rmtree(path_name)
|
393 |
+
|
394 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
395 |
+
print("Generating dataset...")
|
396 |
+
|
397 |
+
multinote_data = np.ones((n, 512, 256))
|
398 |
+
single_data = np.ones((n, 512, 256))
|
399 |
+
for i in tqdm(range(n)):
|
400 |
+
par_list = []
|
401 |
+
n_notes = np.random.randint(1, 5)
|
402 |
+
onsets = []
|
403 |
+
for j in range(n_notes):
|
404 |
+
parameter_values = [param.generate() for param in param_descriptions]
|
405 |
+
parameter_values_raw = {param.name: param.value for param in parameter_values}
|
406 |
+
parameter_values_raw["pitch_mod_depth"] = 0.0
|
407 |
+
par_list.append(parameter_values_raw)
|
408 |
+
onsets.append(np.random.randint(0, sample_rate * 3))
|
409 |
+
|
410 |
+
signal = synth.get_melody(par_list, onsets)
|
411 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
412 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
413 |
+
multinote_data[i] = amp[:512, :256]
|
414 |
+
if write_spec:
|
415 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
416 |
+
plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
|
417 |
+
frame_step=frame_step)
|
418 |
+
|
419 |
+
single_par = par_list[np.argmin(onsets)]
|
420 |
+
single_par["duration"] = 3.0
|
421 |
+
single_par["pitch"] = 52
|
422 |
+
signal = synth.get_melody([single_par], [0])
|
423 |
+
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
|
424 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
425 |
+
single_data[i] = amp[:512, :256]
|
426 |
+
if write_spec:
|
427 |
+
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
|
428 |
+
plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
|
429 |
+
frame_step=frame_step)
|
430 |
+
print(f"Generating dataset over, {n} samples generated!")
|
431 |
+
return multinote_data, single_data
|
load_data.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
import numpy as np
|
3 |
+
from generate_synthetic_data_online import generate_synth_dataset_log_512, generate_synth_dataset_log_muted_512
|
4 |
+
from tools import show_spc, spc_to_VAE_input, VAE_out_put_to_spc, np_log10
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
|
8 |
+
class Data_cache():
|
9 |
+
"""This is a class that stores synthetic data."""
|
10 |
+
|
11 |
+
def __init__(self, synthetic_data, external_sources):
|
12 |
+
self.n_synthetic = np.shape(synthetic_data)[0]
|
13 |
+
self.synthetic_data = synthetic_data.astype(np.float32)
|
14 |
+
self.external_sources = external_sources.astype(np.float32)
|
15 |
+
self.epsilon = 1e-20
|
16 |
+
|
17 |
+
def get_all_data(self):
|
18 |
+
return np.vstack([self.synthetic_data, self.external_sources])
|
19 |
+
|
20 |
+
def refresh(self):
|
21 |
+
self.synthetic_data = generate_synth_dataset(self.n_synthetic, mute=True)
|
22 |
+
|
23 |
+
def get_data_loader(self, shuffle=True, BATCH_SIZE=8, new_way=False):
|
24 |
+
all_data = self.get_all_data()
|
25 |
+
our_data = []
|
26 |
+
for i in range(len(all_data)):
|
27 |
+
if new_way:
|
28 |
+
spectrogram = VAE_out_put_to_spc(np.reshape(all_data[i], (1, 512, 256)))
|
29 |
+
log_spectrogram = np.log10(spectrogram + self.epsilon)
|
30 |
+
our_data.append(log_spectrogram)
|
31 |
+
else:
|
32 |
+
our_data.append(np.reshape(all_data[i], (1, 512, 256)))
|
33 |
+
|
34 |
+
iterator = data.DataLoader(our_data, shuffle=shuffle, batch_size=BATCH_SIZE)
|
35 |
+
return iterator
|
36 |
+
|
37 |
+
|
38 |
+
def generate_synth_dataset(n_synthetic, mute=False):
|
39 |
+
"""Preprocessing for synthetic data"""
|
40 |
+
n_synthetic_sample = n_synthetic
|
41 |
+
if mute:
|
42 |
+
Input0 = generate_synth_dataset_log_muted_512(n_synthetic_sample)
|
43 |
+
else:
|
44 |
+
Input0 = generate_synth_dataset_log_512(n_synthetic_sample)
|
45 |
+
Input0 = spc_to_VAE_input(Input0)
|
46 |
+
Input0 = Input0.reshape(Input0.shape[0], Input0.shape[1], Input0.shape[2], 1)
|
47 |
+
return Input0
|
48 |
+
|
49 |
+
|
50 |
+
def read_data(data_path):
|
51 |
+
"""Read external sources"""
|
52 |
+
|
53 |
+
data = np.array(joblib.load(data_path))
|
54 |
+
data = spc_to_VAE_input(data)
|
55 |
+
data = data.reshape(data.shape[0], data.shape[1], data.shape[2], 1)
|
56 |
+
return data
|
57 |
+
|
58 |
+
|
59 |
+
def load_data(n_synthetic):
|
60 |
+
"""Generate the hybrid dataset."""
|
61 |
+
Input_synthetic = generate_synth_dataset(n_synthetic)
|
62 |
+
|
63 |
+
Input_AU = read_data("./data/external_data/ARTURIA_data")
|
64 |
+
print("ARTURIA dataset loaded.")
|
65 |
+
|
66 |
+
Input_NSynth = read_data("./data/external_data/NSynth_data")
|
67 |
+
print("NSynth dataset loaded.")
|
68 |
+
|
69 |
+
Input_SF = read_data("./data/external_data/soundfonts_data")
|
70 |
+
Input_SF_256 = np.zeros((337, 512, 256, 1))
|
71 |
+
Input_SF_256[:,:,:251,:] += Input_SF
|
72 |
+
Input_SF =Input_SF_256
|
73 |
+
print("SoundFonts dataset loaded.")
|
74 |
+
|
75 |
+
Input_google = read_data("./data/external_data/WaveNet_samples")
|
76 |
+
|
77 |
+
Input_external = np.vstack([Input_AU, Input_NSynth, Input_SF, Input_google])
|
78 |
+
data_cache = Data_cache(Input_synthetic, Input_external)
|
79 |
+
print(f"Data loaded, data shape: {np.shape(data_cache.get_all_data())}")
|
80 |
+
return data_cache
|
81 |
+
|
82 |
+
|
83 |
+
def show_data(dataset_name, n_sample=3, index=-1, new_way=False):
|
84 |
+
"""Show and return a certain dataset.
|
85 |
+
|
86 |
+
Parameters
|
87 |
+
----------
|
88 |
+
dataset_name: String
|
89 |
+
Name of the dataset to show.
|
90 |
+
n_samples: int
|
91 |
+
Number of samples to show.
|
92 |
+
index: int
|
93 |
+
Setting 'index' larger equal 0 shows the 'index'-th sample in the desired dataset.
|
94 |
+
|
95 |
+
Returns
|
96 |
+
-------
|
97 |
+
np.ndarray:
|
98 |
+
The showed dataset.
|
99 |
+
"""
|
100 |
+
|
101 |
+
if dataset_name == "ARTURIA":
|
102 |
+
data = read_data("./data/external_data/ARTURIA_data")
|
103 |
+
elif dataset_name == "NSynth":
|
104 |
+
data = read_data("./data/external_data/NSynth_data")
|
105 |
+
elif dataset_name == "SoundFonts":
|
106 |
+
data = read_data("./data/external_data/soundfonts_data")
|
107 |
+
elif dataset_name == "Synthetic":
|
108 |
+
data = generate_synth_dataset(int(n_sample * 3))
|
109 |
+
else:
|
110 |
+
print("Example command: \"!python thesis_main.py show_data -s [ARTURIA, NSynth, SoundFonts, Synthetic] -n 5\"")
|
111 |
+
return
|
112 |
+
|
113 |
+
if index >= 0:
|
114 |
+
show_spc(VAE_out_put_to_spc(data[index]))
|
115 |
+
else:
|
116 |
+
for i in range(n_sample):
|
117 |
+
index = np.random.randint(0,len(data))
|
118 |
+
print(index)
|
119 |
+
show_spc(VAE_out_put_to_spc(data[index]))
|
120 |
+
return data
|
121 |
+
|
122 |
+
|
123 |
+
def show_data(tensor_batch, index=-1, new_way=False):
|
124 |
+
if index < 0:
|
125 |
+
index = np.random.randint(0, tensor_batch.shape[0])
|
126 |
+
|
127 |
+
if new_way:
|
128 |
+
sample = tensor_batch[index].detach().numpy()
|
129 |
+
spectrogram = 10.0 ** sample
|
130 |
+
print(f"The {index}-th sample:")
|
131 |
+
show_spc(spectrogram)
|
132 |
+
else:
|
133 |
+
sample = tensor_batch[index].detach().numpy()
|
134 |
+
show_spc(VAE_out_put_to_spc(sample))
|
135 |
+
# return data
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
melody_synth/complex_torch_synth.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import Tensor, tensor
|
6 |
+
|
7 |
+
from torchsynth.config import SynthConfig
|
8 |
+
from torchsynth.module import (
|
9 |
+
ADSR,
|
10 |
+
VCA,
|
11 |
+
AudioMixer,
|
12 |
+
ControlRateUpsample,
|
13 |
+
MonophonicKeyboard,
|
14 |
+
SineVCO,
|
15 |
+
SquareSawVCO,
|
16 |
+
VCO, LFO, ModulationMixer,
|
17 |
+
)
|
18 |
+
from torchsynth.signal import Signal
|
19 |
+
from torchsynth.synth import AbstractSynth
|
20 |
+
|
21 |
+
# from configurations.read_configuration import get_conf_sample_rate
|
22 |
+
from melody_synth.non_random_LFOs import SinLFO, SawLFO, TriLFO, SquareLFO, RSawLFO
|
23 |
+
|
24 |
+
|
25 |
+
class TriangleVCO(VCO):
|
26 |
+
"""This is an expanded module that inherits VCO producing Triangle waves."""
|
27 |
+
|
28 |
+
def oscillator(self, argument: Signal, midi_f0: Tensor) -> Signal:
|
29 |
+
return torch.arcsin(torch.sin(argument * 2)) * 2.0 / torch.pi
|
30 |
+
|
31 |
+
|
32 |
+
class AmpModTorchSynth(AbstractSynth):
|
33 |
+
"""This is an abstract class using the modules provided by 1B1Synth to assemble synthesizers that generate the
|
34 |
+
training set. (The implementation of this class references code in TorchSynth) """
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
synthconfig: Optional[SynthConfig] = None,
|
39 |
+
nebula: Optional[str] = "nebula",
|
40 |
+
*args,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
AbstractSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
|
44 |
+
self.share_modules = [
|
45 |
+
("keyboard", MonophonicKeyboard),
|
46 |
+
("adsr_1", ADSR),
|
47 |
+
("adsr_2", ADSR),
|
48 |
+
("upsample", ControlRateUpsample),
|
49 |
+
("vca", VCA),
|
50 |
+
("lfo_amp_sin", SinLFO),
|
51 |
+
("lfo_pitch_sin_1", SinLFO),
|
52 |
+
("lfo_pitch_sin_2", SinLFO),
|
53 |
+
(
|
54 |
+
"mixer",
|
55 |
+
AudioMixer,
|
56 |
+
{
|
57 |
+
"n_input": 2,
|
58 |
+
"curves": [1.0, 1.0],
|
59 |
+
"names": ["vco_1", "vco_2"],
|
60 |
+
},
|
61 |
+
)
|
62 |
+
]
|
63 |
+
|
64 |
+
def output(self) -> Tensor:
|
65 |
+
"""Synthesizes the signal as Tensor"""
|
66 |
+
|
67 |
+
midi_f0, note_on_duration = self.keyboard()
|
68 |
+
adsr1 = self.adsr_1(note_on_duration)
|
69 |
+
adsr1 = self.upsample(adsr1)
|
70 |
+
|
71 |
+
adsr2 = self.adsr_2(note_on_duration)
|
72 |
+
adsr2 = self.upsample(adsr2)
|
73 |
+
|
74 |
+
amp_modulation = self.lfo_amp_sin()
|
75 |
+
amp_modulation = self.upsample(amp_modulation)
|
76 |
+
|
77 |
+
pitch_modulation_1 = self.lfo_pitch_sin_1()
|
78 |
+
pitch_modulation_1 = self.upsample(pitch_modulation_1)
|
79 |
+
pitch_modulation_2 = self.lfo_pitch_sin_2()
|
80 |
+
pitch_modulation_2 = self.upsample(pitch_modulation_2)
|
81 |
+
|
82 |
+
vco_amp1 = adsr1 * (amp_modulation * 0.5 + 0.5)
|
83 |
+
vco_amp2 = adsr2 * (amp_modulation * 0.5 + 0.5)
|
84 |
+
vco_1_out = self.vco_1(midi_f0, pitch_modulation_1)
|
85 |
+
vco_1_out = self.vca(vco_1_out, vco_amp1)
|
86 |
+
|
87 |
+
vco_2_out = self.vco_2(midi_f0, pitch_modulation_2)
|
88 |
+
vco_2_out = self.vca(vco_2_out, vco_amp2)
|
89 |
+
|
90 |
+
return self.mixer(vco_1_out, vco_2_out)
|
91 |
+
|
92 |
+
def get_signal(self, amp_mod_depth, amp_waveform, duration_l, amp1, amp2):
|
93 |
+
"""Synthesizes the signal as Tensor"""
|
94 |
+
|
95 |
+
midi_f0, note_on_duration = self.keyboard()
|
96 |
+
adsr1 = self.adsr_1(note_on_duration)
|
97 |
+
adsr1 = self.upsample(adsr1)
|
98 |
+
|
99 |
+
adsr2 = self.adsr_2(note_on_duration)
|
100 |
+
adsr2 = self.upsample(adsr2)
|
101 |
+
|
102 |
+
amp_modulation = self.lfo_amp_sin()
|
103 |
+
amp_modulation = self.upsample(amp_modulation)
|
104 |
+
|
105 |
+
pitch_modulation_1 = self.lfo_pitch_sin_1()
|
106 |
+
pitch_modulation_1 = self.upsample(pitch_modulation_1)
|
107 |
+
pitch_modulation_2 = self.lfo_pitch_sin_2()
|
108 |
+
pitch_modulation_2 = self.upsample(pitch_modulation_2)
|
109 |
+
|
110 |
+
vco_amp1 = adsr1 * (amp_modulation * 0.5 + 0.5)
|
111 |
+
vco_amp2 = adsr2 * (amp_modulation * 0.5 + 0.5)
|
112 |
+
vco_1_out = self.vco_1(midi_f0, pitch_modulation_1)
|
113 |
+
vco_1_out = self.vca(vco_1_out, vco_amp1)
|
114 |
+
|
115 |
+
vco_2_out = self.vco_2(midi_f0, pitch_modulation_1)
|
116 |
+
vco_2_out = self.vca(vco_2_out, vco_amp2)
|
117 |
+
|
118 |
+
return self.mixer(vco_1_out, vco_2_out)
|
119 |
+
|
120 |
+
|
121 |
+
class DoubleSawSynth(AmpModTorchSynth):
|
122 |
+
"""In addition to the shared modules, this synthesizer uses two "SquareSawVCO" modules to generate square and
|
123 |
+
sawtooth waves"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
synthconfig: Optional[SynthConfig] = None,
|
128 |
+
nebula: Optional[str] = "saw_square_voice",
|
129 |
+
*args,
|
130 |
+
**kwargs,
|
131 |
+
):
|
132 |
+
AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
|
133 |
+
|
134 |
+
# Register all modules as children
|
135 |
+
module_list = self.share_modules
|
136 |
+
module_list.append(("vco_1", SquareSawVCO))
|
137 |
+
module_list.append(("vco_2", SquareSawVCO))
|
138 |
+
self.add_synth_modules(module_list)
|
139 |
+
|
140 |
+
|
141 |
+
class SinSawSynth(AmpModTorchSynth):
|
142 |
+
"""In addition to the shared modules, this synthesizer uses a "SinVco" and a "SquareSawVCO" to generate
|
143 |
+
sine and sawtooth/square waves """
|
144 |
+
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
synthconfig: Optional[SynthConfig] = None,
|
148 |
+
nebula: Optional[str] = "sin_saw_voice",
|
149 |
+
*args,
|
150 |
+
**kwargs,
|
151 |
+
):
|
152 |
+
AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
|
153 |
+
|
154 |
+
# Register all modules as children
|
155 |
+
module_list = self.share_modules
|
156 |
+
module_list.append(("vco_1", SineVCO))
|
157 |
+
module_list.append(("vco_2", SquareSawVCO))
|
158 |
+
self.add_synth_modules(module_list)
|
159 |
+
|
160 |
+
|
161 |
+
class SinTriangleSynth(AmpModTorchSynth):
|
162 |
+
"""In addition to the shared modules, this synthesizer uses a "SinVco" and a "TriangleVCO" to generate
|
163 |
+
sine and triangle waves """
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
synthconfig: Optional[SynthConfig] = None,
|
168 |
+
nebula: Optional[str] = "sin_tri_voice",
|
169 |
+
*args,
|
170 |
+
**kwargs,
|
171 |
+
):
|
172 |
+
AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
|
173 |
+
|
174 |
+
# Register all modules as children
|
175 |
+
module_list = self.share_modules
|
176 |
+
module_list.append(("vco_1", SineVCO))
|
177 |
+
module_list.append(("vco_2", TriangleVCO))
|
178 |
+
self.add_synth_modules(module_list)
|
179 |
+
|
180 |
+
|
181 |
+
class TriangleSawSynth(AmpModTorchSynth):
|
182 |
+
"""In addition to the shared modules, this synthesizer uses a "TriangleVCO" and a "SquareSawVCO" to generate
|
183 |
+
triangle and sawtooth/square waves """
|
184 |
+
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
synthconfig: Optional[SynthConfig] = None,
|
188 |
+
nebula: Optional[str] = "triangle_saw_voice",
|
189 |
+
*args,
|
190 |
+
**kwargs,
|
191 |
+
):
|
192 |
+
AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
|
193 |
+
|
194 |
+
# Register all modules as children
|
195 |
+
module_list = self.share_modules
|
196 |
+
module_list.append(("vco_1", TriangleVCO))
|
197 |
+
module_list.append(("vco_2", SquareSawVCO))
|
198 |
+
self.add_synth_modules(module_list)
|
199 |
+
|
200 |
+
|
201 |
+
def amp_mod_with_duration(env, duration_l):
|
202 |
+
env_np = env.detach().numpy()[0] + 1e-30
|
203 |
+
env_np_shift = np.hstack([[0], env_np[:-1]])
|
204 |
+
env_np_sign = (env_np - env_np_shift)[:duration_l] + 1e-30
|
205 |
+
|
206 |
+
env_np_sign_nor = np.around(env_np_sign / np.abs(env_np_sign))
|
207 |
+
env_np_sign_nor_shift = np.hstack([[0], env_np_sign_nor[:-1]])
|
208 |
+
extreme_points = (env_np_sign_nor - env_np_sign_nor_shift)
|
209 |
+
|
210 |
+
(max_loc,) = np.where(extreme_points == -2)
|
211 |
+
|
212 |
+
n_max = len(max_loc)
|
213 |
+
if n_max == 0:
|
214 |
+
return env
|
215 |
+
else:
|
216 |
+
last_max_loc = max_loc[n_max - 1] - 1
|
217 |
+
# new_env = np.hstack([env_np[:last_max_loc], np.ones(len(env_np) - last_max_loc) * env_np[last_max_loc]])
|
218 |
+
new_env = np.hstack([env_np[:last_max_loc], (env_np[last_max_loc:] * 0.8 + 0.2)])
|
219 |
+
|
220 |
+
|
221 |
+
return tensor([new_env])
|
melody_synth/melody_generator.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import torch
|
3 |
+
from ddsp.core import tf_float32
|
4 |
+
import tensorflow as tf
|
5 |
+
import ddsp
|
6 |
+
import numpy as np
|
7 |
+
from torch import tensor
|
8 |
+
from melody_synth.complex_torch_synth import SinSawSynth, DoubleSawSynth, TriangleSawSynth, SinTriangleSynth
|
9 |
+
from torchsynth.config import SynthConfig
|
10 |
+
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
device = "cuda"
|
13 |
+
else:
|
14 |
+
device = "cpu"
|
15 |
+
|
16 |
+
|
17 |
+
class MelodyGenerator:
|
18 |
+
"""This is the only external interface of the melody_synth package."""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
sample_rate: int,
|
22 |
+
n_note_samples: int,
|
23 |
+
n_melody_samples: int):
|
24 |
+
self.sample_rate = sample_rate
|
25 |
+
self.n_note_samples = n_note_samples
|
26 |
+
self.n_melody_samples = n_melody_samples
|
27 |
+
synthconfig = SynthConfig(
|
28 |
+
batch_size=1, reproducible=False, sample_rate=sample_rate,
|
29 |
+
buffer_size_seconds=np.float64(n_note_samples) / np.float64(sample_rate)
|
30 |
+
)
|
31 |
+
self.Saw_Square_Voice = DoubleSawSynth(synthconfig)
|
32 |
+
self.SinSawVoice = SinSawSynth(synthconfig)
|
33 |
+
self.SinTriVoice = SinTriangleSynth(synthconfig)
|
34 |
+
self.TriSawVoice = TriangleSawSynth(synthconfig)
|
35 |
+
|
36 |
+
def get_melody(self, params: Dict[str, float], midi) -> [tf.Tensor]:
|
37 |
+
"""Generates a random melody audio.
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
params: Dict[str, float]
|
42 |
+
Dictionary of specifications (see Readme).
|
43 |
+
midi: List[float, float, float]
|
44 |
+
Melody midi (see Readme).
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
onsets: List[tf.Tensor]
|
49 |
+
Audio.
|
50 |
+
"""
|
51 |
+
|
52 |
+
osc1_amp = np.float(params.get("osc1_amp", 0))
|
53 |
+
osc2_amp = np.float(params.get("osc2_amp", 0))
|
54 |
+
attack = np.float(params.get("attack", 0))
|
55 |
+
decay = np.float(params.get("decay", 0))
|
56 |
+
sustain = np.float(params.get("sustain", 0))
|
57 |
+
release = np.float(params.get("release", 0))
|
58 |
+
cutoff_freq = params.get("cutoff_freq", 4000)
|
59 |
+
|
60 |
+
syn_parameters = {
|
61 |
+
("adsr", "attack"): tensor([attack]), # [0.0, 2.0]
|
62 |
+
("adsr", "decay"): tensor([decay]), # [0.0, 2.0]
|
63 |
+
("adsr", "sustain"): tensor([sustain]), # [0.0, 2.0]
|
64 |
+
("adsr", "release"): tensor([release]), # [0.0, 2.0]
|
65 |
+
("adsr", "alpha"): tensor([3]), # [0.1, 6.0]
|
66 |
+
|
67 |
+
# Mixer parameter
|
68 |
+
("mixer", "vco_1"): tensor([osc1_amp]), # [0, 1]
|
69 |
+
("mixer", "vco_2"): tensor([osc2_amp]), # [0, 1]
|
70 |
+
|
71 |
+
# Constant parameters:
|
72 |
+
("vco_1", "mod_depth"): tensor([0.0]), # [-96, 96]
|
73 |
+
("vco_1", "tuning"): tensor([0.0]), # [-24.0, 24]
|
74 |
+
("vco_2", "mod_depth"): tensor([0.0]), # [-96, 96]
|
75 |
+
("vco_2", "tuning"): tensor([0.0]), # [-24.0, 24]
|
76 |
+
}
|
77 |
+
|
78 |
+
osc_types = params.get("osc_types", 0)
|
79 |
+
if osc_types == 0:
|
80 |
+
synth = self.SinSawVoice
|
81 |
+
syn_parameters[("vco_2", "shape")] = tensor([1])
|
82 |
+
elif osc_types == 1:
|
83 |
+
synth = self.SinSawVoice
|
84 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
85 |
+
elif osc_types == 2:
|
86 |
+
synth = self.Saw_Square_Voice
|
87 |
+
syn_parameters[("vco_1", "shape")] = tensor([1])
|
88 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
89 |
+
elif osc_types == 3:
|
90 |
+
synth = self.SinTriVoice
|
91 |
+
elif osc_types == 4:
|
92 |
+
synth = self.TriSawVoice
|
93 |
+
syn_parameters[("vco_2", "shape")] = tensor([1])
|
94 |
+
else:
|
95 |
+
synth = self.TriSawVoice
|
96 |
+
syn_parameters[("vco_2", "shape")] = tensor([0])
|
97 |
+
|
98 |
+
track = np.zeros(self.n_melody_samples)
|
99 |
+
for i in range(len(midi)):
|
100 |
+
(location, pitch, duration) = midi[i]
|
101 |
+
syn_parameters[("keyboard", "midi_f0")] = tensor([pitch])
|
102 |
+
syn_parameters[("keyboard", "duration")] = tensor([duration])
|
103 |
+
synth.set_parameters(syn_parameters)
|
104 |
+
|
105 |
+
audio_out, parameters, is_train = synth()
|
106 |
+
single_note = audio_out[0]
|
107 |
+
|
108 |
+
single_note = np.hstack(
|
109 |
+
[np.zeros(int(location * self.sample_rate)), single_note, np.zeros(self.n_melody_samples)])[
|
110 |
+
:self.n_melody_samples]
|
111 |
+
track = track + single_note
|
112 |
+
|
113 |
+
no_cutoff = False
|
114 |
+
if no_cutoff:
|
115 |
+
return track
|
116 |
+
cutoff_freq = tf_float32(cutoff_freq)
|
117 |
+
impulse_response = ddsp.core.sinc_impulse_response(cutoff_freq,
|
118 |
+
2048,
|
119 |
+
self.sample_rate)
|
120 |
+
track = tf_float32(track)
|
121 |
+
return ddsp.core.fft_convolve(track[tf.newaxis, :], impulse_response)[0, :]
|
melody_synth/non_random_LFOs.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, tensor
|
5 |
+
|
6 |
+
from torchsynth.config import SynthConfig
|
7 |
+
from torchsynth.module import (
|
8 |
+
VCO, LFO, ModulationMixer,
|
9 |
+
)
|
10 |
+
from torchsynth.signal import Signal
|
11 |
+
from torchsynth.synth import AbstractSynth
|
12 |
+
|
13 |
+
|
14 |
+
class SinLFO(LFO):
|
15 |
+
"""A LFO that generates the sine waveform.
|
16 |
+
(The implementation of this class is a modification of the code in TorchSynth) """
|
17 |
+
|
18 |
+
def output(self, mod_signal: Optional[Signal] = None) -> Signal:
|
19 |
+
# This module accepts signals at control rate!
|
20 |
+
if mod_signal is not None:
|
21 |
+
assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
|
22 |
+
|
23 |
+
frequency = self.make_control(mod_signal)
|
24 |
+
argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
|
25 |
+
argument = argument + self.p("initial_phase").unsqueeze(1)
|
26 |
+
|
27 |
+
shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
|
28 |
+
|
29 |
+
mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
|
30 |
+
mode[0] = tensor([1.0, 0., 0., 0., 0.])
|
31 |
+
mode = torch.pow(mode, self.exponent)
|
32 |
+
mode = mode / torch.sum(mode, dim=1, keepdim=True)
|
33 |
+
return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
|
34 |
+
|
35 |
+
|
36 |
+
class TriLFO(LFO):
|
37 |
+
"""A LFO that generates the triangle waveform.
|
38 |
+
(The implementation of this class is a modification of the code in TorchSynth) """
|
39 |
+
|
40 |
+
def output(self, mod_signal: Optional[Signal] = None) -> Signal:
|
41 |
+
# This module accepts signals at control rate!
|
42 |
+
if mod_signal is not None:
|
43 |
+
assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
|
44 |
+
|
45 |
+
frequency = self.make_control(mod_signal)
|
46 |
+
argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
|
47 |
+
argument = argument + self.p("initial_phase").unsqueeze(1)
|
48 |
+
|
49 |
+
shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
|
50 |
+
|
51 |
+
mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
|
52 |
+
mode[0] = tensor([0.5, 0.5, 0., 0., 0.])
|
53 |
+
mode = torch.pow(mode, self.exponent)
|
54 |
+
mode = mode / torch.sum(mode, dim=1, keepdim=True)
|
55 |
+
return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
|
56 |
+
|
57 |
+
|
58 |
+
class SawLFO(LFO):
|
59 |
+
"""A LFO that generates the sawtooth waveform.
|
60 |
+
(The implementation of this class is a modification of the code in TorchSynth) """
|
61 |
+
|
62 |
+
def output(self, mod_signal: Optional[Signal] = None) -> Signal:
|
63 |
+
# This module accepts signals at control rate!
|
64 |
+
if mod_signal is not None:
|
65 |
+
assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
|
66 |
+
|
67 |
+
frequency = self.make_control(mod_signal)
|
68 |
+
argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
|
69 |
+
argument = argument + self.p("initial_phase").unsqueeze(1)
|
70 |
+
|
71 |
+
shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
|
72 |
+
|
73 |
+
mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
|
74 |
+
mode[0] = tensor([0.5, 0., 0.5, 0., 0.])
|
75 |
+
mode = torch.pow(mode, self.exponent)
|
76 |
+
mode = mode / torch.sum(mode, dim=1, keepdim=True)
|
77 |
+
return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
|
78 |
+
|
79 |
+
|
80 |
+
class RSawLFO(LFO):
|
81 |
+
"""A LFO that generates the sawtooth waveform.
|
82 |
+
(The implementation of this class is a modification of the code in TorchSynth) """
|
83 |
+
|
84 |
+
def output(self, mod_signal: Optional[Signal] = None) -> Signal:
|
85 |
+
# This module accepts signals at control rate!
|
86 |
+
if mod_signal is not None:
|
87 |
+
assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
|
88 |
+
|
89 |
+
frequency = self.make_control(mod_signal)
|
90 |
+
argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
|
91 |
+
argument = argument + self.p("initial_phase").unsqueeze(1)
|
92 |
+
|
93 |
+
shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
|
94 |
+
|
95 |
+
mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
|
96 |
+
mode[0] = tensor([0.5, 0., 0.0, 0.5, 0.])
|
97 |
+
mode = torch.pow(mode, self.exponent)
|
98 |
+
mode = mode / torch.sum(mode, dim=1, keepdim=True)
|
99 |
+
return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
|
100 |
+
|
101 |
+
|
102 |
+
class SquareLFO(LFO):
|
103 |
+
"""A LFO that generates the square waveform.
|
104 |
+
(The implementation of this class is a modification of the code in TorchSynth) """
|
105 |
+
|
106 |
+
def output(self, mod_signal: Optional[Signal] = None) -> Signal:
|
107 |
+
# This module accepts signals at control rate!
|
108 |
+
if mod_signal is not None:
|
109 |
+
assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
|
110 |
+
|
111 |
+
frequency = self.make_control(mod_signal)
|
112 |
+
argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
|
113 |
+
argument = argument + self.p("initial_phase").unsqueeze(1)
|
114 |
+
|
115 |
+
shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
|
116 |
+
|
117 |
+
mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
|
118 |
+
mode[0] = tensor([0.5, 0., 0., 0., 0.5])
|
119 |
+
mode = torch.pow(mode, self.exponent)
|
120 |
+
mode = mode / torch.sum(mode, dim=1, keepdim=True)
|
121 |
+
return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
|
melody_synth/random_duration.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configurations.read_configuration import midi_parameter_range, midi_is_discrete
|
2 |
+
from data_generation.encoding import ParameterDescription
|
3 |
+
|
4 |
+
|
5 |
+
def get_full_duration(midi):
|
6 |
+
"""Uses "full_duration" strategy to generate random duration (see Readme)."""
|
7 |
+
n = len(midi)
|
8 |
+
(time_starting_point, pitch) = midi[n - 1]
|
9 |
+
new_midi = [(time_starting_point, pitch, 0.5)]
|
10 |
+
next_time_starting_point = time_starting_point
|
11 |
+
for i in range(n - 1):
|
12 |
+
(time_starting_point, pitch) = midi[n - 2 - i]
|
13 |
+
duration = (next_time_starting_point - time_starting_point) * 0.9
|
14 |
+
new_midi.insert(0, (time_starting_point, pitch, duration))
|
15 |
+
next_time_starting_point = time_starting_point
|
16 |
+
|
17 |
+
return new_midi
|
18 |
+
|
19 |
+
|
20 |
+
def get_random_duration(midi):
|
21 |
+
"""Uses "random_duration" strategy to generate random duration (see Readme)."""
|
22 |
+
parameterDescription = ParameterDescription(name="duration",
|
23 |
+
values=midi_parameter_range('duration'),
|
24 |
+
discrete=midi_is_discrete('duration'))
|
25 |
+
n = len(midi)
|
26 |
+
new_midi = []
|
27 |
+
for i in range(n):
|
28 |
+
(location, pitch) = midi[i]
|
29 |
+
duration = float(parameterDescription.generate().value)
|
30 |
+
new_midi.append((location, pitch, duration))
|
31 |
+
return new_midi
|
32 |
+
|
33 |
+
|
34 |
+
def get_fixed_duration(midi):
|
35 |
+
return [(location, pitch, 2.0) for (location, pitch) in midi]
|
36 |
+
|
37 |
+
|
38 |
+
def get_limited_random_duration(midi):
|
39 |
+
"""Uses "limited_random_duration" strategy to generate random duration (see Readme)."""
|
40 |
+
parameterDescription = ParameterDescription(name="duration",
|
41 |
+
values=midi_parameter_range('duration'),
|
42 |
+
discrete=midi_is_discrete('duration'))
|
43 |
+
n = len(midi)
|
44 |
+
(time_starting_point, pitch) = midi[n - 1]
|
45 |
+
duration = float(parameterDescription.generate().value)
|
46 |
+
new_midi = [(time_starting_point, pitch, duration)]
|
47 |
+
next_time_starting_point = time_starting_point
|
48 |
+
for i in range(n - 1):
|
49 |
+
(time_starting_point, pitch) = midi[n - 2 - i]
|
50 |
+
max_duration = (next_time_starting_point - time_starting_point) * 0.9
|
51 |
+
duration = float(parameterDescription.generate().value)
|
52 |
+
duration = min(duration, max_duration)
|
53 |
+
new_midi.insert(0, (time_starting_point, pitch, duration))
|
54 |
+
next_time_starting_point = time_starting_point
|
55 |
+
|
56 |
+
return new_midi
|
57 |
+
|
58 |
+
|
59 |
+
class RandomDuration:
|
60 |
+
"""Third component in the random midi pipeline responsible for random duration (keyboard hold time) generating"""
|
61 |
+
|
62 |
+
def __call__(self, strategy: str, midi, *args, **kwargs):
|
63 |
+
"""Choose required strategy to generate random duration for each note.
|
64 |
+
|
65 |
+
Parameters
|
66 |
+
----------
|
67 |
+
strategy: str
|
68 |
+
Strategy names for random duration (see Readme).
|
69 |
+
midi: List[(float, float)]
|
70 |
+
Random rhythm and pitch from previous pipeline component.
|
71 |
+
|
72 |
+
Returns
|
73 |
+
-------
|
74 |
+
midi: List[(float, float, float)]
|
75 |
+
Original input list with duration assigned to each note onset.
|
76 |
+
"""
|
77 |
+
|
78 |
+
if strategy == 'random_duration':
|
79 |
+
midi = get_random_duration(midi)
|
80 |
+
elif strategy == 'limited_random_duration':
|
81 |
+
midi = get_limited_random_duration(midi)
|
82 |
+
elif strategy == 'fixed_duration':
|
83 |
+
midi = get_fixed_duration(midi)
|
84 |
+
else:
|
85 |
+
midi = get_full_duration(midi)
|
86 |
+
return midi
|
melody_synth/random_midi.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from configurations.read_configuration import get_conf_sample_rate, get_conf_stft_hyperparameter,\
|
4 |
+
midi_parameter_range, get_conf_time_resolution, get_conf_max_n_notes
|
5 |
+
from melody_synth.random_duration import RandomDuration
|
6 |
+
from melody_synth.random_pitch import RandomPitch
|
7 |
+
from melody_synth.random_rhythm import RandomRhythm
|
8 |
+
|
9 |
+
|
10 |
+
class RandomMidi:
|
11 |
+
"""Pipeline generating random midi"""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.randomRhythm = RandomRhythm()
|
15 |
+
self.randomPitch = RandomPitch()
|
16 |
+
self.randomDuration = RandomDuration()
|
17 |
+
self.max_n_notes = get_conf_max_n_notes()
|
18 |
+
|
19 |
+
def __call__(self, strategy=None, *args, **kwargs):
|
20 |
+
"""Assembles the pipeline based on given strategies and return random midi.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
strategy: Dict[str, str]
|
25 |
+
Strategies names for random rhythm, pitch and duration generation (see Readme).
|
26 |
+
|
27 |
+
Returns
|
28 |
+
-------
|
29 |
+
encode, midi: List[int], List[(float, float, float)]
|
30 |
+
encode -- Midi's label as a list of 0s and 1s
|
31 |
+
midi -- A list of (onset, pitch, duration) tuples, each tuple refers to a note
|
32 |
+
"""
|
33 |
+
|
34 |
+
if strategy is None:
|
35 |
+
strategy = {"rhythm_strategy": "non-test",
|
36 |
+
"pitch_strategy": "random_major",
|
37 |
+
"duration_strategy": "limited_random",
|
38 |
+
}
|
39 |
+
|
40 |
+
midi = self.randomRhythm(strategy["rhythm_strategy"])
|
41 |
+
midi = self.randomPitch(strategy["pitch_strategy"], midi)
|
42 |
+
midi = self.randomDuration(strategy["duration_strategy"], midi)
|
43 |
+
|
44 |
+
return self.get_encode(midi), midi
|
45 |
+
|
46 |
+
def get_encode(self, midi):
|
47 |
+
"""Generate labels for midi
|
48 |
+
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
midi: List[(onset, pitch, duration)]
|
52 |
+
A list of (onset, pitch, duration) tuples, each tuple refers to a note
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
encode: List[int]
|
57 |
+
Midi's label as a list of 0s and 1s
|
58 |
+
|
59 |
+
Encoding method
|
60 |
+
-------
|
61 |
+
One-hot Encoding for each note. Stack all note labels to form midi label.
|
62 |
+
"""
|
63 |
+
duration_range = midi_parameter_range("duration")
|
64 |
+
pitch_range = midi_parameter_range("pitch")
|
65 |
+
time_resolution = get_conf_time_resolution()
|
66 |
+
|
67 |
+
pixel_duration = get_conf_stft_hyperparameter()["frame_step"] / get_conf_sample_rate()
|
68 |
+
single_note_encode_length = (time_resolution + len(pitch_range) + len(duration_range))
|
69 |
+
encode_length = single_note_encode_length * self.max_n_notes
|
70 |
+
encode = []
|
71 |
+
for i in range(len(midi)):
|
72 |
+
(location, pitch, duration) = midi[i]
|
73 |
+
|
74 |
+
location_index = int(float(location) / pixel_duration)
|
75 |
+
if location_index >= time_resolution:
|
76 |
+
break
|
77 |
+
pitch_index = pitch - pitch_range[0]
|
78 |
+
duration_index = np.argmin(np.abs(np.array(duration_range) - duration))
|
79 |
+
|
80 |
+
single_note_encode = np.zeros(single_note_encode_length)
|
81 |
+
single_note_encode[location_index] = 1
|
82 |
+
single_note_encode[time_resolution + pitch_index] = 1
|
83 |
+
single_note_encode[time_resolution + len(pitch_range) + duration_index] = 1
|
84 |
+
encode = np.hstack([encode, single_note_encode])
|
85 |
+
|
86 |
+
return np.hstack([encode, np.zeros(encode_length)])[:encode_length]
|
melody_synth/random_pitch.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from configurations.read_configuration import midi_parameter_range
|
4 |
+
|
5 |
+
|
6 |
+
class RandomPitch:
|
7 |
+
"""Second component in the random midi pipeline responsible for random rhythm (note onsets) generating"""
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.pitch_range = midi_parameter_range("pitch")
|
11 |
+
self.major = [0, 2, 4, 5, 7, 9, 11]
|
12 |
+
self.minor = [0, 2, 3, 5, 7, 8, 10]
|
13 |
+
|
14 |
+
def __call__(self, strategy: str, onsets, *args, **kwargs):
|
15 |
+
"""Choose required strategy to generate random pitch for each note.
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
strategy: str
|
20 |
+
Strategy names for random pitches (see Readme).
|
21 |
+
onsets: List[float]
|
22 |
+
Random rhythm from previous pipeline component.
|
23 |
+
|
24 |
+
Returns
|
25 |
+
-------
|
26 |
+
midi: List[(float, float)]
|
27 |
+
Original input list with pitches assigned to each note onset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
if strategy == 'random_major':
|
31 |
+
return self.get_random_major(onsets)
|
32 |
+
elif strategy == 'random_minor':
|
33 |
+
return self.get_random_minor(onsets)
|
34 |
+
elif strategy == 'fixed_pitch':
|
35 |
+
return self.get_fixed_pitch(onsets)
|
36 |
+
elif strategy == 'fixed_pitch1':
|
37 |
+
return self.get_fixed_pitch1(onsets)
|
38 |
+
elif strategy == 'fixed_pitch2':
|
39 |
+
return self.get_fixed_pitch2(onsets)
|
40 |
+
elif strategy == 'fixed_pitch3':
|
41 |
+
return self.get_fixed_pitch3(onsets)
|
42 |
+
elif strategy == 'fixed_pitch4':
|
43 |
+
return self.get_fixed_pitch4(onsets)
|
44 |
+
else:
|
45 |
+
return self.get_random_pitch(onsets)
|
46 |
+
|
47 |
+
def get_random_major(self, midi):
|
48 |
+
"""Uses "random_major" strategy to generate random pitches (see Readme)."""
|
49 |
+
random_scale = np.random.randint(0, 12)
|
50 |
+
scale = [one for one in self.pitch_range if (one - random_scale) % 12 in self.major]
|
51 |
+
midi = [(duration, scale[np.random.randint(0, len(scale))]) for duration in midi]
|
52 |
+
# midi[0] = (midi[0][0], random_scale + self.pitch_range[-1])
|
53 |
+
midi[len(midi) - 1] = (midi[len(midi) - 1][0], random_scale + self.pitch_range[0])
|
54 |
+
return midi
|
55 |
+
|
56 |
+
def get_random_pitch(self, midi):
|
57 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
58 |
+
return [(duration, np.random.randint(self.pitch_range[0], self.pitch_range[-1])) for duration in midi]
|
59 |
+
|
60 |
+
def get_fixed_pitch(self, midi):
|
61 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
62 |
+
return [(duration, 48) for duration in midi]
|
63 |
+
|
64 |
+
def get_fixed_pitch1(self, midi):
|
65 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
66 |
+
return [(duration, 55) for duration in midi]
|
67 |
+
|
68 |
+
def get_fixed_pitch2(self, midi):
|
69 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
70 |
+
return [(duration, 62) for duration in midi]
|
71 |
+
|
72 |
+
def get_fixed_pitch3(self, midi):
|
73 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
74 |
+
return [(duration, 69) for duration in midi]
|
75 |
+
|
76 |
+
def get_fixed_pitch4(self, midi):
|
77 |
+
"""Uses "free_pitch" strategy to generate random pitches (see Readme)."""
|
78 |
+
return [(duration, 76) for duration in midi]
|
79 |
+
|
80 |
+
def get_random_minor(self, midi):
|
81 |
+
"""Uses "random_minor" strategy to generate random pitches (see Readme)."""
|
82 |
+
random_scale = np.random.randint(0, 12)
|
83 |
+
scale = [one for one in self.pitch_range if (one - random_scale) % 12 in self.minor]
|
84 |
+
midi = [(duration, scale[np.random.randint(0, len(scale))]) for duration in midi]
|
85 |
+
# midi[0] = (midi[0][0], random_scale + self.pitch_range[-1])
|
86 |
+
midi[len(midi) - 1] = (midi[len(midi) - 1][0], random_scale + self.pitch_range[0])
|
87 |
+
return midi
|
melody_synth/random_rhythm.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from configurations.read_configuration import get_conf_n_sample, get_conf_sample_rate, get_conf_max_n_notes
|
6 |
+
|
7 |
+
|
8 |
+
def get_random_note_type_index(distribution):
|
9 |
+
"""A helper method that randomly chooses next note type based on a distribution
|
10 |
+
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
distribution: List[float]
|
14 |
+
Note type distribution.
|
15 |
+
|
16 |
+
Returns
|
17 |
+
-------
|
18 |
+
midi: int
|
19 |
+
Random type index.
|
20 |
+
"""
|
21 |
+
|
22 |
+
r = np.random.random()
|
23 |
+
for i in range(len(distribution)):
|
24 |
+
r = r - distribution[i]
|
25 |
+
if r < 0:
|
26 |
+
return i
|
27 |
+
return len(distribution) - 1
|
28 |
+
|
29 |
+
|
30 |
+
# Todo: rewrite this part
|
31 |
+
def to_onsets_in_seconds(bpm, notes):
|
32 |
+
"""A helper method that transform a list of note types into a list of note onsets (in seconds)
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
bpm: float
|
37 |
+
BPM
|
38 |
+
notes: List[int]
|
39 |
+
|
40 |
+
|
41 |
+
Returns
|
42 |
+
-------
|
43 |
+
midi: int
|
44 |
+
Random type index.
|
45 |
+
"""
|
46 |
+
|
47 |
+
full_note_length = 4 * 60 / bpm
|
48 |
+
onsets = [0]
|
49 |
+
for i in range(len(notes)):
|
50 |
+
onsets.append(onsets[i] + full_note_length * notes[i])
|
51 |
+
return onsets
|
52 |
+
|
53 |
+
|
54 |
+
class RandomRhythm:
|
55 |
+
"""First component in the random midi pipeline responsible for random rhythm (note onsets) generating"""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
self.note_types = [0, 1, 3 / 4, 0.5, 3 / 8, 0.25, 1 / 8]
|
59 |
+
self.first_note_type_distribution = np.array([0, 0.2, 0.05, 0.25, 0.05, 0.3, 0.15])
|
60 |
+
self.rhythm_generation_matrix = np.array([
|
61 |
+
[0.1, 0.1, 0.25, 0.1, 0.25, 0.2],
|
62 |
+
[0.05, 0.25, 0.25, 0.05, 0.3, 0.1],
|
63 |
+
[0.1, 0.1, 0.3, 0.05, 0.35, 0.1],
|
64 |
+
[0.05, 0.05, 0.2, 0.2, 0.25, 0.25],
|
65 |
+
[0.1, 0.05, 0.1, 0.05, 0.4, 0.3],
|
66 |
+
[0.1, 0.05, 0.1, 0.1, 0.3, 0.35],
|
67 |
+
])
|
68 |
+
# self.bpm = bpm
|
69 |
+
self.rhythm_duration = np.array([0, 1, 3 / 4, 0.5, 3 / 8, 0.25])
|
70 |
+
self.audio_length = get_conf_n_sample() / get_conf_sample_rate()
|
71 |
+
self.bpm_range = [90, 100, 110, 120, 130, 140, 150, 160, 170]
|
72 |
+
self.max_n_notes = get_conf_max_n_notes()
|
73 |
+
|
74 |
+
def __call__(self, strategy: str, *args, **kwargs):
|
75 |
+
"""Choose required strategy to generate random rhythm (note onsets).
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
strategy: str
|
80 |
+
Strategy names for random rhythm (see Readme).
|
81 |
+
|
82 |
+
Returns
|
83 |
+
-------
|
84 |
+
onsets: List[float]
|
85 |
+
A list of floats referring to note onsets in seconds.
|
86 |
+
"""
|
87 |
+
if strategy == 'bpm_based_rhythm':
|
88 |
+
rhythm = self.get_bpm_based_rhythm()
|
89 |
+
elif strategy == 'free_rhythm':
|
90 |
+
rhythm = self.get_free_rhythm()
|
91 |
+
elif strategy == 'single_note_rhythm':
|
92 |
+
rhythm = self.get_single_note()
|
93 |
+
else:
|
94 |
+
rhythm = [0.0, 1, 2, 3, 4]
|
95 |
+
|
96 |
+
return rhythm[:self.max_n_notes]
|
97 |
+
|
98 |
+
def get_bpm_based_rhythm(self):
|
99 |
+
"""Uses "bpm_based_rhythm" strategy to generate random rhythm (see Readme)."""
|
100 |
+
# Todo: clean up this part
|
101 |
+
|
102 |
+
bpm = random.choice(self.bpm_range)
|
103 |
+
|
104 |
+
first_note = get_random_note_type_index(self.first_note_type_distribution)
|
105 |
+
note_type_indexes = [first_note]
|
106 |
+
current_note_type = first_note
|
107 |
+
while True:
|
108 |
+
current_note_type = get_random_note_type_index(self.rhythm_generation_matrix[current_note_type - 1]) + 1
|
109 |
+
note_type_indexes.append(current_note_type)
|
110 |
+
|
111 |
+
# Random early stop
|
112 |
+
if np.random.random() < 9 / bpm:
|
113 |
+
break
|
114 |
+
|
115 |
+
notes = [self.note_types[note_type_index] for note_type_index in note_type_indexes]
|
116 |
+
|
117 |
+
onsets = to_onsets_in_seconds(bpm, notes)
|
118 |
+
return onsets
|
119 |
+
|
120 |
+
def get_free_rhythm(self):
|
121 |
+
"""Uses "free_rhythm" strategy to generate random rhythm (see Readme)."""
|
122 |
+
n_notes = np.random.randint(int(self.max_n_notes * 0.6), self.max_n_notes)
|
123 |
+
# n_notes = np.random.randint(int(1), self.max_n_notes)
|
124 |
+
|
125 |
+
onsets = np.random.rand(n_notes)
|
126 |
+
onsets.sort()
|
127 |
+
|
128 |
+
# Avoid notes too close together
|
129 |
+
pre = onsets[0]
|
130 |
+
n_removed = 0
|
131 |
+
for i in range(len(onsets)-1):
|
132 |
+
index = i - n_removed + 1
|
133 |
+
if (onsets[index] - pre) < 0.05:
|
134 |
+
new_onsets = np.delete(onsets, index)
|
135 |
+
onsets = new_onsets
|
136 |
+
n_removed = n_removed + 1
|
137 |
+
else:
|
138 |
+
pre = onsets[index]
|
139 |
+
|
140 |
+
return ((onsets - onsets[0])*self.audio_length).tolist()
|
141 |
+
|
142 |
+
def get_single_note(self):
|
143 |
+
return [0.0]
|
model/VAE.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras import backend as K
|
3 |
+
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Reshape, Lambda, BatchNormalization
|
4 |
+
from keras.models import Model
|
5 |
+
import numpy as np
|
6 |
+
import threading
|
7 |
+
|
8 |
+
KL = tf.keras.layers
|
9 |
+
|
10 |
+
|
11 |
+
def cbam_layer(inputs_tensor=None, ratio=None):
|
12 |
+
"""Source: https://blog.csdn.net/ZXF_1991/article/details/104615942
|
13 |
+
The channel attention
|
14 |
+
"""
|
15 |
+
channels = K.int_shape(inputs_tensor)[-1]
|
16 |
+
|
17 |
+
def share_layer(inputs=None):
|
18 |
+
x_ = KL.Conv2D(channels // ratio, (1, 1), strides=1, padding="valid")(inputs)
|
19 |
+
x_ = KL.Activation('relu')(x_)
|
20 |
+
output_share = KL.Conv2D(channels, (1, 1), strides=1, padding="valid")(x_)
|
21 |
+
return output_share
|
22 |
+
|
23 |
+
x_global_avg_pool = KL.GlobalAveragePooling2D()(inputs_tensor)
|
24 |
+
x_global_avg_pool = KL.Reshape((1, 1, channels))(x_global_avg_pool)
|
25 |
+
x_global_max_pool = KL.GlobalMaxPool2D()(inputs_tensor)
|
26 |
+
x_global_max_pool = KL.Reshape((1, 1, channels))(x_global_max_pool)
|
27 |
+
x_global_avg_pool = share_layer(x_global_avg_pool)
|
28 |
+
x_global_max_pool = share_layer(x_global_max_pool)
|
29 |
+
x = KL.Add()([x_global_avg_pool, x_global_max_pool])
|
30 |
+
x = KL.Activation('sigmoid')(x)
|
31 |
+
CAM = KL.multiply([inputs_tensor, x])
|
32 |
+
output = CAM
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def res_cell(x, n_channel=64, stride=1):
|
37 |
+
"""The basic unit in the VAE, cell."""
|
38 |
+
if stride == -1:
|
39 |
+
# upsample cell
|
40 |
+
skip = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
|
41 |
+
skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=1, padding='same')(skip)
|
42 |
+
x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), strides=2, padding='same')(x)
|
43 |
+
x = BatchNormalization()(x)
|
44 |
+
x = tf.keras.activations.elu(x)
|
45 |
+
x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
|
46 |
+
|
47 |
+
elif stride == 2:
|
48 |
+
# downsample cell
|
49 |
+
skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=2, padding='same')(x)
|
50 |
+
x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x)
|
51 |
+
x = BatchNormalization()(x)
|
52 |
+
x = tf.keras.activations.elu(x)
|
53 |
+
x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
|
54 |
+
else:
|
55 |
+
# preserving cell
|
56 |
+
skip = tf.identity(x)
|
57 |
+
x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x)
|
58 |
+
x = BatchNormalization()(x)
|
59 |
+
x = tf.keras.activations.elu(x)
|
60 |
+
x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
|
61 |
+
|
62 |
+
x = BatchNormalization()(x)
|
63 |
+
x = cbam_layer(inputs_tensor=x, ratio=8)
|
64 |
+
x = x + skip
|
65 |
+
x = tf.keras.activations.elu(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
def res_block(x, n_channel=64, upsample=False, n_cells=2):
|
70 |
+
"""The block is a stack of cells."""
|
71 |
+
if upsample:
|
72 |
+
x = res_cell(x, n_channel=n_channel, stride=-1)
|
73 |
+
else:
|
74 |
+
x = res_cell(x, n_channel=n_channel, stride=2)
|
75 |
+
for _ in range(n_cells - 1):
|
76 |
+
x = res_cell(x, n_channel=n_channel, stride=1)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
def l1_distance(x1, x2):
|
81 |
+
return tf.reduce_mean(tf.math.abs(x1 - x2))
|
82 |
+
|
83 |
+
|
84 |
+
def l1_log_distance(x1, x2):
|
85 |
+
return tf.reduce_mean(tf.math.abs(tf.math.log(tf.maximum(1e-6, x1)) - tf.math.log(tf.maximum(1e-6, x2))))
|
86 |
+
|
87 |
+
|
88 |
+
img_height = 512
|
89 |
+
img_width = 256
|
90 |
+
num_channels = 1
|
91 |
+
input_shape = (img_height, img_width, num_channels)
|
92 |
+
timbre_dim = 20
|
93 |
+
n_filters = 64
|
94 |
+
act = 'elu'
|
95 |
+
|
96 |
+
|
97 |
+
def compute_latent(x):
|
98 |
+
"""Re-parameterizing."""
|
99 |
+
mu, sigma = x
|
100 |
+
batch = K.shape(mu)[0]
|
101 |
+
dim = K.int_shape(mu)[1]
|
102 |
+
eps = K.random_normal(shape=(batch, dim))
|
103 |
+
return mu + K.exp(sigma / 2) * eps
|
104 |
+
|
105 |
+
|
106 |
+
def get_encoder(N2=0, channel_sizes=None):
|
107 |
+
"""Assemble and return the VAE encoder."""
|
108 |
+
if channel_sizes is None:
|
109 |
+
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
|
110 |
+
encoder_input = Input(shape=input_shape)
|
111 |
+
|
112 |
+
encoder_conv = res_block(encoder_input, channel_sizes[0], upsample=False, n_cells=1)
|
113 |
+
|
114 |
+
for c in channel_sizes[1:]:
|
115 |
+
encoder_conv = res_block(encoder_conv, c, upsample=False, n_cells=1 + N2)
|
116 |
+
|
117 |
+
encoder = Flatten()(encoder_conv)
|
118 |
+
|
119 |
+
mu_timbre = Dense(timbre_dim)(encoder)
|
120 |
+
sigma_timbre = Dense(timbre_dim)(encoder)
|
121 |
+
latent_vector = Lambda(compute_latent, output_shape=(timbre_dim,))([mu_timbre, sigma_timbre])
|
122 |
+
|
123 |
+
kl_loss = -0.5 * (1 + sigma_timbre - tf.square(mu_timbre) - tf.exp(sigma_timbre))
|
124 |
+
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
|
125 |
+
|
126 |
+
encoder = Model(encoder_input, [latent_vector, kl_loss])
|
127 |
+
return encoder
|
128 |
+
|
129 |
+
|
130 |
+
def get_decoder(N2=0, N3=8, channel_sizes=None):
|
131 |
+
"""Assemble and return the VAE decoder."""
|
132 |
+
if channel_sizes is None:
|
133 |
+
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
|
134 |
+
conv_shape = [-1, 2 ** (9 - N3), 2 ** (8 - N3), channel_sizes[-1]]
|
135 |
+
decoder_input = Input(shape=(timbre_dim,))
|
136 |
+
|
137 |
+
decoder = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation=act)(decoder_input)
|
138 |
+
decoder_conv = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(decoder)
|
139 |
+
|
140 |
+
for c in list(reversed(channel_sizes))[1:]:
|
141 |
+
decoder_conv = res_block(decoder_conv, c, upsample=True, n_cells=1 + N2)
|
142 |
+
|
143 |
+
decoder_conv = Conv2DTranspose(filters=num_channels, kernel_size=5, strides=2,
|
144 |
+
padding='same', activation='sigmoid')(decoder_conv)
|
145 |
+
|
146 |
+
decoder = Model(decoder_input, decoder_conv)
|
147 |
+
return decoder
|
148 |
+
|
149 |
+
|
150 |
+
def VAE(N2=0, N3=8, channel_sizes=None):
|
151 |
+
"""Assemble and return the VAE."""
|
152 |
+
if channel_sizes is None:
|
153 |
+
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
|
154 |
+
print("Creating model...")
|
155 |
+
assert N2 >= 0, "Please set N2 >= 0"
|
156 |
+
assert N3 >= 1, "Please set 1 <= N3 <= 8"
|
157 |
+
assert N3 <= 8, "Please set 1 <= N3 <= 8"
|
158 |
+
assert N3 == len(channel_sizes), "Please set N3 = len(channel_sizes)"
|
159 |
+
encoder = get_encoder(N2, channel_sizes)
|
160 |
+
decoder = get_decoder(N2, N3, channel_sizes)
|
161 |
+
|
162 |
+
# encoder = tf.keras.models.load_model(f"encoder_thesis_record_1.h5")
|
163 |
+
# decoder = tf.keras.models.load_model(f"decoder_thesis_record_1.h5")
|
164 |
+
|
165 |
+
encoder_input1 = Input(shape=input_shape)
|
166 |
+
scalar_input1 = Input(shape=(1,))
|
167 |
+
|
168 |
+
embedding_1_timbre, kl_loss = encoder(encoder_input1)
|
169 |
+
reconstruction_1 = decoder(embedding_1_timbre)
|
170 |
+
|
171 |
+
VAE = Model([encoder_input1, scalar_input1], [reconstruction_1, kl_loss])
|
172 |
+
# decoder.summary()
|
173 |
+
VAE.summary()
|
174 |
+
return encoder, decoder, VAE
|
175 |
+
|
176 |
+
|
177 |
+
def my_thread(data_cache):
|
178 |
+
data_cache.refresh()
|
179 |
+
|
180 |
+
|
181 |
+
def train_VAE(vae, encoder, decoder, data_cache, stages, batch_size):
|
182 |
+
"""Train the VAE.
|
183 |
+
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
vae: keras.engine.functional.Functional
|
187 |
+
The VAE.
|
188 |
+
encoder: keras.engine.functional.Functional
|
189 |
+
The VAE encoder.
|
190 |
+
decoder: keras.engine.functional.Functional
|
191 |
+
The VAE decoder.
|
192 |
+
data_cache: Data_cache
|
193 |
+
A Data_cache entity that provides training data.
|
194 |
+
stages: Dict
|
195 |
+
Defines the training stages. In each stage, the synthetic data will be refreshed and
|
196 |
+
models will be stored once.
|
197 |
+
|
198 |
+
Returns
|
199 |
+
-------
|
200 |
+
"""
|
201 |
+
threshold = 1e-0
|
202 |
+
kl_weight = 100.0
|
203 |
+
|
204 |
+
def weighted_binary_cross_entropy_loss(true, pred):
|
205 |
+
b_n = true * tf.math.log(tf.maximum(1e-20, pred)) + (1 - true) * tf.math.log(tf.maximum(1e-20, 1 - pred))
|
206 |
+
w = tf.maximum(threshold, true)
|
207 |
+
return -tf.reduce_sum(b_n / w) / batch_size
|
208 |
+
|
209 |
+
def reconstruction_loss(true, pred):
|
210 |
+
reconstruction_loss = weighted_binary_cross_entropy_loss(K.flatten(true), K.flatten(pred))
|
211 |
+
return K.mean(reconstruction_loss)
|
212 |
+
|
213 |
+
def kl_loss(true, pred):
|
214 |
+
return pred * kl_weight
|
215 |
+
|
216 |
+
for stage in stages:
|
217 |
+
threshold = stage["threshold"]
|
218 |
+
kl_weight = stage["kl_weight"]
|
219 |
+
vae.compile(tf.keras.optimizers.Adam(learning_rate=stage["learning_rate"]), loss=[reconstruction_loss, kl_loss])
|
220 |
+
|
221 |
+
Input_all = data_cache.get_all_data()
|
222 |
+
n_total = np.shape(Input_all)[0]
|
223 |
+
|
224 |
+
t = threading.Thread(target=my_thread, args=(data_cache,))
|
225 |
+
t.start()
|
226 |
+
history = vae.fit([Input_all, np.ones(n_total)], [Input_all, np.ones(n_total)], epochs=stage["n_epoch"],
|
227 |
+
batch_size=batch_size)
|
228 |
+
t.join()
|
229 |
+
encoder.save(f"./models/new_trained_models/encoder.h5")
|
230 |
+
decoder.save(f"./models/new_trained_models/decoder.h5")
|
model/VAE_torchV.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ChannelAttention(nn.Module):
|
6 |
+
def __init__(self, in_planes, ratio=16):
|
7 |
+
super(ChannelAttention, self).__init__()
|
8 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
9 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
10 |
+
|
11 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
|
12 |
+
self.relu1 = nn.ReLU()
|
13 |
+
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
|
14 |
+
|
15 |
+
self.sigmoid = nn.Sigmoid()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
19 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
20 |
+
y = avg_out + max_out
|
21 |
+
y = self.sigmoid(y)
|
22 |
+
|
23 |
+
return x * y.expand_as(x)
|
24 |
+
|
25 |
+
|
26 |
+
class ResCell(nn.Module):
|
27 |
+
def __init__(self, input_channel, output_channel, stride=1):
|
28 |
+
super(ResCell, self).__init__()
|
29 |
+
|
30 |
+
self.stride = stride
|
31 |
+
self.input_channel = input_channel
|
32 |
+
self.output_channel = output_channel
|
33 |
+
|
34 |
+
if self.stride == -1:
|
35 |
+
output_size = ()
|
36 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
37 |
+
self.skip = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=1, stride=1, padding=0)
|
38 |
+
self.conv1 = nn.ConvTranspose2d(self.input_channel, self.output_channel, kernel_size=5, stride=2, padding=2, output_padding=1)
|
39 |
+
self.conv2 = nn.ConvTranspose2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
|
40 |
+
|
41 |
+
elif self.stride == 2:
|
42 |
+
self.skip = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=1, stride=2, padding=0)
|
43 |
+
self.conv1 = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=5, stride=self.stride, padding=2)
|
44 |
+
self.conv2 = nn.Conv2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
|
45 |
+
|
46 |
+
else:
|
47 |
+
self.conv1 = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=5, stride=self.stride, padding=2)
|
48 |
+
self.conv2 = nn.Conv2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
|
49 |
+
|
50 |
+
self.bn1 = nn.BatchNorm2d(self.output_channel)
|
51 |
+
self.bn2 = nn.BatchNorm2d(self.output_channel)
|
52 |
+
|
53 |
+
# Please replace `CBAM` with the actual module and parameters
|
54 |
+
self.cbam = ChannelAttention(self.output_channel)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if self.stride == -1:
|
58 |
+
upsampled_x = self.upsample(x)
|
59 |
+
skip = self.skip(upsampled_x)
|
60 |
+
x = F.elu(self.bn1(self.conv1(x)))
|
61 |
+
x = self.conv2(x)
|
62 |
+
elif self.stride == 2:
|
63 |
+
skip = self.skip(x)
|
64 |
+
x = F.elu(self.bn1(self.conv1(x)))
|
65 |
+
x = self.conv2(x)
|
66 |
+
else:
|
67 |
+
skip = x
|
68 |
+
x = F.elu(self.bn1(self.conv1(x)))
|
69 |
+
x = self.conv2(x)
|
70 |
+
|
71 |
+
x = self.bn2(x)
|
72 |
+
x = self.cbam(x)
|
73 |
+
x = x + skip
|
74 |
+
x = F.elu(x)
|
75 |
+
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class ResBlock(nn.Module):
|
80 |
+
def __init__(self, input_channel, output_channel, upsample=False, n_cells=2):
|
81 |
+
super(ResBlock, self).__init__()
|
82 |
+
|
83 |
+
stride = -1 if upsample else 2
|
84 |
+
self.cells = nn.ModuleList([ResCell(input_channel, output_channel, stride=stride)])
|
85 |
+
|
86 |
+
for _ in range(n_cells - 1):
|
87 |
+
self.cells.append(ResCell(input_channel, output_channel, stride=1))
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
for cell in self.cells:
|
91 |
+
x = cell(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class Encoder(nn.Module):
|
96 |
+
def __init__(self, input_shape, timbre_dim, N2=0, channel_sizes=None):
|
97 |
+
super(Encoder, self).__init__()
|
98 |
+
|
99 |
+
if channel_sizes is None:
|
100 |
+
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
|
101 |
+
|
102 |
+
self.input_shape = input_shape
|
103 |
+
self.timbre_dim = timbre_dim
|
104 |
+
self.blocks = nn.ModuleList()
|
105 |
+
|
106 |
+
self.blocks.append(ResBlock(input_channel=1, output_channel=channel_sizes[0], upsample=False, n_cells=1))
|
107 |
+
input_channel = channel_sizes[0]
|
108 |
+
|
109 |
+
for c in channel_sizes[1:]:
|
110 |
+
self.blocks.append(ResBlock(input_channel=input_channel, output_channel=c, upsample=False, n_cells=1 + N2))
|
111 |
+
input_channel = c
|
112 |
+
|
113 |
+
self.flatten = nn.Flatten()
|
114 |
+
self.mu_timbre = nn.Linear(self._get_flattened_dim(), timbre_dim)
|
115 |
+
self.sigma_timbre = nn.Linear(self._get_flattened_dim(), timbre_dim)
|
116 |
+
|
117 |
+
def _get_flattened_dim(self):
|
118 |
+
x = torch.zeros((1,) + self.input_shape)
|
119 |
+
for block in self.blocks:
|
120 |
+
x = block(x)
|
121 |
+
x = self.flatten(x)
|
122 |
+
return x.shape[1]
|
123 |
+
|
124 |
+
def reparameterize(self, mu, logvar):
|
125 |
+
std = torch.exp(0.5*logvar)
|
126 |
+
eps = torch.randn_like(std)
|
127 |
+
return mu + eps*std
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
for block in self.blocks:
|
131 |
+
x = block(x)
|
132 |
+
|
133 |
+
x = self.flatten(x)
|
134 |
+
mu = self.mu_timbre(x)
|
135 |
+
logvar = self.sigma_timbre(x)
|
136 |
+
latent_vector = self.reparameterize(mu, logvar)
|
137 |
+
|
138 |
+
# kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
|
139 |
+
# kl_loss = torch.mean(kl_loss)
|
140 |
+
|
141 |
+
return mu, logvar, latent_vector
|
142 |
+
|
143 |
+
|
144 |
+
class Decoder(nn.Module):
|
145 |
+
def __init__(self, timbre_dim, N2=0, N3=8, channel_sizes=None):
|
146 |
+
super(Decoder, self).__init__()
|
147 |
+
|
148 |
+
if channel_sizes is None:
|
149 |
+
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
|
150 |
+
|
151 |
+
self.conv_shape = [-1, channel_sizes[-1], 2 ** (9 - N3), 2 ** (8 - N3)]
|
152 |
+
|
153 |
+
self.dense = nn.Linear(timbre_dim, self.conv_shape[1] * self.conv_shape[2] * self.conv_shape[3])
|
154 |
+
self.blocks = nn.ModuleList()
|
155 |
+
|
156 |
+
input_channel = channel_sizes[-1]
|
157 |
+
for c in list(reversed(channel_sizes))[1:]:
|
158 |
+
self.blocks.append(ResBlock(input_channel=input_channel, output_channel=c, upsample=True, n_cells=1 + N2))
|
159 |
+
input_channel = c
|
160 |
+
|
161 |
+
self.decoder_conv = nn.ConvTranspose2d(channel_sizes[0], 1, kernel_size=5, stride=2, padding=2, output_padding=1)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
x = F.elu(self.dense(x))
|
165 |
+
x = x.view(-1, self.conv_shape[1], self.conv_shape[2], self.conv_shape[3])
|
166 |
+
for block in self.blocks:
|
167 |
+
x = block(x)
|
168 |
+
|
169 |
+
x = self.decoder_conv(x)
|
170 |
+
x = torch.sigmoid(x)
|
171 |
+
return x
|
model/perceptual_label_predictor.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras.layers import Input, Dense, Dropout
|
3 |
+
from keras.models import Model
|
4 |
+
from keras.losses import binary_crossentropy
|
5 |
+
from load_data import read_data
|
6 |
+
import joblib
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
KL = tf.keras.layers
|
10 |
+
|
11 |
+
|
12 |
+
def perceptual_label_predictor():
|
13 |
+
"""Assemble and return the perceptual_label_predictor."""
|
14 |
+
mini_input = Input((20,))
|
15 |
+
p = Dense(20, activation='relu')(mini_input)
|
16 |
+
p = Dropout(0.2)(p)
|
17 |
+
p = Dense(16, activation='relu')(p)
|
18 |
+
p = Dropout(0.2)(p)
|
19 |
+
p = Dense(5, activation='sigmoid')(p)
|
20 |
+
style_predictor = Model(mini_input, p)
|
21 |
+
style_predictor.summary()
|
22 |
+
return style_predictor
|
23 |
+
|
24 |
+
|
25 |
+
def train_perceptual_label_predictor(perceptual_label_predictor, encoder):
|
26 |
+
"""Train the perceptual_label_predictor. (Including data loading.)"""
|
27 |
+
|
28 |
+
Input_synthetic = read_data("./data/labeled_dataset/synthetic_data")
|
29 |
+
Input_AU = read_data("./data/external_data/ARTURIA_data")[:100]
|
30 |
+
|
31 |
+
AU_labels = joblib.load("./data/labeled_dataset/ARTURIA_labels")
|
32 |
+
synth_labels = joblib.load("./data/labeled_dataset/synthetic_data_labels")
|
33 |
+
|
34 |
+
AU_encode = encoder.predict(Input_AU)[0]
|
35 |
+
Synth_encode = encoder.predict(Input_synthetic)[0]
|
36 |
+
|
37 |
+
perceptual_label_predictor.compile(optimizer='adam', loss=binary_crossentropy)
|
38 |
+
perceptual_label_predictor.fit(np.vstack([AU_encode, Synth_encode]), np.vstack([AU_labels, synth_labels]), epochs=140, validation_split=0.05, batch_size=16)
|
39 |
+
perceptual_label_predictor.save(f"./models/new_trained_models/perceptual_label_predictor.h5")
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
models/decoder_5_13.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7c25c8e241ade613409293d0ba3b37823a129699e802e1a2e9d9bb0074b11d3
|
3 |
+
size 16884433
|
models/encoder_5_13.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3da8ae2c98eea048a5cf4247ed637c67edf8b121e97efd3c7564e8ed4ea9fa51
|
3 |
+
size 21627911
|
models/new_trained_models/perceptual_label_predictor.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c6d2c3df579658bc54e07dd76d0e9efb02831bebf5e2d8c104becac930f0071
|
3 |
+
size 46080
|
models/perceptual_label_predictor.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f59db5a982347ec8c40452b8f77a35432d00253a8ed347726f9b102fc4550d44
|
3 |
+
size 46176
|
new_sound_generation.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
from tqdm import tqdm
|
6 |
+
from tools import save_results, VAE_out_put_to_spc, show_spc
|
7 |
+
|
8 |
+
|
9 |
+
def test_reconstruction(encoder, decoder, data, n_sample=5, f=0, path_name="./data/test_reconstruction", save_data=False):
|
10 |
+
"""Generate and show reconstruction results. Randomly reconstruct 'n_sample' samples in 'data'.
|
11 |
+
You can manually set the index of the first reconstructed sample by 'f'.
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
----------
|
15 |
+
encoder: keras.engine.functional.Functional
|
16 |
+
The VAE encoder.
|
17 |
+
decoder: keras.engine.functional.Functional
|
18 |
+
Sample rate of the audio to generate.
|
19 |
+
data: numpy array
|
20 |
+
The VAE decoder.
|
21 |
+
n_sample: int
|
22 |
+
Number of samples to reconstruct.
|
23 |
+
f: int
|
24 |
+
Index of the first reconstructed sample.
|
25 |
+
path_name: String
|
26 |
+
Path to save the results.
|
27 |
+
save_data: bool
|
28 |
+
Whether save the results.
|
29 |
+
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
"""
|
33 |
+
if save_data:
|
34 |
+
if Path(path_name).exists():
|
35 |
+
shutil.rmtree(path_name)
|
36 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
37 |
+
|
38 |
+
for i in range(n_sample):
|
39 |
+
index = np.random.randint(np.shape(data)[0])
|
40 |
+
if i == 0:
|
41 |
+
index = f
|
42 |
+
print("######################################################")
|
43 |
+
print(f"index: {index}")
|
44 |
+
|
45 |
+
input = data[index]
|
46 |
+
print(f"Original:")
|
47 |
+
show_spc(VAE_out_put_to_spc(input))
|
48 |
+
if save_data:
|
49 |
+
save_results(VAE_out_put_to_spc(input), f"{path_name}/origin_{index}.png", f"{path_name}/origin_{index}.wav")
|
50 |
+
|
51 |
+
input = data[index:index + 1]
|
52 |
+
timbre_encode = encoder.predict(input)[0]
|
53 |
+
|
54 |
+
encode = timbre_encode
|
55 |
+
|
56 |
+
reconstruction = decoder.predict(encode)[0]
|
57 |
+
reconstruction = VAE_out_put_to_spc(reconstruction)
|
58 |
+
reconstruction = np.minimum(5000, reconstruction)
|
59 |
+
print(f"Reconstruction:")
|
60 |
+
show_spc(reconstruction)
|
61 |
+
if save_data:
|
62 |
+
save_results(reconstruction, f"{path_name}/reconstruction_{index}.png", f"{path_name}/reconstruction_{index}.wav")
|
63 |
+
|
64 |
+
|
65 |
+
def test_interpulation(data0, data1, encoder, decoder, path_name = "./data/test_interpolation", save_data=False):
|
66 |
+
"""Generate new sounds by latent space interpolation.
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
data0: numpy array
|
71 |
+
First input for interpolation.
|
72 |
+
data1: numpy array
|
73 |
+
Second input for interpolation.
|
74 |
+
encoder: keras.engine.functional.Functional
|
75 |
+
The VAE encoder.
|
76 |
+
decoder: keras.engine.functional.Functional
|
77 |
+
Sample rate of the audio to generate.
|
78 |
+
path_name: String
|
79 |
+
Path to save the results.
|
80 |
+
save_data: bool
|
81 |
+
Whether save the results.
|
82 |
+
|
83 |
+
Returns
|
84 |
+
-------
|
85 |
+
"""
|
86 |
+
if save_data:
|
87 |
+
if Path(path_name).exists():
|
88 |
+
shutil.rmtree(path_name)
|
89 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
90 |
+
|
91 |
+
if save_data:
|
92 |
+
save_results(VAE_out_put_to_spc(data0), f"{path_name}/origin_0.png", f"{path_name}/origin_0.wav")
|
93 |
+
save_results(VAE_out_put_to_spc(data1), f"{path_name}/origin_1.png", f"{path_name}/origin_1.wav")
|
94 |
+
|
95 |
+
print("First Original:")
|
96 |
+
show_spc(VAE_out_put_to_spc(data0))
|
97 |
+
print("Second Original:")
|
98 |
+
show_spc(VAE_out_put_to_spc(data1))
|
99 |
+
print("######################################################")
|
100 |
+
print("Interpolations:")
|
101 |
+
data0 = np.reshape(data0, (1, 512, 256, 1))
|
102 |
+
data1 = np.reshape(data1, (1, 512, 256, 1))
|
103 |
+
timbre_encode0 = encoder.predict(data0)[0]
|
104 |
+
timbre_encode1 = encoder.predict(data1)[0]
|
105 |
+
|
106 |
+
n_f = 8
|
107 |
+
for i in tqdm(range(n_f+1)):
|
108 |
+
rate = 1 - i/n_f
|
109 |
+
new_timbre = rate * timbre_encode0 + (1-rate) * timbre_encode1
|
110 |
+
output = decoder.predict(new_timbre)
|
111 |
+
|
112 |
+
spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
|
113 |
+
if save_data:
|
114 |
+
save_results(spc, f"{path_name}/test_interpolation_{i}.png", f"{path_name}/test_interpolation_{i}.wav")
|
115 |
+
show_spc(spc)
|
116 |
+
|
117 |
+
|
118 |
+
def test_random_sampling(decoder, n_sample=20, mu=np.zeros(20), sigma=np.ones(20), save_data = False, path_name = "./data/test_random_sampling"):
|
119 |
+
"""Generate new sounds by random sampling in the latent space.
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
decoder: keras.engine.functional.Functional
|
124 |
+
Sample rate of the audio to generate.
|
125 |
+
path_name: String
|
126 |
+
Path to save the results.
|
127 |
+
save_data: bool
|
128 |
+
Whether save the results.
|
129 |
+
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
"""
|
133 |
+
if save_data:
|
134 |
+
if Path(path_name).exists():
|
135 |
+
shutil.rmtree(path_name)
|
136 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
137 |
+
|
138 |
+
for i in tqdm(range(n_sample)):
|
139 |
+
off_set = np.random.normal(mu,np.square(sigma))
|
140 |
+
new_timbre = np.reshape(off_set, (1,20))
|
141 |
+
|
142 |
+
output = decoder.predict(new_timbre)
|
143 |
+
|
144 |
+
spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
|
145 |
+
if save_data:
|
146 |
+
save_results(spc, f"{path_name}/random_sampling_{i}.png", f"{path_name}/random_sampling_{i}.wav")
|
147 |
+
show_spc(spc)
|
148 |
+
|
149 |
+
|
150 |
+
def test_style_transform(original, encoder, decoder, perceptual_label_predictor, n_samples=10, save_data = False, goal=0, direction=0, path_name = "./data/test_style_transform"):
|
151 |
+
"""Generate new sounds by latent space interpolation.
|
152 |
+
|
153 |
+
Parameters
|
154 |
+
----------
|
155 |
+
original: numpy array
|
156 |
+
Original for style transform.
|
157 |
+
encoder: keras.engine.functional.Functional
|
158 |
+
The VAE encoder.
|
159 |
+
decoder: keras.engine.functional.Functional
|
160 |
+
Sample rate of the audio to generate.
|
161 |
+
perceptual_label_predictor: keras.engine.functional.Functional
|
162 |
+
Model that selects the output.
|
163 |
+
path_name: String
|
164 |
+
Path to save the results.
|
165 |
+
save_data: bool
|
166 |
+
Whether save the results.
|
167 |
+
|
168 |
+
Returns
|
169 |
+
-------
|
170 |
+
"""
|
171 |
+
if save_data:
|
172 |
+
if Path(path_name).exists():
|
173 |
+
shutil.rmtree(path_name)
|
174 |
+
Path(path_name).mkdir(parents=True, exist_ok=True)
|
175 |
+
save_results(VAE_out_put_to_spc(original), f"{path_name}/origin.png", f"{path_name}/origin.wav")
|
176 |
+
labels_names = ["metallic", "warm", "breathy", "evolving", "aggressiv"]
|
177 |
+
timbre_dim = 20
|
178 |
+
|
179 |
+
print("Original:")
|
180 |
+
show_spc(VAE_out_put_to_spc(original))
|
181 |
+
print("######################################################")
|
182 |
+
original_code = encoder.predict(np.reshape(original, (1,512,256,1)))[0]
|
183 |
+
new_encodes = np.zeros((n_samples, timbre_dim)) + original_code
|
184 |
+
|
185 |
+
new_encodes = [new_encode + np.random.normal(np.zeros(timbre_dim) * 0.2,np.ones(timbre_dim)) for new_encode in new_encodes]
|
186 |
+
new_encodes = np.array(new_encodes, dtype=np.float32)
|
187 |
+
perceptual_labels = perceptual_label_predictor.predict(new_encodes)[:,goal]
|
188 |
+
|
189 |
+
if direction == 0:
|
190 |
+
best_index = np.argmin(perceptual_labels)
|
191 |
+
suffix = f"less_{labels_names[goal]}"
|
192 |
+
else:
|
193 |
+
best_index = np.argmax(perceptual_labels)
|
194 |
+
suffix = f"more_{labels_names[goal]}"
|
195 |
+
|
196 |
+
output = decoder.predict(new_encodes[best_index:best_index+1])
|
197 |
+
|
198 |
+
spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
|
199 |
+
if save_data:
|
200 |
+
save_results(spc, f"{path_name}/{suffix}.png", f"{path_name}/{suffix}.wav")
|
201 |
+
print("Manipulated (suffix):")
|
202 |
+
show_spc(spc)
|
203 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
ddsp
|
3 |
+
torchsynth
|
4 |
+
pytorch_lightning -y
|
5 |
+
pytorch_lightning==1.7.0
|
test_audio.wav
ADDED
Binary file (522 kB). View file
|
|
tools.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
import librosa
|
6 |
+
from scipy.io.wavfile import write
|
7 |
+
|
8 |
+
k = 1e-16
|
9 |
+
|
10 |
+
|
11 |
+
def np_log10(x):
|
12 |
+
numerator = np.log(x + 1e-16)
|
13 |
+
denominator = np.log(10)
|
14 |
+
return numerator / denominator
|
15 |
+
|
16 |
+
|
17 |
+
def sigmoid(x):
|
18 |
+
s = 1 / (1 + np.exp(-x))
|
19 |
+
return s
|
20 |
+
|
21 |
+
|
22 |
+
def inv_sigmoid(s):
|
23 |
+
x = np.log((s / (1 - s)) + 1e-16)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
def spc_to_VAE_input(spc):
|
28 |
+
"""Restrict value range from 0 to 1."""
|
29 |
+
return spc / (1 + spc)
|
30 |
+
|
31 |
+
|
32 |
+
def VAE_out_put_to_spc(o):
|
33 |
+
"""Inverse transform of function 'spc_to_VAE_input'."""
|
34 |
+
return o / (1 - o + k)
|
35 |
+
|
36 |
+
|
37 |
+
def denoise(spc):
|
38 |
+
"""Filter back ground noise. (Not used.)"""
|
39 |
+
return np.maximum(0, spc - (2e-5))
|
40 |
+
|
41 |
+
|
42 |
+
hop_length = 256
|
43 |
+
win_length = 1024
|
44 |
+
|
45 |
+
|
46 |
+
def np_power_to_db(S, amin=1e-16, top_db=80.0):
|
47 |
+
"""Helper method for scaling."""
|
48 |
+
ref = np.max(S)
|
49 |
+
|
50 |
+
# set fixed value for ref
|
51 |
+
|
52 |
+
# 每个元素取max
|
53 |
+
log_spec = 10.0 * np_log10(np.maximum(amin, S))
|
54 |
+
log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
|
55 |
+
|
56 |
+
log_spec = np.maximum(log_spec, np.max(log_spec) - top_db)
|
57 |
+
|
58 |
+
return log_spec
|
59 |
+
|
60 |
+
|
61 |
+
def show_spc(spc, resolution=(512, 256)):
|
62 |
+
"""Show a spectrogram."""
|
63 |
+
spc = np.reshape(spc, resolution)
|
64 |
+
magnitude_spectrum = np.abs(spc)
|
65 |
+
log_spectrum = np_power_to_db(magnitude_spectrum)
|
66 |
+
plt.imshow(np.flipud(log_spectrum))
|
67 |
+
plt.show()
|
68 |
+
|
69 |
+
|
70 |
+
def save_results(spectrogram, spectrogram_image_path, waveform_path):
|
71 |
+
"""Save the input 'spectrogram' and its waveform (reconstructed bu Griffin Lim)
|
72 |
+
to path provided by 'spectrogram_image_path' and 'waveform_path'."""
|
73 |
+
# save image
|
74 |
+
magnitude_spectrum = np.abs(spectrogram)
|
75 |
+
log_spc = np_power_to_db(magnitude_spectrum)
|
76 |
+
log_spc = np.reshape(log_spc, (512, 256))
|
77 |
+
matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
|
78 |
+
origin='lower')
|
79 |
+
|
80 |
+
# save waveform
|
81 |
+
abs_spec = np.zeros((513, 256))
|
82 |
+
abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
|
83 |
+
rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
84 |
+
write(waveform_path, 16000, rec_signal)
|
webUI/initial_example_encodes.json
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"random": [
|
3 |
+
0.398047679983798,
|
4 |
+
0.3181480556003946,
|
5 |
+
0.4481247707840732,
|
6 |
+
0.17181013678356805,
|
7 |
+
0.38504974079327525,
|
8 |
+
0.1630011861056878,
|
9 |
+
0.2718486665735521,
|
10 |
+
0.4338781507304229,
|
11 |
+
0.6538380475574059,
|
12 |
+
0.4158802639661583,
|
13 |
+
0.23043953925285032,
|
14 |
+
0.10156416680503988,
|
15 |
+
0.30416174813259533,
|
16 |
+
0.5135342367189637,
|
17 |
+
0.5187104878467569,
|
18 |
+
0.3526902627535946,
|
19 |
+
0.7747258429690094,
|
20 |
+
0.2627179357923156,
|
21 |
+
0.9086876170530048,
|
22 |
+
0.9271088722414674,
|
23 |
+
0.7019110161290005,
|
24 |
+
0.7718117435584002,
|
25 |
+
0.36794622268993094,
|
26 |
+
0.201057894940179
|
27 |
+
],
|
28 |
+
"few_overtone_fading_out": [
|
29 |
+
-1.0641387701034546,
|
30 |
+
-0.22593635320663452,
|
31 |
+
-3.0761210918426514,
|
32 |
+
0.3395313322544098,
|
33 |
+
3.4432809352874756,
|
34 |
+
1.2180149555206299,
|
35 |
+
-3.312405824661255,
|
36 |
+
-0.9400798082351685,
|
37 |
+
1.4439473152160645,
|
38 |
+
-2.7892191410064697,
|
39 |
+
-4.55153751373291,
|
40 |
+
2.2633938789367676,
|
41 |
+
-1.9029977321624756,
|
42 |
+
0.37142419815063477,
|
43 |
+
1.7903343439102173,
|
44 |
+
3.5063657760620117,
|
45 |
+
-1.5748300552368164,
|
46 |
+
-2.555540084838867,
|
47 |
+
0.07989349216222763,
|
48 |
+
0.23952914774417877,
|
49 |
+
0.5571582317352295,
|
50 |
+
-1.200455904006958,
|
51 |
+
-1.2390071153640747,
|
52 |
+
-0.5626499652862549
|
53 |
+
],
|
54 |
+
"much_overtone": [
|
55 |
+
-1.649719476699829,
|
56 |
+
-2.1237597465515137,
|
57 |
+
-3.3006417751312256,
|
58 |
+
-1.4381871223449707,
|
59 |
+
2.0898361206054688,
|
60 |
+
0.34304752945899963,
|
61 |
+
-1.6530671119689941,
|
62 |
+
0.6339190006256104,
|
63 |
+
-0.700051486492157,
|
64 |
+
-0.6604726910591125,
|
65 |
+
-2.3133397102355957,
|
66 |
+
3.0709166526794434,
|
67 |
+
-1.2595521211624146,
|
68 |
+
0.6515411138534546,
|
69 |
+
1.8037245273590088,
|
70 |
+
0.17395955324172974,
|
71 |
+
-1.1443099975585938,
|
72 |
+
-2.599336624145508,
|
73 |
+
-1.909640908241272,
|
74 |
+
-1.422598123550415,
|
75 |
+
0.5974328517913818,
|
76 |
+
-2.559039354324341,
|
77 |
+
-2.4977917671203613,
|
78 |
+
-0.9264755249023438
|
79 |
+
],
|
80 |
+
"much_overtone_high_register": [
|
81 |
+
-1.5539246797561646,
|
82 |
+
-1.4718247652053833,
|
83 |
+
-0.2083689421415329,
|
84 |
+
0.47305312752723694,
|
85 |
+
-0.3550421893596649,
|
86 |
+
1.6288657188415527,
|
87 |
+
-2.5005292892456055,
|
88 |
+
-0.5079396367073059,
|
89 |
+
1.9173517227172852,
|
90 |
+
-2.8692283630371094,
|
91 |
+
-1.3840413093566895,
|
92 |
+
1.8955140113830566,
|
93 |
+
-0.6880097389221191,
|
94 |
+
1.2770217657089233,
|
95 |
+
0.2371762990951538,
|
96 |
+
2.726161241531372,
|
97 |
+
-1.7957547903060913,
|
98 |
+
-0.28189635276794434,
|
99 |
+
-0.6052505373954773,
|
100 |
+
-0.557244598865509,
|
101 |
+
-0.7524706721305847,
|
102 |
+
-0.6265298128128052,
|
103 |
+
-1.4746038913726807,
|
104 |
+
-2.393972396850586
|
105 |
+
],
|
106 |
+
"blurry": [
|
107 |
+
1.3638803958892822,
|
108 |
+
-1.7874348163604736,
|
109 |
+
0.4853333532810211,
|
110 |
+
-0.9626323580741882,
|
111 |
+
-0.2609613537788391,
|
112 |
+
-1.5900236368179321,
|
113 |
+
-2.30368709564209,
|
114 |
+
0.4847792387008667,
|
115 |
+
1.5201102495193481,
|
116 |
+
-0.5147597789764404,
|
117 |
+
-2.452840805053711,
|
118 |
+
1.5057097673416138,
|
119 |
+
-0.16519485414028168,
|
120 |
+
2.126760721206665,
|
121 |
+
0.7019514441490173,
|
122 |
+
1.612999439239502,
|
123 |
+
-0.3407663106918335,
|
124 |
+
-2.8276443481445312,
|
125 |
+
1.311924934387207,
|
126 |
+
-1.5173300504684448,
|
127 |
+
-0.015635617077350616,
|
128 |
+
-2.7689361572265625,
|
129 |
+
1.1804192066192627,
|
130 |
+
-3.077000141143799
|
131 |
+
],
|
132 |
+
"interesting_release": [
|
133 |
+
1.7602112293243408,
|
134 |
+
-1.9114476442337036,
|
135 |
+
-2.730947256088257,
|
136 |
+
-0.6088603138923645,
|
137 |
+
3.317946195602417,
|
138 |
+
-0.2716876268386841,
|
139 |
+
-2.3106558322906494,
|
140 |
+
-2.114469289779663,
|
141 |
+
-4.443742752075195,
|
142 |
+
-1.0665826797485352,
|
143 |
+
-3.0929622650146484,
|
144 |
+
1.1979585886001587,
|
145 |
+
-1.6287152767181396,
|
146 |
+
-1.537142276763916,
|
147 |
+
2.4184482097625732,
|
148 |
+
0.22694607079029083,
|
149 |
+
0.10934393107891083,
|
150 |
+
-0.18058283627033234,
|
151 |
+
-2.489964723587036,
|
152 |
+
-4.448374271392822,
|
153 |
+
1.2452409267425537,
|
154 |
+
0.05835026502609253,
|
155 |
+
0.8547804355621338,
|
156 |
+
0.8163737654685974
|
157 |
+
],
|
158 |
+
"global_trend": [
|
159 |
+
-1.0987968444824219,
|
160 |
+
-1.1155377626419067,
|
161 |
+
0.14996160566806793,
|
162 |
+
-3.165109157562256,
|
163 |
+
-2.5396244525909424,
|
164 |
+
1.8292016983032227,
|
165 |
+
-3.5159406661987305,
|
166 |
+
3.4396510124206543,
|
167 |
+
-2.3765876293182373,
|
168 |
+
0.5692415833473206,
|
169 |
+
-1.7827686071395874,
|
170 |
+
-0.4062053859233856,
|
171 |
+
-1.6925498247146606,
|
172 |
+
0.7511563897132874,
|
173 |
+
0.12510846555233002,
|
174 |
+
-0.14617247879505157,
|
175 |
+
0.5096412897109985,
|
176 |
+
2.2399022579193115,
|
177 |
+
0.5798826217651367,
|
178 |
+
-1.5942487716674805,
|
179 |
+
-0.36588573455810547,
|
180 |
+
-0.9877008199691772,
|
181 |
+
4.732168674468994,
|
182 |
+
-5.468194007873535
|
183 |
+
],
|
184 |
+
"crescendo": [
|
185 |
+
1.3167105913162231,
|
186 |
+
-1.1503334045410156,
|
187 |
+
-3.488548517227173,
|
188 |
+
-1.146520972251892,
|
189 |
+
2.478545665740967,
|
190 |
+
-0.5853592753410339,
|
191 |
+
-2.1441550254821777,
|
192 |
+
-2.1898915767669678,
|
193 |
+
-7.137173175811768,
|
194 |
+
-0.34099096059799194,
|
195 |
+
-3.832253932952881,
|
196 |
+
2.0366034507751465,
|
197 |
+
-1.3639447689056396,
|
198 |
+
-2.3450658321380615,
|
199 |
+
1.1388293504714966,
|
200 |
+
1.1278795003890991,
|
201 |
+
0.6025446653366089,
|
202 |
+
-0.2925209701061249,
|
203 |
+
-0.07147964835166931,
|
204 |
+
-3.1970367431640625,
|
205 |
+
2.4061689376831055,
|
206 |
+
0.477089524269104,
|
207 |
+
-0.8897881507873535,
|
208 |
+
2.827509880065918
|
209 |
+
]
|
210 |
+
}
|