WeixuanYuan commited on
Commit
b88cc47
1 Parent(s): 18a55e0

Upload 31 files

Browse files
MyTest.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from load_data import load_data
2
+
3
+ data_cache = load_data(500)
NN.json ADDED
File without changes
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import librosa
7
+ from tools import VAE_out_put_to_spc, np_power_to_db
8
+ from model.VAE_torchV import Encoder, Decoder
9
+
10
+ SPECTROGRAM_RESOLUTION = (512, 256, 3)
11
+ encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to("cuda")
12
+ decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to("cuda")
13
+ model_name = "test"
14
+ encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt"))
15
+ decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt"))
16
+
17
+ INIT_ENCODE_CACHE = {"init": np.random.random((24, ))}
18
+ with open('webUI/initial_example_encodes.json', 'r') as f:
19
+ list_dict = json.load(f)
20
+ for k in list_dict.keys():
21
+ INIT_ENCODE_CACHE[k] = np.array(list_dict[k])
22
+
23
+ #################################
24
+ def prepare_image(image):
25
+ # Rescale to 0-255 and convert to uint8
26
+ rescaled = (image + 80.0) / 80.0
27
+ rescaled = (255.0 * rescaled).astype(np.uint8)
28
+ return rescaled
29
+
30
+ def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)):
31
+ """Show a spectrogram."""
32
+ reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
33
+ flipped_log_spectrums, rec_signals = [], []
34
+ for reconstruction in reconstruction_batch:
35
+ spc = VAE_out_put_to_spc(reconstruction)
36
+ spc = np.reshape(spc, resolution)
37
+ magnitude_spectrum = np.abs(spc)
38
+ log_spectrum = np_power_to_db(magnitude_spectrum)
39
+ flipped_log_spectrum = np.flipud(log_spectrum)
40
+
41
+ colorful_spc = np.ones((512, 256, 3)) * -80.0
42
+ colorful_spc[:, :, 0] = flipped_log_spectrum
43
+ colorful_spc[:, :, 1] = flipped_log_spectrum
44
+ colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0
45
+ flipped_log_spectrum = prepare_image(colorful_spc)
46
+
47
+ # get_audio
48
+ abs_spec = np.zeros((513, 256))
49
+ abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256)))
50
+ rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
51
+ flipped_log_spectrums.append(flipped_log_spectrum)
52
+ rec_signals.append(rec_signal)
53
+
54
+ return flipped_log_spectrums, 16000, rec_signals
55
+
56
+
57
+ def get_example_module(encodeCache):
58
+ def show_example(selected_example, encodeCache):
59
+ example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to("cuda")
60
+ flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode)
61
+ flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0]
62
+
63
+ return flipped_log_spectrum, str(encodeCache), (sampleRate, rec_signal), encodeCache
64
+
65
+ with gr.Tab("Examples"):
66
+ gr.Markdown("Predefined examples.")
67
+ with gr.Row():
68
+ with gr.Column():
69
+ selected_example = gr.Dropdown(
70
+ list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example!"
71
+ )
72
+ example_button = gr.Button(value="Show example")
73
+ with gr.Column():
74
+ example_image_output = gr.Image(label="Reconstruction", type="numpy")
75
+ example_image_output.style(height=250, width=600)
76
+ example_audio_output = gr.Audio(type="numpy", label="Play reconstruction output!")
77
+ example_text_output = gr.Textbox()
78
+ example_button.click(show_example, inputs=[selected_example, encodeCache],
79
+ outputs=[example_image_output, example_text_output, example_audio_output, encodeCache])
80
+
81
+ def get_reconstruction_module():
82
+
83
+ def do_nothing(image_input):
84
+ return np.random.random(SPECTROGRAM_RESOLUTION)
85
+
86
+ with gr.Tab("Reconstruction"):
87
+ gr.Markdown("Test reconstruction.")
88
+ with gr.Row():
89
+ with gr.Column():
90
+ test_reconstruction_input = gr.Number(label="Batch_index")
91
+ test_reconstruction_button = gr.Button(value="Generate")
92
+ with gr.Column():
93
+ test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy")
94
+ test_reconstruction_output.style(height=250, width=600)
95
+ test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output)
96
+
97
+ def get_interpolation_module(encodeCache):
98
+
99
+ def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache):
100
+ # Todo: use batch
101
+ first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24)))
102
+ second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24)))
103
+ ratio = torch.Tensor([interpulation_input_ratio])
104
+ interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio)
105
+
106
+ interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to("cuda")
107
+ flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode)
108
+ first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0]
109
+ second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1]
110
+ interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2]
111
+ return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache
112
+
113
+ def refresh_interpolation_input(encodeCache):
114
+ return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), str(list(encodeCache.keys())), encodeCache
115
+
116
+ with gr.Tab("Interpolation"):
117
+ gr.Markdown("Test Interpolation.")
118
+ with gr.Row():
119
+ with gr.Column():
120
+ with gr.Row():
121
+ first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input")
122
+ second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input")
123
+ with gr.Row():
124
+ first_input_audio = gr.Audio(type="numpy", label="first_input_audio")
125
+ first_input_audio.style(length=125)
126
+ second_input_audio = gr.Audio(type="numpy", label="second_input_audio")
127
+ second_input_audio.style(length=125)
128
+
129
+ interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.")
130
+ interpulation_refresh_button = gr.Button(value="Refresh")
131
+ interpulation_button = gr.Button(value="Interpulate")
132
+ with gr.Column():
133
+ with gr.Row():
134
+ first_input_spectrogram = gr.Image(label="First Input", type="numpy")
135
+ first_input_spectrogram.style(height=250, width=125)
136
+ interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy")
137
+ interpolation_spectrogram.style(height=250, width=125)
138
+ second_input_spectrogram = gr.Image(label="Second Input", type="numpy")
139
+ second_input_spectrogram.style(height=250, width=125)
140
+ interpolation_audio = gr.Audio(type="numpy", label="Interpolation")
141
+ interpolation_audio.style(length=125)
142
+
143
+ interpolation_text_output = gr.Textbox()
144
+ interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache],
145
+ outputs=[first_interpulation_input, second_interpulation_input, interpolation_text_output, encodeCache])
146
+ interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache],
147
+ outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache])
148
+
149
+ def get_random_sampling_module(encodeCache):
150
+
151
+ def random_sample(mu, sigma, encodeCache):
152
+ random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24)
153
+ encodeCache["mytest"] = random_encode.detach().numpy()
154
+ random_encode = random_encode.to("cuda")
155
+ flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode)
156
+ random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0]
157
+ return random_log_spectrum, (sampleRate, random_rec_signal), encodeCache
158
+
159
+ with gr.Tab("Random sampling"):
160
+ gr.Markdown("Test reconstruction.")
161
+ with gr.Row():
162
+ with gr.Column():
163
+ with gr.Row():
164
+ mu = gr.Number(label="mu")
165
+ sigma = gr.Number(label="sigma")
166
+ random_sampling_button = gr.Button(value="Sample")
167
+ with gr.Column():
168
+ random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy")
169
+ random_sampling_spectrogram.style(height=250, width=600)
170
+ random_sampling_audio = gr.Audio(type="numpy", label="Interpolation")
171
+ random_sampling_audio.style(length=125)
172
+ random_sampling_button.click(random_sample, inputs=[mu, sigma, encodeCache], outputs=[random_sampling_spectrogram, random_sampling_audio, encodeCache])
173
+
174
+
175
+ with gr.Blocks() as demo:
176
+ initial_examples = gr.State(value=INIT_ENCODE_CACHE)
177
+ # initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)})
178
+ get_example_module(initial_examples)
179
+ # get_reconstruction_module()
180
+ get_random_sampling_module(initial_examples)
181
+ get_interpolation_module(initial_examples)
182
+
183
+ # demo.launch(share=True)
184
+ demo.launch(share=True, debug=True)
configurations/conf.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "midpoints":
3
+ {
4
+ "osc1_amp": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
5
+ "osc_amp2": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
6
+ "osc2_amp": [0.000, 0.01, 0.05, 0.15, 0.25, 0.45, 0.65, 0.75, 0.85, 0.89, 0.9],
7
+ "attack": [0.001, 0.03, 0.1, 0.25, 0.40, 0.7],
8
+ "decay": [0.001, 0.2, 0.60, 1.2],
9
+ "sustain": [0.01, 0.2, 0.5, 1.0],
10
+ "release": [0.001, 0.15, 0.35, 0.8],
11
+ "cutoff_freq": [2200, 2400, 2600, 2800, 3000,
12
+ 3200, 3400, 3600, 3800, 4000,
13
+ 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800],
14
+ "osc_types": [0, 1, 2, 3, 4, 5],
15
+ "amp_mod_depth": [0, 0, 0, 0.1, 0.3, 0.5, 1.0],
16
+ "amp_mod_freq": [0, 0, 1, 2, 4, 8],
17
+ "mod_waveforms": [0,1,2,3],
18
+ "pitch_mod_depth": [0,1],
19
+ "pitch_mod_freq": [1,2,4,8]
20
+ },
21
+
22
+ "subspace_range":
23
+ {
24
+ "osc_amp2": 2,
25
+ "osc1_amp": 2,
26
+ "osc2_amp": 2,
27
+ "attack": 1,
28
+ "decay": 1,
29
+ "sustain": 2,
30
+ "release": 2,
31
+ "cutoff_freq": 0,
32
+ "osc_types": 0
33
+ },
34
+
35
+ "is_discrete":
36
+ {
37
+ "osc_amp2": false,
38
+ "osc1_amp": false,
39
+ "osc2_amp": false,
40
+ "attack": false,
41
+ "decay": false,
42
+ "sustain": false,
43
+ "release": false,
44
+ "cutoff_freq": false,
45
+ "duration": false,
46
+ "osc_types": true,
47
+ "mod_waveforms": true,
48
+ "amp_mod_depth": true,
49
+ "amp_mod_freq": true,
50
+ "pitch_mod_depth": true,
51
+ "pitch_mod_freq": true
52
+ },
53
+
54
+ "sample_rate": 16384,
55
+ "n_sample_note": 65536,
56
+ "n_sample_music": 65536,
57
+
58
+ "STFT_hyperParameter":
59
+ {
60
+ "frame_length": 512,
61
+ "frame_step": 256
62
+ },
63
+
64
+ "midi_midpoints":
65
+ {
66
+ "duration": [0.1,0.5,1.0,2.0],
67
+ "pitch": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76]
68
+ },
69
+
70
+ "midi_is_discrete":
71
+ {
72
+ "duration": false,
73
+ "pitch": true
74
+ },
75
+
76
+ "midi_max_n_notes": 8,
77
+
78
+ "resolution":
79
+ {
80
+ "time_resolution": 509,
81
+ "freq_resolution": 513
82
+ }
83
+ }
configurations/read_configuration.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ bins_path = 'configurations/conf.json'
3
+
4
+
5
+ def parameter_range(parameter_name):
6
+ """
7
+ :param parameter_name:
8
+ :return: List[Float]--midpoints of bins for the input synthesizer parameter
9
+ """
10
+
11
+ with open(bins_path) as f:
12
+ midpoints = json.load(f)["midpoints"]
13
+ return midpoints[parameter_name]
14
+
15
+
16
+ def cluster_range():
17
+ """
18
+ :return: Dict[String:Int]--defines the range of cluster to search
19
+ """
20
+ with open(bins_path) as f:
21
+ cluster_r = json.load(f)["subspace_range"]
22
+ return cluster_r
23
+
24
+
25
+ def midi_parameter_range(parameter_name):
26
+ """
27
+ :param parameter_name:
28
+ :return: List[Float]--midpoints of bins for the input midi parameter
29
+ """
30
+ with open(bins_path) as f:
31
+ r = json.load(f)["midi_midpoints"]
32
+ return r[parameter_name]
33
+
34
+
35
+ def is_discrete(parameter_name):
36
+ """
37
+ :param parameter_name:
38
+ :return: Boolean--if the input synthesizer parameter is discrete
39
+ """
40
+ with open(bins_path) as f:
41
+ is_dis = json.load(f)["is_discrete"]
42
+ return is_dis[parameter_name]
43
+
44
+
45
+ def midi_is_discrete(parameter_name):
46
+ """
47
+ :param parameter_name:
48
+ :return: Boolean--if the input midi parameter is discrete
49
+ """
50
+ with open(bins_path) as f:
51
+ is_dis = json.load(f)["midi_is_discrete"]
52
+ return is_dis[parameter_name]
53
+
54
+
55
+ def get_label_size():
56
+ """
57
+ :return: Int--length of synthesizer parameter encoding
58
+ """
59
+ with open(bins_path) as f:
60
+ conf = json.load(f)
61
+ midpoints = conf["midpoints"]
62
+ n_labels = 0
63
+ for key in midpoints:
64
+ n_labels = n_labels + len(midpoints[key])
65
+
66
+ return n_labels
67
+
68
+
69
+ def get_bins_length():
70
+ """
71
+ :return: Dict[String:Int]--Number of bins for all synthesizer parameters
72
+ """
73
+ with open(bins_path) as f:
74
+ midpoints = json.load(f)["midpoints"]
75
+ bins_length = {}
76
+
77
+ for key in midpoints:
78
+ bins_length[key] = len(midpoints[key])
79
+
80
+ return bins_length
81
+
82
+
83
+ def get_conf_stft_hyperparameter():
84
+ """
85
+ :return: Dict[String:Int]--STFT hyper parameters
86
+ """
87
+ with open(bins_path) as f:
88
+ STFT_hyperParameters = json.load(f)["STFT_hyperParameter"]
89
+
90
+ return STFT_hyperParameters
91
+
92
+
93
+ def get_conf_sample_rate():
94
+ """
95
+ :return: Int--sample_rate
96
+ """
97
+ with open(bins_path) as f:
98
+ sample_rate = json.load(f)["sample_rate"]
99
+
100
+ return sample_rate
101
+
102
+
103
+ def get_conf_n_sample_note():
104
+ """
105
+ :return: Int--sample number of a note example
106
+ """
107
+ with open(bins_path) as f:
108
+ n_sample_note = json.load(f)["n_sample_note"]
109
+
110
+ return n_sample_note
111
+
112
+
113
+ def get_conf_n_sample():
114
+ """
115
+ :return: Int--sample number of a melody example
116
+ """
117
+ with open(bins_path) as f:
118
+ n_sample = json.load(f)["n_sample_music"]
119
+
120
+ return n_sample
121
+
122
+
123
+ def get_conf_time_resolution():
124
+ """
125
+ :return: Int--spectrogram resolution on time dimension
126
+ """
127
+ with open(bins_path) as f:
128
+ resolution = json.load(f)["resolution"]
129
+
130
+ return resolution["time_resolution"]
131
+
132
+
133
+ def get_conf_pitch_resolution():
134
+ """
135
+ :return: Int--spectrogram resolution on pitch dimension
136
+ """
137
+ with open(bins_path) as f:
138
+ resolution = json.load(f)["resolution"]
139
+
140
+ return resolution["freq_resolution"]
141
+
142
+
143
+ def get_conf_max_n_notes():
144
+ """
145
+ :return: Int--maximum number of notes to be generated in a melody
146
+ """
147
+ with open(bins_path) as f:
148
+ max_n_notes = json.load(f)["midi_max_n_notes"]
149
+
150
+ return max_n_notes
151
+
152
+
data_generation/data_generation.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import librosa
4
+
5
+ from data_generation.encoding import ParameterDescription, Sample
6
+ from melody_synth.random_midi import RandomMidi
7
+ from melody_synth.melody_generator import MelodyGenerator
8
+ from scipy.io.wavfile import write
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import json
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import tensorflow as tf
15
+ import matplotlib
16
+ from configurations.read_configuration import parameter_range, is_discrete, get_conf_stft_hyperparameter
17
+ import shutil
18
+
19
+ # from model.log_spectrogram import power_to_db
20
+ from tools import power_to_db
21
+
22
+ num_params = 16
23
+
24
+
25
+ def plot_spectrogram(signal: np.ndarray,
26
+ path: str,
27
+ frame_length=512,
28
+ frame_step=256):
29
+ """Computes the spectrogram of the given signal and saves it.
30
+
31
+ Parameters
32
+ ----------
33
+ signal: np.ndarray
34
+ The signal for which to compute the spectrogram.
35
+ path: str
36
+ Path to save the the computed spectrogram.
37
+ frame_length:
38
+ Window size of the FFT.
39
+ frame_step:
40
+ Hop size of the FFT.
41
+ """
42
+
43
+ # Compute spectrum for each frame. Returns complex tensor.
44
+ # todo: duplicate code in log_spectrogram.py. Move this somewhere else perhaps.
45
+ spectrogram = tf.signal.stft(signal,
46
+ frame_length=frame_length,
47
+ frame_step=frame_step,
48
+ pad_end=False) # Returns 63 frames instead of 64 otherwise
49
+
50
+ # Compute the magnitudes
51
+ magnitude_spectrum = tf.abs(spectrogram)
52
+ log_spectrum = power_to_db(magnitude_spectrum)
53
+ matplotlib.pyplot.imsave(path, np.transpose(log_spectrum), vmin=-100, vmax=0, origin='lower')
54
+
55
+
56
+ def plot_mel_spectrogram(signal: np.ndarray,
57
+ path: str,
58
+ frame_length=512,
59
+ frame_step=256):
60
+
61
+ spectrogram = librosa.feature.melspectrogram(signal, sr=16384, n_fft=2048, hop_length=frame_step, win_length=frame_length)
62
+ matplotlib.pyplot.imsave(path, spectrogram, vmin=-100, vmax=0, origin='lower')
63
+
64
+
65
+ # List of ParameterDescription objects that specify the parameters for generation
66
+ param_descriptions: List[ParameterDescription]
67
+ param_descriptions = [
68
+
69
+ # Oscillator levels
70
+ ParameterDescription(name="osc1_amp",
71
+ values=parameter_range('osc1_amp'),
72
+ discrete=is_discrete('osc1_amp')),
73
+ ParameterDescription(name="osc2_amp",
74
+ values=parameter_range('osc2_amp'),
75
+ discrete=is_discrete('osc2_amp')),
76
+
77
+ # ADSR params
78
+ ParameterDescription(name="attack",
79
+ values=parameter_range('attack'),
80
+ discrete=is_discrete('attack')),
81
+ ParameterDescription(name="decay",
82
+ values=parameter_range('decay'),
83
+ discrete=is_discrete('decay')),
84
+ ParameterDescription(name="sustain",
85
+ values=parameter_range('sustain'),
86
+ discrete=is_discrete('sustain')),
87
+ ParameterDescription(name="release",
88
+ values=parameter_range('release'),
89
+ discrete=is_discrete('release')),
90
+
91
+ ParameterDescription(name="cutoff_freq",
92
+ values=parameter_range('cutoff_freq'),
93
+ discrete=is_discrete('cutoff_freq')),
94
+
95
+ # Oscillators types
96
+ # 0 for sin saw, 1 for sin square, 2 for saw square
97
+ # 3 for sin triangle, 4 for triangle saw, 5 for triangle square
98
+ ParameterDescription(name="osc_types",
99
+ values=parameter_range('osc_types'),
100
+ discrete=is_discrete('osc_types')),
101
+ ]
102
+
103
+
104
+ def generate_dataset_for_cnn(n: int,
105
+ path_name="./data/data_cnn_model",
106
+ sample_rate=16384,
107
+ n_samples_for_note=16384 * 4,
108
+ n_samples_for_melody=16384 * 4, write_parameter=True, write_spectrogram=True):
109
+ """
110
+ Generate dataset of size n for 'Inversynth' cnn model
111
+
112
+ :param n: Int
113
+ :param path_name: String--path to save the dataset
114
+ :param sample_rate: Int
115
+ :param n_samples_for_note: Int
116
+ :param n_samples_for_melody: Int
117
+ :param write_parameter: Boolean--if write parameter values in a .txt file
118
+ :param write_spectrogram: Boolean--write spectrogram with parameter values in the file name
119
+ :return:
120
+ """
121
+
122
+ shutil.rmtree(path_name)
123
+ Path(path_name).mkdir(parents=True, exist_ok=True)
124
+ print("Generating dataset...")
125
+
126
+ synth = MelodyGenerator(sample_rate,
127
+ n_samples_for_note, n_samples_for_melody)
128
+ randomMidi = RandomMidi()
129
+
130
+ for i in tqdm(range(n)):
131
+ parameter_values = [param.generate() for param in param_descriptions]
132
+
133
+ # Dict of parameter values, what our synthesizer expects as input
134
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
135
+
136
+ strategy = {"rhythm_strategy": "free_rhythm",
137
+ "pitch_strategy": "free_pitch",
138
+ "duration_strategy": "random_duration",
139
+ }
140
+ midi_encode, midi = randomMidi(strategy)
141
+ signal = synth.get_melody(parameter_values_raw, midi=midi).numpy()
142
+
143
+ # Path to store each sample with its label
144
+ path = path_name + f"/{i}"
145
+ Path(path).mkdir(parents=True, exist_ok=True)
146
+
147
+ if write_parameter:
148
+ suffix = 'spectrogram'
149
+ for parameter_value in parameter_values:
150
+ suffix += f'_{parameter_value.name}={"%.3f" % parameter_value.value}'
151
+ if write_spectrogram:
152
+ plot_spectrogram(signal, path=path + f"/{suffix}.png", frame_length=1024, frame_step=256)
153
+ else:
154
+ with open(path + f"/{suffix}.txt", "w") as f:
155
+ f.write("test")
156
+ f.close()
157
+
158
+ write(path + f"/{i}.wav", synth.sample_rate, signal)
159
+
160
+ sample = Sample(parameter_values)
161
+
162
+ # Dump label as json
163
+ with open(path + "/label.json", "w") as label_file:
164
+ label = sample.get_values()
165
+ label['midi'] = midi
166
+ # print(len(label["encoding"]))
167
+ json.dump(label, label_file, ensure_ascii=True)
168
+
169
+ print('Data generation done!')
170
+
171
+
172
+ def generate_dataset_for_triplet(n: int,
173
+ path_name="./data/data_triplet_val_10_500",
174
+ sample_rate=16384,
175
+ n_samples_for_note=16384 * 4,
176
+ n_samples_for_melody=16384 * 4,
177
+ n_labels=30,
178
+ write_spectrogram=True):
179
+ """
180
+ Generate dataset of size n for triplet model
181
+
182
+ :param write_spectrogram: Boolean--if write spectrogram
183
+ :param n: Int :param path_name: String--path to save the dataset :param sample_rate: Int :param
184
+ n_samples_for_note: Int :param n_samples_for_melody: Int :param n_labels: Int--number of synthesizer parameter
185
+ combinations contained in the dataset (a hyper parameter of triplet model)
186
+ """
187
+
188
+ shutil.rmtree(path_name)
189
+ Path(path_name).mkdir(parents=True, exist_ok=True)
190
+ print("Generating dataset...")
191
+ synth = MelodyGenerator(sample_rate,
192
+ n_samples_for_note, n_samples_for_melody)
193
+ randomMidi = RandomMidi()
194
+
195
+ parameter_values_examples = [[param.generate() for param in param_descriptions] for i in range(n_labels)]
196
+ parameter_values_raw_examples = [{param.name: param.value for param in parameter_values} for parameter_values in
197
+ parameter_values_examples]
198
+
199
+ np.random.seed()
200
+ for i in tqdm(range(n)):
201
+ label_index = np.random.randint(0, n_labels)
202
+ parameter_values = parameter_values_examples[label_index]
203
+ parameter_values_raw = parameter_values_raw_examples[label_index]
204
+
205
+ strategy = {"rhythm_strategy": "free_rhythm",
206
+ "pitch_strategy": "free_pitch",
207
+ "duration_strategy": "random_duration",
208
+ }
209
+ midi_encode, midi = randomMidi(strategy)
210
+ signal = synth.get_melody(parameter_values_raw, midi=midi).numpy()
211
+
212
+ # Path to store each sample with its label
213
+ path = path_name + f"/{i}"
214
+ Path(path).mkdir(parents=True, exist_ok=True)
215
+
216
+ write(path + f"/{i}.wav", synth.sample_rate, signal)
217
+ suffix = 'spectrogram'
218
+ for parameter_value in parameter_values:
219
+ suffix += f'_{parameter_value.name}={"%.3f" % parameter_value.value}'
220
+
221
+ if write_spectrogram:
222
+ hp = get_conf_stft_hyperparameter()
223
+ frame_l = hp['frame_length']
224
+ frame_s = hp['frame_length']
225
+ plot_spectrogram(signal, path=path + f"/{suffix}.png", frame_length=frame_l, frame_step=frame_s)
226
+ else:
227
+ with open(path + f"/{suffix}.txt", "w") as f:
228
+ f.write("test")
229
+ f.close()
230
+
231
+ with open(path + "/label_index.json", "w") as label_index_file:
232
+ index_json = {'index': label_index}
233
+ json.dump(index_json, label_index_file, ensure_ascii=False)
234
+
235
+ # save midi as .txt file
236
+ with open(path + "/midi.txt", "w") as midi_file:
237
+ midi_file.write(str(midi))
238
+ midi_file.close()
239
+
240
+ print('Data generation done!')
241
+
242
+
243
+ def manhattan_distance(SP1, SP2):
244
+ """
245
+ :param SP1: first input synthesizer parameter combination
246
+ :param SP2: second input synthesizer parameter combination
247
+ :return: Float--manhattan distance between SP1 and SP2
248
+ """
249
+
250
+ md = []
251
+ for key in SP1:
252
+ parameter_name = key
253
+ value1 = SP1[parameter_name]
254
+ value2 = SP2[parameter_name]
255
+ bins = parameter_range(parameter_name)
256
+ bin_index1 = np.argmin(np.abs(np.array(bins) - value1))
257
+ bin_index2 = np.argmin(np.abs(np.array(bins) - value2))
258
+
259
+ if parameter_name == "osc_types":
260
+ if bin_index1 == bin_index2:
261
+ d = 0
262
+ else:
263
+ d = 1
264
+ else:
265
+ d = np.abs(bin_index1 - bin_index2) / (len(bins) - 1)
266
+ md.append(d)
267
+
268
+ return np.average(md)
269
+
270
+
271
+ def generate_dataset_for_mixed_input_model(n: int,
272
+ path_name="./data/data_mixed_input",
273
+ sample_rate=16384,
274
+ n_samples_for_note=16384 * 4,
275
+ n_samples_for_melody=16384 * 4
276
+ ):
277
+ """
278
+ Generate dataset of size n for mixed_input_model model
279
+
280
+ :param n: Int
281
+ :param path_name: String--path to save the dataset
282
+ :param sample_rate: Int
283
+ :param n_samples_for_note: Int
284
+ :param n_samples_for_melody: Int
285
+ :return:
286
+ """
287
+
288
+ shutil.rmtree(path_name)
289
+ Path(path_name).mkdir(parents=True, exist_ok=True)
290
+ print("Generating dataset...")
291
+ synth = MelodyGenerator(sample_rate,
292
+ n_samples_for_note, n_samples_for_melody)
293
+ randomMidi = RandomMidi()
294
+
295
+ strategy = {"rhythm_strategy": "free_rhythm",
296
+ "pitch_strategy": "free_pitch",
297
+ "duration_strategy": "random_duration",
298
+ }
299
+ strategy0 = {"rhythm_strategy": "single_note_rhythm",
300
+ "pitch_strategy": "fixed_pitch",
301
+ "duration_strategy": "fixed_duration",
302
+ }
303
+ strategy1 = {"rhythm_strategy": "single_note_rhythm",
304
+ "pitch_strategy": "fixed_pitch1",
305
+ "duration_strategy": "fixed_duration",
306
+ }
307
+ strategy2 = {"rhythm_strategy": "single_note_rhythm",
308
+ "pitch_strategy": "fixed_pitch2",
309
+ "duration_strategy": "fixed_duration",
310
+ }
311
+ strategy3 = {"rhythm_strategy": "single_note_rhythm",
312
+ "pitch_strategy": "fixed_pitch3",
313
+ "duration_strategy": "fixed_duration",
314
+ }
315
+ strategy4 = {"rhythm_strategy": "single_note_rhythm",
316
+ "pitch_strategy": "fixed_pitch4",
317
+ "duration_strategy": "fixed_duration",
318
+ }
319
+
320
+ np.random.seed()
321
+ for i in tqdm(range(n)):
322
+ path = path_name + f"/{i}"
323
+ Path(path).mkdir(parents=True, exist_ok=True)
324
+ parameter_values = [param.generate() for param in param_descriptions]
325
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
326
+
327
+ # generate query music
328
+ midi_encode, midi = randomMidi(strategy)
329
+ signal_query = synth.get_melody(parameter_values_raw, midi=midi).numpy()
330
+ write(path + f"/{i}.wav", synth.sample_rate, signal_query)
331
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
332
+
333
+ if np.random.rand() < 0.01: # 50% positive
334
+ with open(path + "/label.json", "w") as label_file:
335
+ sample = Sample(parameter_values)
336
+ label = sample.get_values()
337
+ label['manhattan_distance'] = 0.
338
+ json.dump(label, label_file, ensure_ascii=False)
339
+ else:
340
+ with open(path + "/label.json", "w") as label_file:
341
+ query_sp = parameter_values_raw
342
+ parameter_values = [param.generate() for param in param_descriptions]
343
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
344
+ sample = Sample(parameter_values)
345
+ label = sample.get_values()
346
+ md = manhattan_distance(query_sp, parameter_values_raw)
347
+ label['manhattan_distance'] = md
348
+ json.dump(label, label_file, ensure_ascii=False)
349
+
350
+ # generate query music
351
+ midi_encode, midi = randomMidi(strategy0)
352
+ signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
353
+ write(path + f"/{i}_0.wav", synth.sample_rate, signal_single_note)
354
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
355
+
356
+ # generate query music
357
+ midi_encode, midi = randomMidi(strategy1)
358
+ signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
359
+ write(path + f"/{i}_1.wav", synth.sample_rate, signal_single_note)
360
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
361
+
362
+ # generate query music
363
+ midi_encode, midi = randomMidi(strategy2)
364
+ signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
365
+ write(path + f"/{i}_2.wav", synth.sample_rate, signal_single_note)
366
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
367
+
368
+ # generate query music
369
+ midi_encode, midi = randomMidi(strategy3)
370
+ signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
371
+ write(path + f"/{i}_3.wav", synth.sample_rate, signal_single_note)
372
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
373
+
374
+ # generate query music
375
+ midi_encode, midi = randomMidi(strategy4)
376
+ signal_single_note = synth.get_melody(parameter_values_raw, midi=midi).numpy()
377
+ write(path + f"/{i}_4.wav", synth.sample_rate, signal_single_note)
378
+ # plot_spectrogram(signal, path=path + f"/{i}_input.png", frame_length=512, frame_step=256)
379
+
380
+ print('Data generation done!')
data_generation/decoding.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from data_generation.data_generation import param_descriptions
3
+ import numpy as np
4
+
5
+ from melody_synth.melody_generator import MelodyGenerator
6
+ from melody_synth.random_midi import RandomMidi
7
+
8
+
9
+ def decode_label(prediction: np.ndarray,
10
+ sample_rate: int,
11
+ n_samples: int,
12
+ return_params=False,
13
+ discard_parameters=[]):
14
+ """Parses a network prediction array, synthesizes the described audio and returns it.
15
+
16
+ Parameters
17
+ ----------
18
+ prediction: np.ndarray
19
+ The network prediction array
20
+ sample_rate: int
21
+ Sample rate of the audio to generate.
22
+ n_samples: int
23
+ Number of samples per wav file.
24
+ return_params: bool
25
+ Whether or not to also return the parameters alongside the signal
26
+ discard_parameters: List[str]
27
+ Parameter names that should be discarded (set to their default value)
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray:
32
+ The generated signal
33
+ """
34
+
35
+ params: Dict[str, float] = {}
36
+ index = 0
37
+ for i, param_description in enumerate(param_descriptions):
38
+ # Parses the one-hot-encoding of the prediction array
39
+ bits = len(param_description.values)
40
+ curr_prediction = prediction[index:index + bits]
41
+
42
+ hot_index = curr_prediction.argmax()
43
+ params[param_description.name] = param_description.parameter_value(hot_index).value
44
+ index += bits
45
+
46
+ for param_str in discard_parameters:
47
+ params[param_str] = 0 # todo: make this safe and change to default value and not just 0
48
+
49
+ synth = MelodyGenerator(sample_rate,
50
+ n_samples, n_samples)
51
+ randomMidi = RandomMidi()
52
+
53
+ strategy = {"rhythm_strategy": "single_note_rhythm",
54
+ "pitch_strategy": "fixed_pitch",
55
+ "duration_strategy": "fixed_duration",
56
+ }
57
+ midi_encode, midi = randomMidi(strategy)
58
+
59
+ signal = synth.get_melody(params, midi=midi).numpy()
60
+
61
+ if return_params:
62
+ return signal, params
63
+
64
+ return signal
data_generation/encoding.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import numpy as np
3
+
4
+
5
+ def parameter_range_low_high(parameter_range: List):
6
+ """
7
+ :param parameter_range:List[Float]--midpoints of bins
8
+ :return: List[Float]--lower and upper bounds of bins
9
+ """
10
+ temp1 = np.array(parameter_range[1:])
11
+ temp2 = np.array(parameter_range[:len(temp1)])
12
+ temp1 = 0.5 * (temp1 + temp2)
13
+
14
+ return np.hstack([parameter_range[0], temp1, parameter_range[len(parameter_range) - 1]])
15
+
16
+
17
+ class ParameterValue:
18
+ """Describes a one hot encoded parameter value."""
19
+
20
+ name: str
21
+ value: float
22
+ encoding: List[float]
23
+ index: int
24
+
25
+ def __init__(self, name, value, encoding, index):
26
+ self.name = name
27
+ self.value = value
28
+ self.encoding = encoding
29
+ self.index = index
30
+
31
+
32
+ class ParameterDescription:
33
+ """A description for generating a parameter value."""
34
+
35
+ # Discrete is used to generate samples that don't exactly fit into a bin for training.
36
+ def __init__(self, name, values: List[float], discrete=True):
37
+ self.name = name
38
+ self.values = values
39
+ self.discrete = discrete
40
+ self.parameter_low_high = parameter_range_low_high(values)
41
+
42
+ # one-hot encoding as per paper
43
+ # Value used for specifying a different value than values[index], useful for non-discrete params. todo: too adhoc?
44
+ def parameter_value(self, index, value=None) -> ParameterValue:
45
+ if value is None:
46
+ value = self.values[index]
47
+ encoding = np.zeros(len(self.values), dtype=float)
48
+ encoding[index] = 1.0
49
+ return ParameterValue(
50
+ name=self.name,
51
+ value=value,
52
+ encoding=encoding,
53
+ index=index
54
+ )
55
+
56
+ # random even distribution as per paper
57
+ def generate(self) -> ParameterValue:
58
+ # choose a bin if parameter is discrete
59
+ if self.discrete:
60
+ index = np.random.randint(0, len(self.values))
61
+ return self.parameter_value(index)
62
+ # otherwise generate a random value
63
+ else:
64
+ indexFinder = np.random.uniform(0, 1)
65
+ l = np.linspace(0.0, 1, len(self.values))
66
+ index = np.argmin(np.abs(l - indexFinder))
67
+ value = (self.parameter_low_high[index+1] - self.parameter_low_high[index]) * np.random.uniform(0, 1) + self.parameter_low_high[index]
68
+
69
+ return self.parameter_value(index, value)
70
+
71
+ # get the index of the best matching bin
72
+ def get_bin_index(self, value):
73
+ return np.argmin(np.abs(np.array(self.values) - value))
74
+
75
+ def decode(self, encoding: List[float]) -> ParameterValue:
76
+ index = np.array(encoding).argmax()
77
+ return self.parameter_value(index)
78
+
79
+
80
+ class Sample:
81
+ """Describes the label of one training sample."""
82
+
83
+ parameters: List[ParameterValue]
84
+
85
+ def __init__(self, parameters):
86
+ self.parameters = parameters
87
+
88
+ def get_values(self) -> Dict[str, dict]:
89
+ return {
90
+ "parameters": {p.name: p.value for p in self.parameters},
91
+ "encoding": list(np.hstack(p.encoding for p in self.parameters))
92
+ }
example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
external sources.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 1. External datasets have been preprocessed and stored in the directory '/data/external_data'. Links for the datasets can be found in the same directory.
2
+
3
+ 2. External codes: Some code in '/model/VAE.py', 'non_random_LFOs.py' and '/melody/complex_torch_synth.py' references external code. References are made in the files.
generate_synthetic_data_online.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import librosa
3
+ import matplotlib
4
+ import pandas as pd
5
+ from typing import Optional
6
+ from torch import tensor
7
+ from ddsp.core import tf_float32
8
+ import torch
9
+ from torch import Tensor
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ from torchsynth.config import SynthConfig
13
+ import ddsp
14
+ from pathlib import Path
15
+ from typing import Dict
16
+ from data_generation.encoding import ParameterDescription
17
+ from typing import List
18
+ from configurations.read_configuration import parameter_range, is_discrete, midi_parameter_range, midi_is_discrete
19
+ import shutil
20
+ from tqdm import tqdm
21
+ from scipy.io.wavfile import write
22
+ from melody_synth.complex_torch_synth import DoubleSawSynth, SinSawSynth, SinTriangleSynth, TriangleSawSynth
23
+
24
+ sample_rate = 16000
25
+ n_samples = sample_rate * 4.5
26
+
27
+
28
+ class NoteGenerator:
29
+ """
30
+ This class is responsible for single-note audio generation by function 'get_note'.
31
+ """
32
+
33
+ def __init__(self,
34
+ sample_rate=sample_rate,
35
+ n_samples=sample_rate * 4.5):
36
+ self.sample_rate = sample_rate
37
+ self.n_samples = n_samples
38
+ synthconfig = SynthConfig(
39
+ batch_size=1, reproducible=False, sample_rate=sample_rate,
40
+ buffer_size_seconds=np.float64(n_samples) / np.float64(sample_rate)
41
+ )
42
+ self.Saw_Square_Voice = DoubleSawSynth(synthconfig)
43
+ self.SinSawVoice = SinSawSynth(synthconfig)
44
+ self.SinTriVoice = SinTriangleSynth(synthconfig)
45
+ self.TriSawVoice = TriangleSawSynth(synthconfig)
46
+
47
+ def get_note(self, params: Dict[str, float]):
48
+ osc_amp2 = np.float64(params.get("osc_amp2", 0))
49
+
50
+ if osc_amp2 < 0.45:
51
+ osc1_amp = 0.9
52
+ osc2_amp = osc_amp2
53
+ else:
54
+ osc1_amp = 0.9 - osc_amp2
55
+ osc2_amp = 0.9
56
+
57
+ attack_1 = np.float64(params.get("attack_1", 0))
58
+ decay_1 = np.float64(params.get("decay_1", 0))
59
+ sustain_1 = np.float64(params.get("sustain_1", 0))
60
+ release_1 = np.float64(params.get("release_1", 0))
61
+
62
+ attack_2 = np.float64(params.get("attack_2", 0))
63
+ decay_2 = np.float64(params.get("decay_2", 0))
64
+ sustain_2 = np.float64(params.get("sustain_2", 0))
65
+ release_2 = np.float64(params.get("release_2", 0))
66
+
67
+ amp_mod_freq = params.get("amp_mod_freq", 0)
68
+ amp_mod_depth = params.get("amp_mod_depth", 0)
69
+ amp_mod_waveform = params.get("amp_mod_waveform", 0)
70
+
71
+ pitch_mod_freq_1 = params.get("pitch_mod_freq_1", 0)
72
+ pitch_mod_depth = params.get("pitch_mod_depth", 0)
73
+
74
+ cutoff_freq = params.get("cutoff_freq", 4000)
75
+
76
+ pitch = np.float64(params.get("pitch", 0))
77
+ duration = np.float64(params.get("duration", 0))
78
+
79
+ syn_parameters = {
80
+ ("adsr_1", "attack"): tensor([attack_1]), # [0.0, 2.0]
81
+ ("adsr_1", "decay"): tensor([decay_1]), # [0.0, 2.0]
82
+ ("adsr_1", "sustain"): tensor([sustain_1]), # [0.0, 2.0]
83
+ ("adsr_1", "release"): tensor([release_1]), # [0.0, 2.0]
84
+ ("adsr_1", "alpha"): tensor([5]), # [0.1, 6.0]
85
+
86
+ ("adsr_2", "attack"): tensor([attack_2]), # [0.0, 2.0]
87
+ ("adsr_2", "decay"): tensor([decay_2]), # [0.0, 2.0]
88
+ ("adsr_2", "sustain"): tensor([sustain_2]), # [0.0, 2.0]
89
+ ("adsr_2", "release"): tensor([release_2]), # [0.0, 2.0]
90
+ ("adsr_2", "alpha"): tensor([5]), # [0.1, 6.0]
91
+ ("keyboard", "midi_f0"): tensor([pitch]),
92
+ ("keyboard", "duration"): tensor([duration]),
93
+
94
+ # Mixer parameter
95
+ ("mixer", "vco_1"): tensor([osc1_amp]), # [0, 1]
96
+ ("mixer", "vco_2"): tensor([osc2_amp]), # [0, 1]
97
+
98
+ # Constant parameters:
99
+ ("vco_1", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
100
+ ("vco_1", "tuning"): tensor([0.0]), # [-24.0, 24]
101
+ ("vco_2", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
102
+ ("vco_2", "tuning"): tensor([0.0]), # [-24.0, 24]
103
+
104
+ # LFOs
105
+ ("lfo_amp_sin", "frequency"): tensor([amp_mod_freq]), # [0, 20]
106
+ ("lfo_amp_sin", "mod_depth"): tensor([0]), # [-10, 20]
107
+ ("lfo_pitch_sin_1", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
108
+ ("lfo_pitch_sin_1", "mod_depth"): tensor([10]), # [-10, 20]
109
+ ("lfo_pitch_sin_2", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
110
+ ("lfo_pitch_sin_2", "mod_depth"): tensor([10]), # [-10, 20]
111
+ }
112
+
113
+ osc_types = params.get("osc_types", 0)
114
+ if osc_types == 0:
115
+ synth = self.SinSawVoice
116
+ syn_parameters[("vco_2", "shape")] = tensor([1])
117
+ elif osc_types == 1:
118
+ synth = self.SinSawVoice
119
+ syn_parameters[("vco_2", "shape")] = tensor([0])
120
+ elif osc_types == 2:
121
+ synth = self.Saw_Square_Voice
122
+ syn_parameters[("vco_1", "shape")] = tensor([1])
123
+ syn_parameters[("vco_2", "shape")] = tensor([0])
124
+ elif osc_types == 3:
125
+ synth = self.SinTriVoice
126
+ elif osc_types == 4:
127
+ synth = self.TriSawVoice
128
+ syn_parameters[("vco_2", "shape")] = tensor([1])
129
+ else:
130
+ synth = self.TriSawVoice
131
+ syn_parameters[("vco_2", "shape")] = tensor([0])
132
+
133
+ synth.set_parameters(syn_parameters)
134
+ audio_out = synth.get_signal(amp_mod_depth, amp_mod_waveform, int(sample_rate * duration), osc1_amp, osc2_amp)
135
+ single_note = audio_out[0].detach().numpy()
136
+
137
+ cutoff_freq = tf_float32(cutoff_freq)
138
+ impulse_response = ddsp.core.sinc_impulse_response(cutoff_freq, 2048, self.sample_rate)
139
+ single_note = tf_float32(single_note)
140
+ return ddsp.core.fft_convolve(single_note[tf.newaxis, :], impulse_response)[0, :]
141
+
142
+
143
+ class MelodyGenerator:
144
+ """
145
+ This class is responsible for multi-note audio generation by function 'get_melody'.
146
+ """
147
+
148
+ def __init__(self,
149
+ sample_rate=sample_rate,
150
+ n_note_samples=sample_rate * 4.5,
151
+ n_melody_samples=sample_rate * 4.5):
152
+ self.sample_rate = sample_rate
153
+ self.noteGenerator = NoteGenerator(sample_rate, sample_rate * 4.5)
154
+ self.n_melody_samples = int(n_melody_samples)
155
+
156
+ def get_melody(self, params_list: List[Dict[str, float]], onsets):
157
+ track = np.zeros(self.n_melody_samples)
158
+ for i in range(len(onsets)):
159
+ location = onsets[i]
160
+ single_note = self.noteGenerator.get_note(params_list[i])
161
+ single_note = np.hstack(
162
+ [np.zeros(int(location)), single_note, np.zeros(self.n_melody_samples)])[
163
+ :self.n_melody_samples]
164
+ track = track + single_note
165
+ return track
166
+
167
+
168
+ def plot_log_spectrogram(signal: np.ndarray,
169
+ path: str,
170
+ n_fft=2048,
171
+ frame_length=1024,
172
+ frame_step=256):
173
+ """Write spectrogram."""
174
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
175
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
176
+ magnitude_spectrum = np.abs(amp)
177
+ log_mel = np_power_to_db(magnitude_spectrum)
178
+ matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
179
+
180
+
181
+ def np_power_to_db(S, amin=1e-16, top_db=80.0):
182
+ """A helper function for scaling."""
183
+
184
+ def np_log10(x):
185
+ numerator = np.log(x)
186
+ denominator = np.log(10)
187
+ return numerator / denominator
188
+
189
+ # Scale magnitude relative to maximum value in S. Zeros in the output
190
+ # correspond to positions where S == ref.
191
+ ref = np.max(S)
192
+
193
+ # 每个元素取max
194
+ log_spec = 10.0 * np_log10(np.maximum(amin, S))
195
+ log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
196
+
197
+ log_spec = np.maximum(log_spec, np.max(log_spec) - top_db)
198
+
199
+ return log_spec
200
+
201
+
202
+ synth = MelodyGenerator()
203
+ param_descriptions: List[ParameterDescription]
204
+
205
+ param_descriptions = [
206
+
207
+ # Oscillator levels
208
+ ParameterDescription(name="osc_amp2",
209
+ values=parameter_range('osc_amp2'),
210
+ discrete=is_discrete('osc_amp2')),
211
+
212
+ # ADSR params
213
+ ParameterDescription(name="attack_1",
214
+ values=parameter_range('attack'),
215
+ discrete=is_discrete('attack')),
216
+ ParameterDescription(name="decay_1",
217
+ values=parameter_range('decay'),
218
+ discrete=is_discrete('decay')),
219
+ ParameterDescription(name="sustain_1",
220
+ values=parameter_range('sustain'),
221
+ discrete=is_discrete('sustain')),
222
+ ParameterDescription(name="release_1",
223
+ values=parameter_range('release'),
224
+ discrete=is_discrete('release')),
225
+ ParameterDescription(name="attack_2",
226
+ values=parameter_range('attack'),
227
+ discrete=is_discrete('attack')),
228
+ ParameterDescription(name="decay_2",
229
+ values=parameter_range('decay'),
230
+ discrete=is_discrete('decay')),
231
+ ParameterDescription(name="sustain_2",
232
+ values=parameter_range('sustain'),
233
+ discrete=is_discrete('sustain')),
234
+ ParameterDescription(name="release_2",
235
+ values=parameter_range('release'),
236
+ discrete=is_discrete('release')),
237
+
238
+ ParameterDescription(name="cutoff_freq",
239
+ values=parameter_range('cutoff_freq'),
240
+ discrete=is_discrete('cutoff_freq')),
241
+ ParameterDescription(name="pitch",
242
+ values=midi_parameter_range('pitch'),
243
+ discrete=midi_is_discrete('pitch')),
244
+ ParameterDescription(name="duration",
245
+ values=midi_parameter_range('duration'),
246
+ discrete=midi_is_discrete('duration')),
247
+
248
+ ParameterDescription(name="amp_mod_freq",
249
+ values=parameter_range('amp_mod_freq'),
250
+ discrete=is_discrete('amp_mod_freq')),
251
+ ParameterDescription(name="amp_mod_depth",
252
+ values=parameter_range('amp_mod_depth'),
253
+ discrete=is_discrete('amp_mod_depth')),
254
+
255
+ ParameterDescription(name="pitch_mod_freq_1",
256
+ values=parameter_range('pitch_mod_freq'),
257
+ discrete=is_discrete('pitch_mod_freq')),
258
+
259
+ ParameterDescription(name="pitch_mod_freq_2",
260
+ values=parameter_range('pitch_mod_freq'),
261
+ discrete=is_discrete('pitch_mod_freq')),
262
+ ParameterDescription(name="pitch_mod_depth",
263
+ values=parameter_range('pitch_mod_depth'),
264
+ discrete=is_discrete('pitch_mod_depth')),
265
+
266
+ # Oscillators types
267
+ # 0 for sin saw, 1 for sin square, 2 for saw square
268
+ # 3 for sin triangle, 4 for triangle saw, 5 for triangle square
269
+ ParameterDescription(name="osc_types",
270
+ values=parameter_range('osc_types'),
271
+ discrete=is_discrete('osc_types')),
272
+ ]
273
+
274
+ frame_length = 1024
275
+ frame_step = 256
276
+ spectrogram_len = 256
277
+
278
+ n_fft = 1024
279
+
280
+
281
+ def generate_synth_dataset_log_muted_512(n: int, path_name="./data/data_log", write_spec=False):
282
+ if Path(path_name).exists():
283
+ shutil.rmtree(path_name)
284
+
285
+ Path(path_name).mkdir(parents=True, exist_ok=True)
286
+ print("Generating dataset...")
287
+
288
+ synthetic_data = np.ones((n, 512, 256))
289
+
290
+ for i in range(n):
291
+ index = i
292
+ parameter_values = [param.generate() for param in param_descriptions]
293
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
294
+ parameter_values_raw["duration"] = 3.0
295
+ parameter_values_raw["pitch"] = 52
296
+ parameter_values_raw["pitch_mod_depth"] = 0.0
297
+ signal = synth.get_melody([parameter_values_raw], [0])
298
+ # mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
299
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
300
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
301
+
302
+ synthetic_data[i] = amp[:512, :256]
303
+
304
+ if write_spec:
305
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
306
+ plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
307
+ print(f"Generating dataset over, {n} samples generated!")
308
+ return synthetic_data
309
+
310
+
311
+ def generate_synth_dataset_log_512(n: int, path_name="./data/data_log", write_spec=False):
312
+ """Generate the synthetic dataset with a progress bar."""
313
+ if Path(path_name).exists():
314
+ shutil.rmtree(path_name)
315
+
316
+ Path(path_name).mkdir(parents=True, exist_ok=True)
317
+ print("Generating dataset...")
318
+
319
+ synthetic_data = np.ones((n, 512, 256))
320
+
321
+ for i in tqdm(range(n)):
322
+ index = i
323
+ parameter_values = [param.generate() for param in param_descriptions]
324
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
325
+ parameter_values_raw["duration"] = 3.0
326
+ parameter_values_raw["pitch"] = 52
327
+ parameter_values_raw["pitch_mod_depth"] = 0.0
328
+ signal = synth.get_melody([parameter_values_raw], [0])
329
+ # mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
330
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
331
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
332
+
333
+ synthetic_data[i] = amp[:512, :256]
334
+
335
+ if write_spec:
336
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
337
+ plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
338
+ print(f"Generating dataset over, {n} samples generated!")
339
+ return synthetic_data
340
+
341
+
342
+ def generate_DANN_dataset_muted(n: int, path_name="./data/data_DANN", write_spec=False):
343
+ """Generate the synthetic dataset without a progress bar."""
344
+ if Path(path_name).exists():
345
+ shutil.rmtree(path_name)
346
+
347
+ Path(path_name).mkdir(parents=True, exist_ok=True)
348
+ print("Generating dataset...")
349
+
350
+ multinote_data = np.ones((n, 512, 256))
351
+ single_data = np.ones((n, 512, 256))
352
+ for i in range(n):
353
+ index = i
354
+ par_list = []
355
+ n_notes = np.random.randint(1, 5)
356
+ onsets = []
357
+ for j in range(n_notes):
358
+ parameter_values = [param.generate() for param in param_descriptions]
359
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
360
+ # parameter_values_raw["duration"] = 0.5
361
+ parameter_values_raw["pitch_mod_depth"] = 0.0
362
+ par_list.append(parameter_values_raw)
363
+ onsets.append(np.random.randint(0, sample_rate * 3))
364
+
365
+ signal = synth.get_melody(par_list, onsets)
366
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
367
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
368
+ multinote_data[i] = amp[:512, :256]
369
+ if write_spec:
370
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
371
+ plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
372
+ frame_step=frame_step)
373
+
374
+ single_par = par_list[np.argmin(onsets)]
375
+ single_par["duration"] = 3.0
376
+ single_par["pitch"] = 52
377
+ signal = synth.get_melody([single_par], [0])
378
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
379
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
380
+ single_data[i] = amp[:512, :256]
381
+ if write_spec:
382
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
383
+ plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
384
+ frame_step=frame_step)
385
+ print(f"Generating dataset over, {n} samples generated!")
386
+ return multinote_data, single_data
387
+
388
+
389
+ def generate_DANN_dataset(n: int, path_name="./data/data_DANN", write_spec=False):
390
+ """Generate the synthetic dataset for adversarial training."""
391
+ if Path(path_name).exists():
392
+ shutil.rmtree(path_name)
393
+
394
+ Path(path_name).mkdir(parents=True, exist_ok=True)
395
+ print("Generating dataset...")
396
+
397
+ multinote_data = np.ones((n, 512, 256))
398
+ single_data = np.ones((n, 512, 256))
399
+ for i in tqdm(range(n)):
400
+ par_list = []
401
+ n_notes = np.random.randint(1, 5)
402
+ onsets = []
403
+ for j in range(n_notes):
404
+ parameter_values = [param.generate() for param in param_descriptions]
405
+ parameter_values_raw = {param.name: param.value for param in parameter_values}
406
+ parameter_values_raw["pitch_mod_depth"] = 0.0
407
+ par_list.append(parameter_values_raw)
408
+ onsets.append(np.random.randint(0, sample_rate * 3))
409
+
410
+ signal = synth.get_melody(par_list, onsets)
411
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
412
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
413
+ multinote_data[i] = amp[:512, :256]
414
+ if write_spec:
415
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
416
+ plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
417
+ frame_step=frame_step)
418
+
419
+ single_par = par_list[np.argmin(onsets)]
420
+ single_par["duration"] = 3.0
421
+ single_par["pitch"] = 52
422
+ signal = synth.get_melody([single_par], [0])
423
+ stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
424
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
425
+ single_data[i] = amp[:512, :256]
426
+ if write_spec:
427
+ write(path_name + f"/{i}.wav", synth.sample_rate, signal)
428
+ plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
429
+ frame_step=frame_step)
430
+ print(f"Generating dataset over, {n} samples generated!")
431
+ return multinote_data, single_data
load_data.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import numpy as np
3
+ from generate_synthetic_data_online import generate_synth_dataset_log_512, generate_synth_dataset_log_muted_512
4
+ from tools import show_spc, spc_to_VAE_input, VAE_out_put_to_spc, np_log10
5
+ import torch.utils.data as data
6
+
7
+
8
+ class Data_cache():
9
+ """This is a class that stores synthetic data."""
10
+
11
+ def __init__(self, synthetic_data, external_sources):
12
+ self.n_synthetic = np.shape(synthetic_data)[0]
13
+ self.synthetic_data = synthetic_data.astype(np.float32)
14
+ self.external_sources = external_sources.astype(np.float32)
15
+ self.epsilon = 1e-20
16
+
17
+ def get_all_data(self):
18
+ return np.vstack([self.synthetic_data, self.external_sources])
19
+
20
+ def refresh(self):
21
+ self.synthetic_data = generate_synth_dataset(self.n_synthetic, mute=True)
22
+
23
+ def get_data_loader(self, shuffle=True, BATCH_SIZE=8, new_way=False):
24
+ all_data = self.get_all_data()
25
+ our_data = []
26
+ for i in range(len(all_data)):
27
+ if new_way:
28
+ spectrogram = VAE_out_put_to_spc(np.reshape(all_data[i], (1, 512, 256)))
29
+ log_spectrogram = np.log10(spectrogram + self.epsilon)
30
+ our_data.append(log_spectrogram)
31
+ else:
32
+ our_data.append(np.reshape(all_data[i], (1, 512, 256)))
33
+
34
+ iterator = data.DataLoader(our_data, shuffle=shuffle, batch_size=BATCH_SIZE)
35
+ return iterator
36
+
37
+
38
+ def generate_synth_dataset(n_synthetic, mute=False):
39
+ """Preprocessing for synthetic data"""
40
+ n_synthetic_sample = n_synthetic
41
+ if mute:
42
+ Input0 = generate_synth_dataset_log_muted_512(n_synthetic_sample)
43
+ else:
44
+ Input0 = generate_synth_dataset_log_512(n_synthetic_sample)
45
+ Input0 = spc_to_VAE_input(Input0)
46
+ Input0 = Input0.reshape(Input0.shape[0], Input0.shape[1], Input0.shape[2], 1)
47
+ return Input0
48
+
49
+
50
+ def read_data(data_path):
51
+ """Read external sources"""
52
+
53
+ data = np.array(joblib.load(data_path))
54
+ data = spc_to_VAE_input(data)
55
+ data = data.reshape(data.shape[0], data.shape[1], data.shape[2], 1)
56
+ return data
57
+
58
+
59
+ def load_data(n_synthetic):
60
+ """Generate the hybrid dataset."""
61
+ Input_synthetic = generate_synth_dataset(n_synthetic)
62
+
63
+ Input_AU = read_data("./data/external_data/ARTURIA_data")
64
+ print("ARTURIA dataset loaded.")
65
+
66
+ Input_NSynth = read_data("./data/external_data/NSynth_data")
67
+ print("NSynth dataset loaded.")
68
+
69
+ Input_SF = read_data("./data/external_data/soundfonts_data")
70
+ Input_SF_256 = np.zeros((337, 512, 256, 1))
71
+ Input_SF_256[:,:,:251,:] += Input_SF
72
+ Input_SF =Input_SF_256
73
+ print("SoundFonts dataset loaded.")
74
+
75
+ Input_google = read_data("./data/external_data/WaveNet_samples")
76
+
77
+ Input_external = np.vstack([Input_AU, Input_NSynth, Input_SF, Input_google])
78
+ data_cache = Data_cache(Input_synthetic, Input_external)
79
+ print(f"Data loaded, data shape: {np.shape(data_cache.get_all_data())}")
80
+ return data_cache
81
+
82
+
83
+ def show_data(dataset_name, n_sample=3, index=-1, new_way=False):
84
+ """Show and return a certain dataset.
85
+
86
+ Parameters
87
+ ----------
88
+ dataset_name: String
89
+ Name of the dataset to show.
90
+ n_samples: int
91
+ Number of samples to show.
92
+ index: int
93
+ Setting 'index' larger equal 0 shows the 'index'-th sample in the desired dataset.
94
+
95
+ Returns
96
+ -------
97
+ np.ndarray:
98
+ The showed dataset.
99
+ """
100
+
101
+ if dataset_name == "ARTURIA":
102
+ data = read_data("./data/external_data/ARTURIA_data")
103
+ elif dataset_name == "NSynth":
104
+ data = read_data("./data/external_data/NSynth_data")
105
+ elif dataset_name == "SoundFonts":
106
+ data = read_data("./data/external_data/soundfonts_data")
107
+ elif dataset_name == "Synthetic":
108
+ data = generate_synth_dataset(int(n_sample * 3))
109
+ else:
110
+ print("Example command: \"!python thesis_main.py show_data -s [ARTURIA, NSynth, SoundFonts, Synthetic] -n 5\"")
111
+ return
112
+
113
+ if index >= 0:
114
+ show_spc(VAE_out_put_to_spc(data[index]))
115
+ else:
116
+ for i in range(n_sample):
117
+ index = np.random.randint(0,len(data))
118
+ print(index)
119
+ show_spc(VAE_out_put_to_spc(data[index]))
120
+ return data
121
+
122
+
123
+ def show_data(tensor_batch, index=-1, new_way=False):
124
+ if index < 0:
125
+ index = np.random.randint(0, tensor_batch.shape[0])
126
+
127
+ if new_way:
128
+ sample = tensor_batch[index].detach().numpy()
129
+ spectrogram = 10.0 ** sample
130
+ print(f"The {index}-th sample:")
131
+ show_spc(spectrogram)
132
+ else:
133
+ sample = tensor_batch[index].detach().numpy()
134
+ show_spc(VAE_out_put_to_spc(sample))
135
+ # return data
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
melody_synth/complex_torch_synth.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor, tensor
6
+
7
+ from torchsynth.config import SynthConfig
8
+ from torchsynth.module import (
9
+ ADSR,
10
+ VCA,
11
+ AudioMixer,
12
+ ControlRateUpsample,
13
+ MonophonicKeyboard,
14
+ SineVCO,
15
+ SquareSawVCO,
16
+ VCO, LFO, ModulationMixer,
17
+ )
18
+ from torchsynth.signal import Signal
19
+ from torchsynth.synth import AbstractSynth
20
+
21
+ # from configurations.read_configuration import get_conf_sample_rate
22
+ from melody_synth.non_random_LFOs import SinLFO, SawLFO, TriLFO, SquareLFO, RSawLFO
23
+
24
+
25
+ class TriangleVCO(VCO):
26
+ """This is an expanded module that inherits VCO producing Triangle waves."""
27
+
28
+ def oscillator(self, argument: Signal, midi_f0: Tensor) -> Signal:
29
+ return torch.arcsin(torch.sin(argument * 2)) * 2.0 / torch.pi
30
+
31
+
32
+ class AmpModTorchSynth(AbstractSynth):
33
+ """This is an abstract class using the modules provided by 1B1Synth to assemble synthesizers that generate the
34
+ training set. (The implementation of this class references code in TorchSynth) """
35
+
36
+ def __init__(
37
+ self,
38
+ synthconfig: Optional[SynthConfig] = None,
39
+ nebula: Optional[str] = "nebula",
40
+ *args,
41
+ **kwargs,
42
+ ):
43
+ AbstractSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
44
+ self.share_modules = [
45
+ ("keyboard", MonophonicKeyboard),
46
+ ("adsr_1", ADSR),
47
+ ("adsr_2", ADSR),
48
+ ("upsample", ControlRateUpsample),
49
+ ("vca", VCA),
50
+ ("lfo_amp_sin", SinLFO),
51
+ ("lfo_pitch_sin_1", SinLFO),
52
+ ("lfo_pitch_sin_2", SinLFO),
53
+ (
54
+ "mixer",
55
+ AudioMixer,
56
+ {
57
+ "n_input": 2,
58
+ "curves": [1.0, 1.0],
59
+ "names": ["vco_1", "vco_2"],
60
+ },
61
+ )
62
+ ]
63
+
64
+ def output(self) -> Tensor:
65
+ """Synthesizes the signal as Tensor"""
66
+
67
+ midi_f0, note_on_duration = self.keyboard()
68
+ adsr1 = self.adsr_1(note_on_duration)
69
+ adsr1 = self.upsample(adsr1)
70
+
71
+ adsr2 = self.adsr_2(note_on_duration)
72
+ adsr2 = self.upsample(adsr2)
73
+
74
+ amp_modulation = self.lfo_amp_sin()
75
+ amp_modulation = self.upsample(amp_modulation)
76
+
77
+ pitch_modulation_1 = self.lfo_pitch_sin_1()
78
+ pitch_modulation_1 = self.upsample(pitch_modulation_1)
79
+ pitch_modulation_2 = self.lfo_pitch_sin_2()
80
+ pitch_modulation_2 = self.upsample(pitch_modulation_2)
81
+
82
+ vco_amp1 = adsr1 * (amp_modulation * 0.5 + 0.5)
83
+ vco_amp2 = adsr2 * (amp_modulation * 0.5 + 0.5)
84
+ vco_1_out = self.vco_1(midi_f0, pitch_modulation_1)
85
+ vco_1_out = self.vca(vco_1_out, vco_amp1)
86
+
87
+ vco_2_out = self.vco_2(midi_f0, pitch_modulation_2)
88
+ vco_2_out = self.vca(vco_2_out, vco_amp2)
89
+
90
+ return self.mixer(vco_1_out, vco_2_out)
91
+
92
+ def get_signal(self, amp_mod_depth, amp_waveform, duration_l, amp1, amp2):
93
+ """Synthesizes the signal as Tensor"""
94
+
95
+ midi_f0, note_on_duration = self.keyboard()
96
+ adsr1 = self.adsr_1(note_on_duration)
97
+ adsr1 = self.upsample(adsr1)
98
+
99
+ adsr2 = self.adsr_2(note_on_duration)
100
+ adsr2 = self.upsample(adsr2)
101
+
102
+ amp_modulation = self.lfo_amp_sin()
103
+ amp_modulation = self.upsample(amp_modulation)
104
+
105
+ pitch_modulation_1 = self.lfo_pitch_sin_1()
106
+ pitch_modulation_1 = self.upsample(pitch_modulation_1)
107
+ pitch_modulation_2 = self.lfo_pitch_sin_2()
108
+ pitch_modulation_2 = self.upsample(pitch_modulation_2)
109
+
110
+ vco_amp1 = adsr1 * (amp_modulation * 0.5 + 0.5)
111
+ vco_amp2 = adsr2 * (amp_modulation * 0.5 + 0.5)
112
+ vco_1_out = self.vco_1(midi_f0, pitch_modulation_1)
113
+ vco_1_out = self.vca(vco_1_out, vco_amp1)
114
+
115
+ vco_2_out = self.vco_2(midi_f0, pitch_modulation_1)
116
+ vco_2_out = self.vca(vco_2_out, vco_amp2)
117
+
118
+ return self.mixer(vco_1_out, vco_2_out)
119
+
120
+
121
+ class DoubleSawSynth(AmpModTorchSynth):
122
+ """In addition to the shared modules, this synthesizer uses two "SquareSawVCO" modules to generate square and
123
+ sawtooth waves"""
124
+
125
+ def __init__(
126
+ self,
127
+ synthconfig: Optional[SynthConfig] = None,
128
+ nebula: Optional[str] = "saw_square_voice",
129
+ *args,
130
+ **kwargs,
131
+ ):
132
+ AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
133
+
134
+ # Register all modules as children
135
+ module_list = self.share_modules
136
+ module_list.append(("vco_1", SquareSawVCO))
137
+ module_list.append(("vco_2", SquareSawVCO))
138
+ self.add_synth_modules(module_list)
139
+
140
+
141
+ class SinSawSynth(AmpModTorchSynth):
142
+ """In addition to the shared modules, this synthesizer uses a "SinVco" and a "SquareSawVCO" to generate
143
+ sine and sawtooth/square waves """
144
+
145
+ def __init__(
146
+ self,
147
+ synthconfig: Optional[SynthConfig] = None,
148
+ nebula: Optional[str] = "sin_saw_voice",
149
+ *args,
150
+ **kwargs,
151
+ ):
152
+ AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
153
+
154
+ # Register all modules as children
155
+ module_list = self.share_modules
156
+ module_list.append(("vco_1", SineVCO))
157
+ module_list.append(("vco_2", SquareSawVCO))
158
+ self.add_synth_modules(module_list)
159
+
160
+
161
+ class SinTriangleSynth(AmpModTorchSynth):
162
+ """In addition to the shared modules, this synthesizer uses a "SinVco" and a "TriangleVCO" to generate
163
+ sine and triangle waves """
164
+
165
+ def __init__(
166
+ self,
167
+ synthconfig: Optional[SynthConfig] = None,
168
+ nebula: Optional[str] = "sin_tri_voice",
169
+ *args,
170
+ **kwargs,
171
+ ):
172
+ AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
173
+
174
+ # Register all modules as children
175
+ module_list = self.share_modules
176
+ module_list.append(("vco_1", SineVCO))
177
+ module_list.append(("vco_2", TriangleVCO))
178
+ self.add_synth_modules(module_list)
179
+
180
+
181
+ class TriangleSawSynth(AmpModTorchSynth):
182
+ """In addition to the shared modules, this synthesizer uses a "TriangleVCO" and a "SquareSawVCO" to generate
183
+ triangle and sawtooth/square waves """
184
+
185
+ def __init__(
186
+ self,
187
+ synthconfig: Optional[SynthConfig] = None,
188
+ nebula: Optional[str] = "triangle_saw_voice",
189
+ *args,
190
+ **kwargs,
191
+ ):
192
+ AmpModTorchSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)
193
+
194
+ # Register all modules as children
195
+ module_list = self.share_modules
196
+ module_list.append(("vco_1", TriangleVCO))
197
+ module_list.append(("vco_2", SquareSawVCO))
198
+ self.add_synth_modules(module_list)
199
+
200
+
201
+ def amp_mod_with_duration(env, duration_l):
202
+ env_np = env.detach().numpy()[0] + 1e-30
203
+ env_np_shift = np.hstack([[0], env_np[:-1]])
204
+ env_np_sign = (env_np - env_np_shift)[:duration_l] + 1e-30
205
+
206
+ env_np_sign_nor = np.around(env_np_sign / np.abs(env_np_sign))
207
+ env_np_sign_nor_shift = np.hstack([[0], env_np_sign_nor[:-1]])
208
+ extreme_points = (env_np_sign_nor - env_np_sign_nor_shift)
209
+
210
+ (max_loc,) = np.where(extreme_points == -2)
211
+
212
+ n_max = len(max_loc)
213
+ if n_max == 0:
214
+ return env
215
+ else:
216
+ last_max_loc = max_loc[n_max - 1] - 1
217
+ # new_env = np.hstack([env_np[:last_max_loc], np.ones(len(env_np) - last_max_loc) * env_np[last_max_loc]])
218
+ new_env = np.hstack([env_np[:last_max_loc], (env_np[last_max_loc:] * 0.8 + 0.2)])
219
+
220
+
221
+ return tensor([new_env])
melody_synth/melody_generator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ from ddsp.core import tf_float32
4
+ import tensorflow as tf
5
+ import ddsp
6
+ import numpy as np
7
+ from torch import tensor
8
+ from melody_synth.complex_torch_synth import SinSawSynth, DoubleSawSynth, TriangleSawSynth, SinTriangleSynth
9
+ from torchsynth.config import SynthConfig
10
+
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ else:
14
+ device = "cpu"
15
+
16
+
17
+ class MelodyGenerator:
18
+ """This is the only external interface of the melody_synth package."""
19
+
20
+ def __init__(self,
21
+ sample_rate: int,
22
+ n_note_samples: int,
23
+ n_melody_samples: int):
24
+ self.sample_rate = sample_rate
25
+ self.n_note_samples = n_note_samples
26
+ self.n_melody_samples = n_melody_samples
27
+ synthconfig = SynthConfig(
28
+ batch_size=1, reproducible=False, sample_rate=sample_rate,
29
+ buffer_size_seconds=np.float64(n_note_samples) / np.float64(sample_rate)
30
+ )
31
+ self.Saw_Square_Voice = DoubleSawSynth(synthconfig)
32
+ self.SinSawVoice = SinSawSynth(synthconfig)
33
+ self.SinTriVoice = SinTriangleSynth(synthconfig)
34
+ self.TriSawVoice = TriangleSawSynth(synthconfig)
35
+
36
+ def get_melody(self, params: Dict[str, float], midi) -> [tf.Tensor]:
37
+ """Generates a random melody audio.
38
+
39
+ Parameters
40
+ ----------
41
+ params: Dict[str, float]
42
+ Dictionary of specifications (see Readme).
43
+ midi: List[float, float, float]
44
+ Melody midi (see Readme).
45
+
46
+ Returns
47
+ -------
48
+ onsets: List[tf.Tensor]
49
+ Audio.
50
+ """
51
+
52
+ osc1_amp = np.float(params.get("osc1_amp", 0))
53
+ osc2_amp = np.float(params.get("osc2_amp", 0))
54
+ attack = np.float(params.get("attack", 0))
55
+ decay = np.float(params.get("decay", 0))
56
+ sustain = np.float(params.get("sustain", 0))
57
+ release = np.float(params.get("release", 0))
58
+ cutoff_freq = params.get("cutoff_freq", 4000)
59
+
60
+ syn_parameters = {
61
+ ("adsr", "attack"): tensor([attack]), # [0.0, 2.0]
62
+ ("adsr", "decay"): tensor([decay]), # [0.0, 2.0]
63
+ ("adsr", "sustain"): tensor([sustain]), # [0.0, 2.0]
64
+ ("adsr", "release"): tensor([release]), # [0.0, 2.0]
65
+ ("adsr", "alpha"): tensor([3]), # [0.1, 6.0]
66
+
67
+ # Mixer parameter
68
+ ("mixer", "vco_1"): tensor([osc1_amp]), # [0, 1]
69
+ ("mixer", "vco_2"): tensor([osc2_amp]), # [0, 1]
70
+
71
+ # Constant parameters:
72
+ ("vco_1", "mod_depth"): tensor([0.0]), # [-96, 96]
73
+ ("vco_1", "tuning"): tensor([0.0]), # [-24.0, 24]
74
+ ("vco_2", "mod_depth"): tensor([0.0]), # [-96, 96]
75
+ ("vco_2", "tuning"): tensor([0.0]), # [-24.0, 24]
76
+ }
77
+
78
+ osc_types = params.get("osc_types", 0)
79
+ if osc_types == 0:
80
+ synth = self.SinSawVoice
81
+ syn_parameters[("vco_2", "shape")] = tensor([1])
82
+ elif osc_types == 1:
83
+ synth = self.SinSawVoice
84
+ syn_parameters[("vco_2", "shape")] = tensor([0])
85
+ elif osc_types == 2:
86
+ synth = self.Saw_Square_Voice
87
+ syn_parameters[("vco_1", "shape")] = tensor([1])
88
+ syn_parameters[("vco_2", "shape")] = tensor([0])
89
+ elif osc_types == 3:
90
+ synth = self.SinTriVoice
91
+ elif osc_types == 4:
92
+ synth = self.TriSawVoice
93
+ syn_parameters[("vco_2", "shape")] = tensor([1])
94
+ else:
95
+ synth = self.TriSawVoice
96
+ syn_parameters[("vco_2", "shape")] = tensor([0])
97
+
98
+ track = np.zeros(self.n_melody_samples)
99
+ for i in range(len(midi)):
100
+ (location, pitch, duration) = midi[i]
101
+ syn_parameters[("keyboard", "midi_f0")] = tensor([pitch])
102
+ syn_parameters[("keyboard", "duration")] = tensor([duration])
103
+ synth.set_parameters(syn_parameters)
104
+
105
+ audio_out, parameters, is_train = synth()
106
+ single_note = audio_out[0]
107
+
108
+ single_note = np.hstack(
109
+ [np.zeros(int(location * self.sample_rate)), single_note, np.zeros(self.n_melody_samples)])[
110
+ :self.n_melody_samples]
111
+ track = track + single_note
112
+
113
+ no_cutoff = False
114
+ if no_cutoff:
115
+ return track
116
+ cutoff_freq = tf_float32(cutoff_freq)
117
+ impulse_response = ddsp.core.sinc_impulse_response(cutoff_freq,
118
+ 2048,
119
+ self.sample_rate)
120
+ track = tf_float32(track)
121
+ return ddsp.core.fft_convolve(track[tf.newaxis, :], impulse_response)[0, :]
melody_synth/non_random_LFOs.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor, tensor
5
+
6
+ from torchsynth.config import SynthConfig
7
+ from torchsynth.module import (
8
+ VCO, LFO, ModulationMixer,
9
+ )
10
+ from torchsynth.signal import Signal
11
+ from torchsynth.synth import AbstractSynth
12
+
13
+
14
+ class SinLFO(LFO):
15
+ """A LFO that generates the sine waveform.
16
+ (The implementation of this class is a modification of the code in TorchSynth) """
17
+
18
+ def output(self, mod_signal: Optional[Signal] = None) -> Signal:
19
+ # This module accepts signals at control rate!
20
+ if mod_signal is not None:
21
+ assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
22
+
23
+ frequency = self.make_control(mod_signal)
24
+ argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
25
+ argument = argument + self.p("initial_phase").unsqueeze(1)
26
+
27
+ shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
28
+
29
+ mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
30
+ mode[0] = tensor([1.0, 0., 0., 0., 0.])
31
+ mode = torch.pow(mode, self.exponent)
32
+ mode = mode / torch.sum(mode, dim=1, keepdim=True)
33
+ return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
34
+
35
+
36
+ class TriLFO(LFO):
37
+ """A LFO that generates the triangle waveform.
38
+ (The implementation of this class is a modification of the code in TorchSynth) """
39
+
40
+ def output(self, mod_signal: Optional[Signal] = None) -> Signal:
41
+ # This module accepts signals at control rate!
42
+ if mod_signal is not None:
43
+ assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
44
+
45
+ frequency = self.make_control(mod_signal)
46
+ argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
47
+ argument = argument + self.p("initial_phase").unsqueeze(1)
48
+
49
+ shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
50
+
51
+ mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
52
+ mode[0] = tensor([0.5, 0.5, 0., 0., 0.])
53
+ mode = torch.pow(mode, self.exponent)
54
+ mode = mode / torch.sum(mode, dim=1, keepdim=True)
55
+ return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
56
+
57
+
58
+ class SawLFO(LFO):
59
+ """A LFO that generates the sawtooth waveform.
60
+ (The implementation of this class is a modification of the code in TorchSynth) """
61
+
62
+ def output(self, mod_signal: Optional[Signal] = None) -> Signal:
63
+ # This module accepts signals at control rate!
64
+ if mod_signal is not None:
65
+ assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
66
+
67
+ frequency = self.make_control(mod_signal)
68
+ argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
69
+ argument = argument + self.p("initial_phase").unsqueeze(1)
70
+
71
+ shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
72
+
73
+ mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
74
+ mode[0] = tensor([0.5, 0., 0.5, 0., 0.])
75
+ mode = torch.pow(mode, self.exponent)
76
+ mode = mode / torch.sum(mode, dim=1, keepdim=True)
77
+ return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
78
+
79
+
80
+ class RSawLFO(LFO):
81
+ """A LFO that generates the sawtooth waveform.
82
+ (The implementation of this class is a modification of the code in TorchSynth) """
83
+
84
+ def output(self, mod_signal: Optional[Signal] = None) -> Signal:
85
+ # This module accepts signals at control rate!
86
+ if mod_signal is not None:
87
+ assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
88
+
89
+ frequency = self.make_control(mod_signal)
90
+ argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
91
+ argument = argument + self.p("initial_phase").unsqueeze(1)
92
+
93
+ shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
94
+
95
+ mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
96
+ mode[0] = tensor([0.5, 0., 0.0, 0.5, 0.])
97
+ mode = torch.pow(mode, self.exponent)
98
+ mode = mode / torch.sum(mode, dim=1, keepdim=True)
99
+ return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
100
+
101
+
102
+ class SquareLFO(LFO):
103
+ """A LFO that generates the square waveform.
104
+ (The implementation of this class is a modification of the code in TorchSynth) """
105
+
106
+ def output(self, mod_signal: Optional[Signal] = None) -> Signal:
107
+ # This module accepts signals at control rate!
108
+ if mod_signal is not None:
109
+ assert mod_signal.shape == (self.batch_size, self.control_buffer_size)
110
+
111
+ frequency = self.make_control(mod_signal)
112
+ argument = torch.cumsum(2 * torch.pi * frequency / self.control_rate, dim=1)
113
+ argument = argument + self.p("initial_phase").unsqueeze(1)
114
+
115
+ shapes = torch.stack(self.make_lfo_shapes(argument), dim=1).as_subclass(Signal)
116
+
117
+ mode = torch.stack([self.p(lfo) for lfo in self.lfo_types], dim=1)
118
+ mode[0] = tensor([0.5, 0., 0., 0., 0.5])
119
+ mode = torch.pow(mode, self.exponent)
120
+ mode = mode / torch.sum(mode, dim=1, keepdim=True)
121
+ return torch.matmul(mode.unsqueeze(1), shapes).squeeze(1).as_subclass(Signal)
melody_synth/random_duration.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configurations.read_configuration import midi_parameter_range, midi_is_discrete
2
+ from data_generation.encoding import ParameterDescription
3
+
4
+
5
+ def get_full_duration(midi):
6
+ """Uses "full_duration" strategy to generate random duration (see Readme)."""
7
+ n = len(midi)
8
+ (time_starting_point, pitch) = midi[n - 1]
9
+ new_midi = [(time_starting_point, pitch, 0.5)]
10
+ next_time_starting_point = time_starting_point
11
+ for i in range(n - 1):
12
+ (time_starting_point, pitch) = midi[n - 2 - i]
13
+ duration = (next_time_starting_point - time_starting_point) * 0.9
14
+ new_midi.insert(0, (time_starting_point, pitch, duration))
15
+ next_time_starting_point = time_starting_point
16
+
17
+ return new_midi
18
+
19
+
20
+ def get_random_duration(midi):
21
+ """Uses "random_duration" strategy to generate random duration (see Readme)."""
22
+ parameterDescription = ParameterDescription(name="duration",
23
+ values=midi_parameter_range('duration'),
24
+ discrete=midi_is_discrete('duration'))
25
+ n = len(midi)
26
+ new_midi = []
27
+ for i in range(n):
28
+ (location, pitch) = midi[i]
29
+ duration = float(parameterDescription.generate().value)
30
+ new_midi.append((location, pitch, duration))
31
+ return new_midi
32
+
33
+
34
+ def get_fixed_duration(midi):
35
+ return [(location, pitch, 2.0) for (location, pitch) in midi]
36
+
37
+
38
+ def get_limited_random_duration(midi):
39
+ """Uses "limited_random_duration" strategy to generate random duration (see Readme)."""
40
+ parameterDescription = ParameterDescription(name="duration",
41
+ values=midi_parameter_range('duration'),
42
+ discrete=midi_is_discrete('duration'))
43
+ n = len(midi)
44
+ (time_starting_point, pitch) = midi[n - 1]
45
+ duration = float(parameterDescription.generate().value)
46
+ new_midi = [(time_starting_point, pitch, duration)]
47
+ next_time_starting_point = time_starting_point
48
+ for i in range(n - 1):
49
+ (time_starting_point, pitch) = midi[n - 2 - i]
50
+ max_duration = (next_time_starting_point - time_starting_point) * 0.9
51
+ duration = float(parameterDescription.generate().value)
52
+ duration = min(duration, max_duration)
53
+ new_midi.insert(0, (time_starting_point, pitch, duration))
54
+ next_time_starting_point = time_starting_point
55
+
56
+ return new_midi
57
+
58
+
59
+ class RandomDuration:
60
+ """Third component in the random midi pipeline responsible for random duration (keyboard hold time) generating"""
61
+
62
+ def __call__(self, strategy: str, midi, *args, **kwargs):
63
+ """Choose required strategy to generate random duration for each note.
64
+
65
+ Parameters
66
+ ----------
67
+ strategy: str
68
+ Strategy names for random duration (see Readme).
69
+ midi: List[(float, float)]
70
+ Random rhythm and pitch from previous pipeline component.
71
+
72
+ Returns
73
+ -------
74
+ midi: List[(float, float, float)]
75
+ Original input list with duration assigned to each note onset.
76
+ """
77
+
78
+ if strategy == 'random_duration':
79
+ midi = get_random_duration(midi)
80
+ elif strategy == 'limited_random_duration':
81
+ midi = get_limited_random_duration(midi)
82
+ elif strategy == 'fixed_duration':
83
+ midi = get_fixed_duration(midi)
84
+ else:
85
+ midi = get_full_duration(midi)
86
+ return midi
melody_synth/random_midi.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from configurations.read_configuration import get_conf_sample_rate, get_conf_stft_hyperparameter,\
4
+ midi_parameter_range, get_conf_time_resolution, get_conf_max_n_notes
5
+ from melody_synth.random_duration import RandomDuration
6
+ from melody_synth.random_pitch import RandomPitch
7
+ from melody_synth.random_rhythm import RandomRhythm
8
+
9
+
10
+ class RandomMidi:
11
+ """Pipeline generating random midi"""
12
+
13
+ def __init__(self):
14
+ self.randomRhythm = RandomRhythm()
15
+ self.randomPitch = RandomPitch()
16
+ self.randomDuration = RandomDuration()
17
+ self.max_n_notes = get_conf_max_n_notes()
18
+
19
+ def __call__(self, strategy=None, *args, **kwargs):
20
+ """Assembles the pipeline based on given strategies and return random midi.
21
+
22
+ Parameters
23
+ ----------
24
+ strategy: Dict[str, str]
25
+ Strategies names for random rhythm, pitch and duration generation (see Readme).
26
+
27
+ Returns
28
+ -------
29
+ encode, midi: List[int], List[(float, float, float)]
30
+ encode -- Midi's label as a list of 0s and 1s
31
+ midi -- A list of (onset, pitch, duration) tuples, each tuple refers to a note
32
+ """
33
+
34
+ if strategy is None:
35
+ strategy = {"rhythm_strategy": "non-test",
36
+ "pitch_strategy": "random_major",
37
+ "duration_strategy": "limited_random",
38
+ }
39
+
40
+ midi = self.randomRhythm(strategy["rhythm_strategy"])
41
+ midi = self.randomPitch(strategy["pitch_strategy"], midi)
42
+ midi = self.randomDuration(strategy["duration_strategy"], midi)
43
+
44
+ return self.get_encode(midi), midi
45
+
46
+ def get_encode(self, midi):
47
+ """Generate labels for midi
48
+
49
+ Parameters
50
+ ----------
51
+ midi: List[(onset, pitch, duration)]
52
+ A list of (onset, pitch, duration) tuples, each tuple refers to a note
53
+
54
+ Returns
55
+ -------
56
+ encode: List[int]
57
+ Midi's label as a list of 0s and 1s
58
+
59
+ Encoding method
60
+ -------
61
+ One-hot Encoding for each note. Stack all note labels to form midi label.
62
+ """
63
+ duration_range = midi_parameter_range("duration")
64
+ pitch_range = midi_parameter_range("pitch")
65
+ time_resolution = get_conf_time_resolution()
66
+
67
+ pixel_duration = get_conf_stft_hyperparameter()["frame_step"] / get_conf_sample_rate()
68
+ single_note_encode_length = (time_resolution + len(pitch_range) + len(duration_range))
69
+ encode_length = single_note_encode_length * self.max_n_notes
70
+ encode = []
71
+ for i in range(len(midi)):
72
+ (location, pitch, duration) = midi[i]
73
+
74
+ location_index = int(float(location) / pixel_duration)
75
+ if location_index >= time_resolution:
76
+ break
77
+ pitch_index = pitch - pitch_range[0]
78
+ duration_index = np.argmin(np.abs(np.array(duration_range) - duration))
79
+
80
+ single_note_encode = np.zeros(single_note_encode_length)
81
+ single_note_encode[location_index] = 1
82
+ single_note_encode[time_resolution + pitch_index] = 1
83
+ single_note_encode[time_resolution + len(pitch_range) + duration_index] = 1
84
+ encode = np.hstack([encode, single_note_encode])
85
+
86
+ return np.hstack([encode, np.zeros(encode_length)])[:encode_length]
melody_synth/random_pitch.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from configurations.read_configuration import midi_parameter_range
4
+
5
+
6
+ class RandomPitch:
7
+ """Second component in the random midi pipeline responsible for random rhythm (note onsets) generating"""
8
+
9
+ def __init__(self):
10
+ self.pitch_range = midi_parameter_range("pitch")
11
+ self.major = [0, 2, 4, 5, 7, 9, 11]
12
+ self.minor = [0, 2, 3, 5, 7, 8, 10]
13
+
14
+ def __call__(self, strategy: str, onsets, *args, **kwargs):
15
+ """Choose required strategy to generate random pitch for each note.
16
+
17
+ Parameters
18
+ ----------
19
+ strategy: str
20
+ Strategy names for random pitches (see Readme).
21
+ onsets: List[float]
22
+ Random rhythm from previous pipeline component.
23
+
24
+ Returns
25
+ -------
26
+ midi: List[(float, float)]
27
+ Original input list with pitches assigned to each note onset.
28
+ """
29
+
30
+ if strategy == 'random_major':
31
+ return self.get_random_major(onsets)
32
+ elif strategy == 'random_minor':
33
+ return self.get_random_minor(onsets)
34
+ elif strategy == 'fixed_pitch':
35
+ return self.get_fixed_pitch(onsets)
36
+ elif strategy == 'fixed_pitch1':
37
+ return self.get_fixed_pitch1(onsets)
38
+ elif strategy == 'fixed_pitch2':
39
+ return self.get_fixed_pitch2(onsets)
40
+ elif strategy == 'fixed_pitch3':
41
+ return self.get_fixed_pitch3(onsets)
42
+ elif strategy == 'fixed_pitch4':
43
+ return self.get_fixed_pitch4(onsets)
44
+ else:
45
+ return self.get_random_pitch(onsets)
46
+
47
+ def get_random_major(self, midi):
48
+ """Uses "random_major" strategy to generate random pitches (see Readme)."""
49
+ random_scale = np.random.randint(0, 12)
50
+ scale = [one for one in self.pitch_range if (one - random_scale) % 12 in self.major]
51
+ midi = [(duration, scale[np.random.randint(0, len(scale))]) for duration in midi]
52
+ # midi[0] = (midi[0][0], random_scale + self.pitch_range[-1])
53
+ midi[len(midi) - 1] = (midi[len(midi) - 1][0], random_scale + self.pitch_range[0])
54
+ return midi
55
+
56
+ def get_random_pitch(self, midi):
57
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
58
+ return [(duration, np.random.randint(self.pitch_range[0], self.pitch_range[-1])) for duration in midi]
59
+
60
+ def get_fixed_pitch(self, midi):
61
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
62
+ return [(duration, 48) for duration in midi]
63
+
64
+ def get_fixed_pitch1(self, midi):
65
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
66
+ return [(duration, 55) for duration in midi]
67
+
68
+ def get_fixed_pitch2(self, midi):
69
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
70
+ return [(duration, 62) for duration in midi]
71
+
72
+ def get_fixed_pitch3(self, midi):
73
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
74
+ return [(duration, 69) for duration in midi]
75
+
76
+ def get_fixed_pitch4(self, midi):
77
+ """Uses "free_pitch" strategy to generate random pitches (see Readme)."""
78
+ return [(duration, 76) for duration in midi]
79
+
80
+ def get_random_minor(self, midi):
81
+ """Uses "random_minor" strategy to generate random pitches (see Readme)."""
82
+ random_scale = np.random.randint(0, 12)
83
+ scale = [one for one in self.pitch_range if (one - random_scale) % 12 in self.minor]
84
+ midi = [(duration, scale[np.random.randint(0, len(scale))]) for duration in midi]
85
+ # midi[0] = (midi[0][0], random_scale + self.pitch_range[-1])
86
+ midi[len(midi) - 1] = (midi[len(midi) - 1][0], random_scale + self.pitch_range[0])
87
+ return midi
melody_synth/random_rhythm.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+
5
+ from configurations.read_configuration import get_conf_n_sample, get_conf_sample_rate, get_conf_max_n_notes
6
+
7
+
8
+ def get_random_note_type_index(distribution):
9
+ """A helper method that randomly chooses next note type based on a distribution
10
+
11
+ Parameters
12
+ ----------
13
+ distribution: List[float]
14
+ Note type distribution.
15
+
16
+ Returns
17
+ -------
18
+ midi: int
19
+ Random type index.
20
+ """
21
+
22
+ r = np.random.random()
23
+ for i in range(len(distribution)):
24
+ r = r - distribution[i]
25
+ if r < 0:
26
+ return i
27
+ return len(distribution) - 1
28
+
29
+
30
+ # Todo: rewrite this part
31
+ def to_onsets_in_seconds(bpm, notes):
32
+ """A helper method that transform a list of note types into a list of note onsets (in seconds)
33
+
34
+ Parameters
35
+ ----------
36
+ bpm: float
37
+ BPM
38
+ notes: List[int]
39
+
40
+
41
+ Returns
42
+ -------
43
+ midi: int
44
+ Random type index.
45
+ """
46
+
47
+ full_note_length = 4 * 60 / bpm
48
+ onsets = [0]
49
+ for i in range(len(notes)):
50
+ onsets.append(onsets[i] + full_note_length * notes[i])
51
+ return onsets
52
+
53
+
54
+ class RandomRhythm:
55
+ """First component in the random midi pipeline responsible for random rhythm (note onsets) generating"""
56
+
57
+ def __init__(self):
58
+ self.note_types = [0, 1, 3 / 4, 0.5, 3 / 8, 0.25, 1 / 8]
59
+ self.first_note_type_distribution = np.array([0, 0.2, 0.05, 0.25, 0.05, 0.3, 0.15])
60
+ self.rhythm_generation_matrix = np.array([
61
+ [0.1, 0.1, 0.25, 0.1, 0.25, 0.2],
62
+ [0.05, 0.25, 0.25, 0.05, 0.3, 0.1],
63
+ [0.1, 0.1, 0.3, 0.05, 0.35, 0.1],
64
+ [0.05, 0.05, 0.2, 0.2, 0.25, 0.25],
65
+ [0.1, 0.05, 0.1, 0.05, 0.4, 0.3],
66
+ [0.1, 0.05, 0.1, 0.1, 0.3, 0.35],
67
+ ])
68
+ # self.bpm = bpm
69
+ self.rhythm_duration = np.array([0, 1, 3 / 4, 0.5, 3 / 8, 0.25])
70
+ self.audio_length = get_conf_n_sample() / get_conf_sample_rate()
71
+ self.bpm_range = [90, 100, 110, 120, 130, 140, 150, 160, 170]
72
+ self.max_n_notes = get_conf_max_n_notes()
73
+
74
+ def __call__(self, strategy: str, *args, **kwargs):
75
+ """Choose required strategy to generate random rhythm (note onsets).
76
+
77
+ Parameters
78
+ ----------
79
+ strategy: str
80
+ Strategy names for random rhythm (see Readme).
81
+
82
+ Returns
83
+ -------
84
+ onsets: List[float]
85
+ A list of floats referring to note onsets in seconds.
86
+ """
87
+ if strategy == 'bpm_based_rhythm':
88
+ rhythm = self.get_bpm_based_rhythm()
89
+ elif strategy == 'free_rhythm':
90
+ rhythm = self.get_free_rhythm()
91
+ elif strategy == 'single_note_rhythm':
92
+ rhythm = self.get_single_note()
93
+ else:
94
+ rhythm = [0.0, 1, 2, 3, 4]
95
+
96
+ return rhythm[:self.max_n_notes]
97
+
98
+ def get_bpm_based_rhythm(self):
99
+ """Uses "bpm_based_rhythm" strategy to generate random rhythm (see Readme)."""
100
+ # Todo: clean up this part
101
+
102
+ bpm = random.choice(self.bpm_range)
103
+
104
+ first_note = get_random_note_type_index(self.first_note_type_distribution)
105
+ note_type_indexes = [first_note]
106
+ current_note_type = first_note
107
+ while True:
108
+ current_note_type = get_random_note_type_index(self.rhythm_generation_matrix[current_note_type - 1]) + 1
109
+ note_type_indexes.append(current_note_type)
110
+
111
+ # Random early stop
112
+ if np.random.random() < 9 / bpm:
113
+ break
114
+
115
+ notes = [self.note_types[note_type_index] for note_type_index in note_type_indexes]
116
+
117
+ onsets = to_onsets_in_seconds(bpm, notes)
118
+ return onsets
119
+
120
+ def get_free_rhythm(self):
121
+ """Uses "free_rhythm" strategy to generate random rhythm (see Readme)."""
122
+ n_notes = np.random.randint(int(self.max_n_notes * 0.6), self.max_n_notes)
123
+ # n_notes = np.random.randint(int(1), self.max_n_notes)
124
+
125
+ onsets = np.random.rand(n_notes)
126
+ onsets.sort()
127
+
128
+ # Avoid notes too close together
129
+ pre = onsets[0]
130
+ n_removed = 0
131
+ for i in range(len(onsets)-1):
132
+ index = i - n_removed + 1
133
+ if (onsets[index] - pre) < 0.05:
134
+ new_onsets = np.delete(onsets, index)
135
+ onsets = new_onsets
136
+ n_removed = n_removed + 1
137
+ else:
138
+ pre = onsets[index]
139
+
140
+ return ((onsets - onsets[0])*self.audio_length).tolist()
141
+
142
+ def get_single_note(self):
143
+ return [0.0]
model/VAE.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras import backend as K
3
+ from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Reshape, Lambda, BatchNormalization
4
+ from keras.models import Model
5
+ import numpy as np
6
+ import threading
7
+
8
+ KL = tf.keras.layers
9
+
10
+
11
+ def cbam_layer(inputs_tensor=None, ratio=None):
12
+ """Source: https://blog.csdn.net/ZXF_1991/article/details/104615942
13
+ The channel attention
14
+ """
15
+ channels = K.int_shape(inputs_tensor)[-1]
16
+
17
+ def share_layer(inputs=None):
18
+ x_ = KL.Conv2D(channels // ratio, (1, 1), strides=1, padding="valid")(inputs)
19
+ x_ = KL.Activation('relu')(x_)
20
+ output_share = KL.Conv2D(channels, (1, 1), strides=1, padding="valid")(x_)
21
+ return output_share
22
+
23
+ x_global_avg_pool = KL.GlobalAveragePooling2D()(inputs_tensor)
24
+ x_global_avg_pool = KL.Reshape((1, 1, channels))(x_global_avg_pool)
25
+ x_global_max_pool = KL.GlobalMaxPool2D()(inputs_tensor)
26
+ x_global_max_pool = KL.Reshape((1, 1, channels))(x_global_max_pool)
27
+ x_global_avg_pool = share_layer(x_global_avg_pool)
28
+ x_global_max_pool = share_layer(x_global_max_pool)
29
+ x = KL.Add()([x_global_avg_pool, x_global_max_pool])
30
+ x = KL.Activation('sigmoid')(x)
31
+ CAM = KL.multiply([inputs_tensor, x])
32
+ output = CAM
33
+ return output
34
+
35
+
36
+ def res_cell(x, n_channel=64, stride=1):
37
+ """The basic unit in the VAE, cell."""
38
+ if stride == -1:
39
+ # upsample cell
40
+ skip = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
41
+ skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=1, padding='same')(skip)
42
+ x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), strides=2, padding='same')(x)
43
+ x = BatchNormalization()(x)
44
+ x = tf.keras.activations.elu(x)
45
+ x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
46
+
47
+ elif stride == 2:
48
+ # downsample cell
49
+ skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=2, padding='same')(x)
50
+ x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x)
51
+ x = BatchNormalization()(x)
52
+ x = tf.keras.activations.elu(x)
53
+ x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
54
+ else:
55
+ # preserving cell
56
+ skip = tf.identity(x)
57
+ x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x)
58
+ x = BatchNormalization()(x)
59
+ x = tf.keras.activations.elu(x)
60
+ x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x)
61
+
62
+ x = BatchNormalization()(x)
63
+ x = cbam_layer(inputs_tensor=x, ratio=8)
64
+ x = x + skip
65
+ x = tf.keras.activations.elu(x)
66
+ return x
67
+
68
+
69
+ def res_block(x, n_channel=64, upsample=False, n_cells=2):
70
+ """The block is a stack of cells."""
71
+ if upsample:
72
+ x = res_cell(x, n_channel=n_channel, stride=-1)
73
+ else:
74
+ x = res_cell(x, n_channel=n_channel, stride=2)
75
+ for _ in range(n_cells - 1):
76
+ x = res_cell(x, n_channel=n_channel, stride=1)
77
+ return x
78
+
79
+
80
+ def l1_distance(x1, x2):
81
+ return tf.reduce_mean(tf.math.abs(x1 - x2))
82
+
83
+
84
+ def l1_log_distance(x1, x2):
85
+ return tf.reduce_mean(tf.math.abs(tf.math.log(tf.maximum(1e-6, x1)) - tf.math.log(tf.maximum(1e-6, x2))))
86
+
87
+
88
+ img_height = 512
89
+ img_width = 256
90
+ num_channels = 1
91
+ input_shape = (img_height, img_width, num_channels)
92
+ timbre_dim = 20
93
+ n_filters = 64
94
+ act = 'elu'
95
+
96
+
97
+ def compute_latent(x):
98
+ """Re-parameterizing."""
99
+ mu, sigma = x
100
+ batch = K.shape(mu)[0]
101
+ dim = K.int_shape(mu)[1]
102
+ eps = K.random_normal(shape=(batch, dim))
103
+ return mu + K.exp(sigma / 2) * eps
104
+
105
+
106
+ def get_encoder(N2=0, channel_sizes=None):
107
+ """Assemble and return the VAE encoder."""
108
+ if channel_sizes is None:
109
+ channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
110
+ encoder_input = Input(shape=input_shape)
111
+
112
+ encoder_conv = res_block(encoder_input, channel_sizes[0], upsample=False, n_cells=1)
113
+
114
+ for c in channel_sizes[1:]:
115
+ encoder_conv = res_block(encoder_conv, c, upsample=False, n_cells=1 + N2)
116
+
117
+ encoder = Flatten()(encoder_conv)
118
+
119
+ mu_timbre = Dense(timbre_dim)(encoder)
120
+ sigma_timbre = Dense(timbre_dim)(encoder)
121
+ latent_vector = Lambda(compute_latent, output_shape=(timbre_dim,))([mu_timbre, sigma_timbre])
122
+
123
+ kl_loss = -0.5 * (1 + sigma_timbre - tf.square(mu_timbre) - tf.exp(sigma_timbre))
124
+ kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
125
+
126
+ encoder = Model(encoder_input, [latent_vector, kl_loss])
127
+ return encoder
128
+
129
+
130
+ def get_decoder(N2=0, N3=8, channel_sizes=None):
131
+ """Assemble and return the VAE decoder."""
132
+ if channel_sizes is None:
133
+ channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
134
+ conv_shape = [-1, 2 ** (9 - N3), 2 ** (8 - N3), channel_sizes[-1]]
135
+ decoder_input = Input(shape=(timbre_dim,))
136
+
137
+ decoder = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation=act)(decoder_input)
138
+ decoder_conv = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(decoder)
139
+
140
+ for c in list(reversed(channel_sizes))[1:]:
141
+ decoder_conv = res_block(decoder_conv, c, upsample=True, n_cells=1 + N2)
142
+
143
+ decoder_conv = Conv2DTranspose(filters=num_channels, kernel_size=5, strides=2,
144
+ padding='same', activation='sigmoid')(decoder_conv)
145
+
146
+ decoder = Model(decoder_input, decoder_conv)
147
+ return decoder
148
+
149
+
150
+ def VAE(N2=0, N3=8, channel_sizes=None):
151
+ """Assemble and return the VAE."""
152
+ if channel_sizes is None:
153
+ channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
154
+ print("Creating model...")
155
+ assert N2 >= 0, "Please set N2 >= 0"
156
+ assert N3 >= 1, "Please set 1 <= N3 <= 8"
157
+ assert N3 <= 8, "Please set 1 <= N3 <= 8"
158
+ assert N3 == len(channel_sizes), "Please set N3 = len(channel_sizes)"
159
+ encoder = get_encoder(N2, channel_sizes)
160
+ decoder = get_decoder(N2, N3, channel_sizes)
161
+
162
+ # encoder = tf.keras.models.load_model(f"encoder_thesis_record_1.h5")
163
+ # decoder = tf.keras.models.load_model(f"decoder_thesis_record_1.h5")
164
+
165
+ encoder_input1 = Input(shape=input_shape)
166
+ scalar_input1 = Input(shape=(1,))
167
+
168
+ embedding_1_timbre, kl_loss = encoder(encoder_input1)
169
+ reconstruction_1 = decoder(embedding_1_timbre)
170
+
171
+ VAE = Model([encoder_input1, scalar_input1], [reconstruction_1, kl_loss])
172
+ # decoder.summary()
173
+ VAE.summary()
174
+ return encoder, decoder, VAE
175
+
176
+
177
+ def my_thread(data_cache):
178
+ data_cache.refresh()
179
+
180
+
181
+ def train_VAE(vae, encoder, decoder, data_cache, stages, batch_size):
182
+ """Train the VAE.
183
+
184
+ Parameters
185
+ ----------
186
+ vae: keras.engine.functional.Functional
187
+ The VAE.
188
+ encoder: keras.engine.functional.Functional
189
+ The VAE encoder.
190
+ decoder: keras.engine.functional.Functional
191
+ The VAE decoder.
192
+ data_cache: Data_cache
193
+ A Data_cache entity that provides training data.
194
+ stages: Dict
195
+ Defines the training stages. In each stage, the synthetic data will be refreshed and
196
+ models will be stored once.
197
+
198
+ Returns
199
+ -------
200
+ """
201
+ threshold = 1e-0
202
+ kl_weight = 100.0
203
+
204
+ def weighted_binary_cross_entropy_loss(true, pred):
205
+ b_n = true * tf.math.log(tf.maximum(1e-20, pred)) + (1 - true) * tf.math.log(tf.maximum(1e-20, 1 - pred))
206
+ w = tf.maximum(threshold, true)
207
+ return -tf.reduce_sum(b_n / w) / batch_size
208
+
209
+ def reconstruction_loss(true, pred):
210
+ reconstruction_loss = weighted_binary_cross_entropy_loss(K.flatten(true), K.flatten(pred))
211
+ return K.mean(reconstruction_loss)
212
+
213
+ def kl_loss(true, pred):
214
+ return pred * kl_weight
215
+
216
+ for stage in stages:
217
+ threshold = stage["threshold"]
218
+ kl_weight = stage["kl_weight"]
219
+ vae.compile(tf.keras.optimizers.Adam(learning_rate=stage["learning_rate"]), loss=[reconstruction_loss, kl_loss])
220
+
221
+ Input_all = data_cache.get_all_data()
222
+ n_total = np.shape(Input_all)[0]
223
+
224
+ t = threading.Thread(target=my_thread, args=(data_cache,))
225
+ t.start()
226
+ history = vae.fit([Input_all, np.ones(n_total)], [Input_all, np.ones(n_total)], epochs=stage["n_epoch"],
227
+ batch_size=batch_size)
228
+ t.join()
229
+ encoder.save(f"./models/new_trained_models/encoder.h5")
230
+ decoder.save(f"./models/new_trained_models/decoder.h5")
model/VAE_torchV.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ChannelAttention(nn.Module):
6
+ def __init__(self, in_planes, ratio=16):
7
+ super(ChannelAttention, self).__init__()
8
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
9
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
10
+
11
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
12
+ self.relu1 = nn.ReLU()
13
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
14
+
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
19
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
20
+ y = avg_out + max_out
21
+ y = self.sigmoid(y)
22
+
23
+ return x * y.expand_as(x)
24
+
25
+
26
+ class ResCell(nn.Module):
27
+ def __init__(self, input_channel, output_channel, stride=1):
28
+ super(ResCell, self).__init__()
29
+
30
+ self.stride = stride
31
+ self.input_channel = input_channel
32
+ self.output_channel = output_channel
33
+
34
+ if self.stride == -1:
35
+ output_size = ()
36
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
37
+ self.skip = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=1, stride=1, padding=0)
38
+ self.conv1 = nn.ConvTranspose2d(self.input_channel, self.output_channel, kernel_size=5, stride=2, padding=2, output_padding=1)
39
+ self.conv2 = nn.ConvTranspose2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
40
+
41
+ elif self.stride == 2:
42
+ self.skip = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=1, stride=2, padding=0)
43
+ self.conv1 = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=5, stride=self.stride, padding=2)
44
+ self.conv2 = nn.Conv2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
45
+
46
+ else:
47
+ self.conv1 = nn.Conv2d(self.input_channel, self.output_channel, kernel_size=5, stride=self.stride, padding=2)
48
+ self.conv2 = nn.Conv2d(self.output_channel, self.output_channel, kernel_size=5, padding=2)
49
+
50
+ self.bn1 = nn.BatchNorm2d(self.output_channel)
51
+ self.bn2 = nn.BatchNorm2d(self.output_channel)
52
+
53
+ # Please replace `CBAM` with the actual module and parameters
54
+ self.cbam = ChannelAttention(self.output_channel)
55
+
56
+ def forward(self, x):
57
+ if self.stride == -1:
58
+ upsampled_x = self.upsample(x)
59
+ skip = self.skip(upsampled_x)
60
+ x = F.elu(self.bn1(self.conv1(x)))
61
+ x = self.conv2(x)
62
+ elif self.stride == 2:
63
+ skip = self.skip(x)
64
+ x = F.elu(self.bn1(self.conv1(x)))
65
+ x = self.conv2(x)
66
+ else:
67
+ skip = x
68
+ x = F.elu(self.bn1(self.conv1(x)))
69
+ x = self.conv2(x)
70
+
71
+ x = self.bn2(x)
72
+ x = self.cbam(x)
73
+ x = x + skip
74
+ x = F.elu(x)
75
+
76
+ return x
77
+
78
+
79
+ class ResBlock(nn.Module):
80
+ def __init__(self, input_channel, output_channel, upsample=False, n_cells=2):
81
+ super(ResBlock, self).__init__()
82
+
83
+ stride = -1 if upsample else 2
84
+ self.cells = nn.ModuleList([ResCell(input_channel, output_channel, stride=stride)])
85
+
86
+ for _ in range(n_cells - 1):
87
+ self.cells.append(ResCell(input_channel, output_channel, stride=1))
88
+
89
+ def forward(self, x):
90
+ for cell in self.cells:
91
+ x = cell(x)
92
+ return x
93
+
94
+
95
+ class Encoder(nn.Module):
96
+ def __init__(self, input_shape, timbre_dim, N2=0, channel_sizes=None):
97
+ super(Encoder, self).__init__()
98
+
99
+ if channel_sizes is None:
100
+ channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
101
+
102
+ self.input_shape = input_shape
103
+ self.timbre_dim = timbre_dim
104
+ self.blocks = nn.ModuleList()
105
+
106
+ self.blocks.append(ResBlock(input_channel=1, output_channel=channel_sizes[0], upsample=False, n_cells=1))
107
+ input_channel = channel_sizes[0]
108
+
109
+ for c in channel_sizes[1:]:
110
+ self.blocks.append(ResBlock(input_channel=input_channel, output_channel=c, upsample=False, n_cells=1 + N2))
111
+ input_channel = c
112
+
113
+ self.flatten = nn.Flatten()
114
+ self.mu_timbre = nn.Linear(self._get_flattened_dim(), timbre_dim)
115
+ self.sigma_timbre = nn.Linear(self._get_flattened_dim(), timbre_dim)
116
+
117
+ def _get_flattened_dim(self):
118
+ x = torch.zeros((1,) + self.input_shape)
119
+ for block in self.blocks:
120
+ x = block(x)
121
+ x = self.flatten(x)
122
+ return x.shape[1]
123
+
124
+ def reparameterize(self, mu, logvar):
125
+ std = torch.exp(0.5*logvar)
126
+ eps = torch.randn_like(std)
127
+ return mu + eps*std
128
+
129
+ def forward(self, x):
130
+ for block in self.blocks:
131
+ x = block(x)
132
+
133
+ x = self.flatten(x)
134
+ mu = self.mu_timbre(x)
135
+ logvar = self.sigma_timbre(x)
136
+ latent_vector = self.reparameterize(mu, logvar)
137
+
138
+ # kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
139
+ # kl_loss = torch.mean(kl_loss)
140
+
141
+ return mu, logvar, latent_vector
142
+
143
+
144
+ class Decoder(nn.Module):
145
+ def __init__(self, timbre_dim, N2=0, N3=8, channel_sizes=None):
146
+ super(Decoder, self).__init__()
147
+
148
+ if channel_sizes is None:
149
+ channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216]
150
+
151
+ self.conv_shape = [-1, channel_sizes[-1], 2 ** (9 - N3), 2 ** (8 - N3)]
152
+
153
+ self.dense = nn.Linear(timbre_dim, self.conv_shape[1] * self.conv_shape[2] * self.conv_shape[3])
154
+ self.blocks = nn.ModuleList()
155
+
156
+ input_channel = channel_sizes[-1]
157
+ for c in list(reversed(channel_sizes))[1:]:
158
+ self.blocks.append(ResBlock(input_channel=input_channel, output_channel=c, upsample=True, n_cells=1 + N2))
159
+ input_channel = c
160
+
161
+ self.decoder_conv = nn.ConvTranspose2d(channel_sizes[0], 1, kernel_size=5, stride=2, padding=2, output_padding=1)
162
+
163
+ def forward(self, x):
164
+ x = F.elu(self.dense(x))
165
+ x = x.view(-1, self.conv_shape[1], self.conv_shape[2], self.conv_shape[3])
166
+ for block in self.blocks:
167
+ x = block(x)
168
+
169
+ x = self.decoder_conv(x)
170
+ x = torch.sigmoid(x)
171
+ return x
model/perceptual_label_predictor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras.layers import Input, Dense, Dropout
3
+ from keras.models import Model
4
+ from keras.losses import binary_crossentropy
5
+ from load_data import read_data
6
+ import joblib
7
+ import numpy as np
8
+
9
+ KL = tf.keras.layers
10
+
11
+
12
+ def perceptual_label_predictor():
13
+ """Assemble and return the perceptual_label_predictor."""
14
+ mini_input = Input((20,))
15
+ p = Dense(20, activation='relu')(mini_input)
16
+ p = Dropout(0.2)(p)
17
+ p = Dense(16, activation='relu')(p)
18
+ p = Dropout(0.2)(p)
19
+ p = Dense(5, activation='sigmoid')(p)
20
+ style_predictor = Model(mini_input, p)
21
+ style_predictor.summary()
22
+ return style_predictor
23
+
24
+
25
+ def train_perceptual_label_predictor(perceptual_label_predictor, encoder):
26
+ """Train the perceptual_label_predictor. (Including data loading.)"""
27
+
28
+ Input_synthetic = read_data("./data/labeled_dataset/synthetic_data")
29
+ Input_AU = read_data("./data/external_data/ARTURIA_data")[:100]
30
+
31
+ AU_labels = joblib.load("./data/labeled_dataset/ARTURIA_labels")
32
+ synth_labels = joblib.load("./data/labeled_dataset/synthetic_data_labels")
33
+
34
+ AU_encode = encoder.predict(Input_AU)[0]
35
+ Synth_encode = encoder.predict(Input_synthetic)[0]
36
+
37
+ perceptual_label_predictor.compile(optimizer='adam', loss=binary_crossentropy)
38
+ perceptual_label_predictor.fit(np.vstack([AU_encode, Synth_encode]), np.vstack([AU_labels, synth_labels]), epochs=140, validation_split=0.05, batch_size=16)
39
+ perceptual_label_predictor.save(f"./models/new_trained_models/perceptual_label_predictor.h5")
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
models/decoder_5_13.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7c25c8e241ade613409293d0ba3b37823a129699e802e1a2e9d9bb0074b11d3
3
+ size 16884433
models/encoder_5_13.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3da8ae2c98eea048a5cf4247ed637c67edf8b121e97efd3c7564e8ed4ea9fa51
3
+ size 21627911
models/new_trained_models/perceptual_label_predictor.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c6d2c3df579658bc54e07dd76d0e9efb02831bebf5e2d8c104becac930f0071
3
+ size 46080
models/perceptual_label_predictor.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f59db5a982347ec8c40452b8f77a35432d00253a8ed347726f9b102fc4550d44
3
+ size 46176
new_sound_generation.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib
3
+ from pathlib import Path
4
+ import shutil
5
+ from tqdm import tqdm
6
+ from tools import save_results, VAE_out_put_to_spc, show_spc
7
+
8
+
9
+ def test_reconstruction(encoder, decoder, data, n_sample=5, f=0, path_name="./data/test_reconstruction", save_data=False):
10
+ """Generate and show reconstruction results. Randomly reconstruct 'n_sample' samples in 'data'.
11
+ You can manually set the index of the first reconstructed sample by 'f'.
12
+
13
+ Parameters
14
+ ----------
15
+ encoder: keras.engine.functional.Functional
16
+ The VAE encoder.
17
+ decoder: keras.engine.functional.Functional
18
+ Sample rate of the audio to generate.
19
+ data: numpy array
20
+ The VAE decoder.
21
+ n_sample: int
22
+ Number of samples to reconstruct.
23
+ f: int
24
+ Index of the first reconstructed sample.
25
+ path_name: String
26
+ Path to save the results.
27
+ save_data: bool
28
+ Whether save the results.
29
+
30
+ Returns
31
+ -------
32
+ """
33
+ if save_data:
34
+ if Path(path_name).exists():
35
+ shutil.rmtree(path_name)
36
+ Path(path_name).mkdir(parents=True, exist_ok=True)
37
+
38
+ for i in range(n_sample):
39
+ index = np.random.randint(np.shape(data)[0])
40
+ if i == 0:
41
+ index = f
42
+ print("######################################################")
43
+ print(f"index: {index}")
44
+
45
+ input = data[index]
46
+ print(f"Original:")
47
+ show_spc(VAE_out_put_to_spc(input))
48
+ if save_data:
49
+ save_results(VAE_out_put_to_spc(input), f"{path_name}/origin_{index}.png", f"{path_name}/origin_{index}.wav")
50
+
51
+ input = data[index:index + 1]
52
+ timbre_encode = encoder.predict(input)[0]
53
+
54
+ encode = timbre_encode
55
+
56
+ reconstruction = decoder.predict(encode)[0]
57
+ reconstruction = VAE_out_put_to_spc(reconstruction)
58
+ reconstruction = np.minimum(5000, reconstruction)
59
+ print(f"Reconstruction:")
60
+ show_spc(reconstruction)
61
+ if save_data:
62
+ save_results(reconstruction, f"{path_name}/reconstruction_{index}.png", f"{path_name}/reconstruction_{index}.wav")
63
+
64
+
65
+ def test_interpulation(data0, data1, encoder, decoder, path_name = "./data/test_interpolation", save_data=False):
66
+ """Generate new sounds by latent space interpolation.
67
+
68
+ Parameters
69
+ ----------
70
+ data0: numpy array
71
+ First input for interpolation.
72
+ data1: numpy array
73
+ Second input for interpolation.
74
+ encoder: keras.engine.functional.Functional
75
+ The VAE encoder.
76
+ decoder: keras.engine.functional.Functional
77
+ Sample rate of the audio to generate.
78
+ path_name: String
79
+ Path to save the results.
80
+ save_data: bool
81
+ Whether save the results.
82
+
83
+ Returns
84
+ -------
85
+ """
86
+ if save_data:
87
+ if Path(path_name).exists():
88
+ shutil.rmtree(path_name)
89
+ Path(path_name).mkdir(parents=True, exist_ok=True)
90
+
91
+ if save_data:
92
+ save_results(VAE_out_put_to_spc(data0), f"{path_name}/origin_0.png", f"{path_name}/origin_0.wav")
93
+ save_results(VAE_out_put_to_spc(data1), f"{path_name}/origin_1.png", f"{path_name}/origin_1.wav")
94
+
95
+ print("First Original:")
96
+ show_spc(VAE_out_put_to_spc(data0))
97
+ print("Second Original:")
98
+ show_spc(VAE_out_put_to_spc(data1))
99
+ print("######################################################")
100
+ print("Interpolations:")
101
+ data0 = np.reshape(data0, (1, 512, 256, 1))
102
+ data1 = np.reshape(data1, (1, 512, 256, 1))
103
+ timbre_encode0 = encoder.predict(data0)[0]
104
+ timbre_encode1 = encoder.predict(data1)[0]
105
+
106
+ n_f = 8
107
+ for i in tqdm(range(n_f+1)):
108
+ rate = 1 - i/n_f
109
+ new_timbre = rate * timbre_encode0 + (1-rate) * timbre_encode1
110
+ output = decoder.predict(new_timbre)
111
+
112
+ spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
113
+ if save_data:
114
+ save_results(spc, f"{path_name}/test_interpolation_{i}.png", f"{path_name}/test_interpolation_{i}.wav")
115
+ show_spc(spc)
116
+
117
+
118
+ def test_random_sampling(decoder, n_sample=20, mu=np.zeros(20), sigma=np.ones(20), save_data = False, path_name = "./data/test_random_sampling"):
119
+ """Generate new sounds by random sampling in the latent space.
120
+
121
+ Parameters
122
+ ----------
123
+ decoder: keras.engine.functional.Functional
124
+ Sample rate of the audio to generate.
125
+ path_name: String
126
+ Path to save the results.
127
+ save_data: bool
128
+ Whether save the results.
129
+
130
+ Returns
131
+ -------
132
+ """
133
+ if save_data:
134
+ if Path(path_name).exists():
135
+ shutil.rmtree(path_name)
136
+ Path(path_name).mkdir(parents=True, exist_ok=True)
137
+
138
+ for i in tqdm(range(n_sample)):
139
+ off_set = np.random.normal(mu,np.square(sigma))
140
+ new_timbre = np.reshape(off_set, (1,20))
141
+
142
+ output = decoder.predict(new_timbre)
143
+
144
+ spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
145
+ if save_data:
146
+ save_results(spc, f"{path_name}/random_sampling_{i}.png", f"{path_name}/random_sampling_{i}.wav")
147
+ show_spc(spc)
148
+
149
+
150
+ def test_style_transform(original, encoder, decoder, perceptual_label_predictor, n_samples=10, save_data = False, goal=0, direction=0, path_name = "./data/test_style_transform"):
151
+ """Generate new sounds by latent space interpolation.
152
+
153
+ Parameters
154
+ ----------
155
+ original: numpy array
156
+ Original for style transform.
157
+ encoder: keras.engine.functional.Functional
158
+ The VAE encoder.
159
+ decoder: keras.engine.functional.Functional
160
+ Sample rate of the audio to generate.
161
+ perceptual_label_predictor: keras.engine.functional.Functional
162
+ Model that selects the output.
163
+ path_name: String
164
+ Path to save the results.
165
+ save_data: bool
166
+ Whether save the results.
167
+
168
+ Returns
169
+ -------
170
+ """
171
+ if save_data:
172
+ if Path(path_name).exists():
173
+ shutil.rmtree(path_name)
174
+ Path(path_name).mkdir(parents=True, exist_ok=True)
175
+ save_results(VAE_out_put_to_spc(original), f"{path_name}/origin.png", f"{path_name}/origin.wav")
176
+ labels_names = ["metallic", "warm", "breathy", "evolving", "aggressiv"]
177
+ timbre_dim = 20
178
+
179
+ print("Original:")
180
+ show_spc(VAE_out_put_to_spc(original))
181
+ print("######################################################")
182
+ original_code = encoder.predict(np.reshape(original, (1,512,256,1)))[0]
183
+ new_encodes = np.zeros((n_samples, timbre_dim)) + original_code
184
+
185
+ new_encodes = [new_encode + np.random.normal(np.zeros(timbre_dim) * 0.2,np.ones(timbre_dim)) for new_encode in new_encodes]
186
+ new_encodes = np.array(new_encodes, dtype=np.float32)
187
+ perceptual_labels = perceptual_label_predictor.predict(new_encodes)[:,goal]
188
+
189
+ if direction == 0:
190
+ best_index = np.argmin(perceptual_labels)
191
+ suffix = f"less_{labels_names[goal]}"
192
+ else:
193
+ best_index = np.argmax(perceptual_labels)
194
+ suffix = f"more_{labels_names[goal]}"
195
+
196
+ output = decoder.predict(new_encodes[best_index:best_index+1])
197
+
198
+ spc = np.reshape(VAE_out_put_to_spc(output), (512,256))
199
+ if save_data:
200
+ save_results(spc, f"{path_name}/{suffix}.png", f"{path_name}/{suffix}.wav")
201
+ print("Manipulated (suffix):")
202
+ show_spc(spc)
203
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ ddsp
3
+ torchsynth
4
+ pytorch_lightning -y
5
+ pytorch_lightning==1.7.0
test_audio.wav ADDED
Binary file (522 kB). View file
 
tools.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import librosa
6
+ from scipy.io.wavfile import write
7
+
8
+ k = 1e-16
9
+
10
+
11
+ def np_log10(x):
12
+ numerator = np.log(x + 1e-16)
13
+ denominator = np.log(10)
14
+ return numerator / denominator
15
+
16
+
17
+ def sigmoid(x):
18
+ s = 1 / (1 + np.exp(-x))
19
+ return s
20
+
21
+
22
+ def inv_sigmoid(s):
23
+ x = np.log((s / (1 - s)) + 1e-16)
24
+ return x
25
+
26
+
27
+ def spc_to_VAE_input(spc):
28
+ """Restrict value range from 0 to 1."""
29
+ return spc / (1 + spc)
30
+
31
+
32
+ def VAE_out_put_to_spc(o):
33
+ """Inverse transform of function 'spc_to_VAE_input'."""
34
+ return o / (1 - o + k)
35
+
36
+
37
+ def denoise(spc):
38
+ """Filter back ground noise. (Not used.)"""
39
+ return np.maximum(0, spc - (2e-5))
40
+
41
+
42
+ hop_length = 256
43
+ win_length = 1024
44
+
45
+
46
+ def np_power_to_db(S, amin=1e-16, top_db=80.0):
47
+ """Helper method for scaling."""
48
+ ref = np.max(S)
49
+
50
+ # set fixed value for ref
51
+
52
+ # 每个元素取max
53
+ log_spec = 10.0 * np_log10(np.maximum(amin, S))
54
+ log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
55
+
56
+ log_spec = np.maximum(log_spec, np.max(log_spec) - top_db)
57
+
58
+ return log_spec
59
+
60
+
61
+ def show_spc(spc, resolution=(512, 256)):
62
+ """Show a spectrogram."""
63
+ spc = np.reshape(spc, resolution)
64
+ magnitude_spectrum = np.abs(spc)
65
+ log_spectrum = np_power_to_db(magnitude_spectrum)
66
+ plt.imshow(np.flipud(log_spectrum))
67
+ plt.show()
68
+
69
+
70
+ def save_results(spectrogram, spectrogram_image_path, waveform_path):
71
+ """Save the input 'spectrogram' and its waveform (reconstructed bu Griffin Lim)
72
+ to path provided by 'spectrogram_image_path' and 'waveform_path'."""
73
+ # save image
74
+ magnitude_spectrum = np.abs(spectrogram)
75
+ log_spc = np_power_to_db(magnitude_spectrum)
76
+ log_spc = np.reshape(log_spc, (512, 256))
77
+ matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
78
+ origin='lower')
79
+
80
+ # save waveform
81
+ abs_spec = np.zeros((513, 256))
82
+ abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
83
+ rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
84
+ write(waveform_path, 16000, rec_signal)
webUI/initial_example_encodes.json ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "random": [
3
+ 0.398047679983798,
4
+ 0.3181480556003946,
5
+ 0.4481247707840732,
6
+ 0.17181013678356805,
7
+ 0.38504974079327525,
8
+ 0.1630011861056878,
9
+ 0.2718486665735521,
10
+ 0.4338781507304229,
11
+ 0.6538380475574059,
12
+ 0.4158802639661583,
13
+ 0.23043953925285032,
14
+ 0.10156416680503988,
15
+ 0.30416174813259533,
16
+ 0.5135342367189637,
17
+ 0.5187104878467569,
18
+ 0.3526902627535946,
19
+ 0.7747258429690094,
20
+ 0.2627179357923156,
21
+ 0.9086876170530048,
22
+ 0.9271088722414674,
23
+ 0.7019110161290005,
24
+ 0.7718117435584002,
25
+ 0.36794622268993094,
26
+ 0.201057894940179
27
+ ],
28
+ "few_overtone_fading_out": [
29
+ -1.0641387701034546,
30
+ -0.22593635320663452,
31
+ -3.0761210918426514,
32
+ 0.3395313322544098,
33
+ 3.4432809352874756,
34
+ 1.2180149555206299,
35
+ -3.312405824661255,
36
+ -0.9400798082351685,
37
+ 1.4439473152160645,
38
+ -2.7892191410064697,
39
+ -4.55153751373291,
40
+ 2.2633938789367676,
41
+ -1.9029977321624756,
42
+ 0.37142419815063477,
43
+ 1.7903343439102173,
44
+ 3.5063657760620117,
45
+ -1.5748300552368164,
46
+ -2.555540084838867,
47
+ 0.07989349216222763,
48
+ 0.23952914774417877,
49
+ 0.5571582317352295,
50
+ -1.200455904006958,
51
+ -1.2390071153640747,
52
+ -0.5626499652862549
53
+ ],
54
+ "much_overtone": [
55
+ -1.649719476699829,
56
+ -2.1237597465515137,
57
+ -3.3006417751312256,
58
+ -1.4381871223449707,
59
+ 2.0898361206054688,
60
+ 0.34304752945899963,
61
+ -1.6530671119689941,
62
+ 0.6339190006256104,
63
+ -0.700051486492157,
64
+ -0.6604726910591125,
65
+ -2.3133397102355957,
66
+ 3.0709166526794434,
67
+ -1.2595521211624146,
68
+ 0.6515411138534546,
69
+ 1.8037245273590088,
70
+ 0.17395955324172974,
71
+ -1.1443099975585938,
72
+ -2.599336624145508,
73
+ -1.909640908241272,
74
+ -1.422598123550415,
75
+ 0.5974328517913818,
76
+ -2.559039354324341,
77
+ -2.4977917671203613,
78
+ -0.9264755249023438
79
+ ],
80
+ "much_overtone_high_register": [
81
+ -1.5539246797561646,
82
+ -1.4718247652053833,
83
+ -0.2083689421415329,
84
+ 0.47305312752723694,
85
+ -0.3550421893596649,
86
+ 1.6288657188415527,
87
+ -2.5005292892456055,
88
+ -0.5079396367073059,
89
+ 1.9173517227172852,
90
+ -2.8692283630371094,
91
+ -1.3840413093566895,
92
+ 1.8955140113830566,
93
+ -0.6880097389221191,
94
+ 1.2770217657089233,
95
+ 0.2371762990951538,
96
+ 2.726161241531372,
97
+ -1.7957547903060913,
98
+ -0.28189635276794434,
99
+ -0.6052505373954773,
100
+ -0.557244598865509,
101
+ -0.7524706721305847,
102
+ -0.6265298128128052,
103
+ -1.4746038913726807,
104
+ -2.393972396850586
105
+ ],
106
+ "blurry": [
107
+ 1.3638803958892822,
108
+ -1.7874348163604736,
109
+ 0.4853333532810211,
110
+ -0.9626323580741882,
111
+ -0.2609613537788391,
112
+ -1.5900236368179321,
113
+ -2.30368709564209,
114
+ 0.4847792387008667,
115
+ 1.5201102495193481,
116
+ -0.5147597789764404,
117
+ -2.452840805053711,
118
+ 1.5057097673416138,
119
+ -0.16519485414028168,
120
+ 2.126760721206665,
121
+ 0.7019514441490173,
122
+ 1.612999439239502,
123
+ -0.3407663106918335,
124
+ -2.8276443481445312,
125
+ 1.311924934387207,
126
+ -1.5173300504684448,
127
+ -0.015635617077350616,
128
+ -2.7689361572265625,
129
+ 1.1804192066192627,
130
+ -3.077000141143799
131
+ ],
132
+ "interesting_release": [
133
+ 1.7602112293243408,
134
+ -1.9114476442337036,
135
+ -2.730947256088257,
136
+ -0.6088603138923645,
137
+ 3.317946195602417,
138
+ -0.2716876268386841,
139
+ -2.3106558322906494,
140
+ -2.114469289779663,
141
+ -4.443742752075195,
142
+ -1.0665826797485352,
143
+ -3.0929622650146484,
144
+ 1.1979585886001587,
145
+ -1.6287152767181396,
146
+ -1.537142276763916,
147
+ 2.4184482097625732,
148
+ 0.22694607079029083,
149
+ 0.10934393107891083,
150
+ -0.18058283627033234,
151
+ -2.489964723587036,
152
+ -4.448374271392822,
153
+ 1.2452409267425537,
154
+ 0.05835026502609253,
155
+ 0.8547804355621338,
156
+ 0.8163737654685974
157
+ ],
158
+ "global_trend": [
159
+ -1.0987968444824219,
160
+ -1.1155377626419067,
161
+ 0.14996160566806793,
162
+ -3.165109157562256,
163
+ -2.5396244525909424,
164
+ 1.8292016983032227,
165
+ -3.5159406661987305,
166
+ 3.4396510124206543,
167
+ -2.3765876293182373,
168
+ 0.5692415833473206,
169
+ -1.7827686071395874,
170
+ -0.4062053859233856,
171
+ -1.6925498247146606,
172
+ 0.7511563897132874,
173
+ 0.12510846555233002,
174
+ -0.14617247879505157,
175
+ 0.5096412897109985,
176
+ 2.2399022579193115,
177
+ 0.5798826217651367,
178
+ -1.5942487716674805,
179
+ -0.36588573455810547,
180
+ -0.9877008199691772,
181
+ 4.732168674468994,
182
+ -5.468194007873535
183
+ ],
184
+ "crescendo": [
185
+ 1.3167105913162231,
186
+ -1.1503334045410156,
187
+ -3.488548517227173,
188
+ -1.146520972251892,
189
+ 2.478545665740967,
190
+ -0.5853592753410339,
191
+ -2.1441550254821777,
192
+ -2.1898915767669678,
193
+ -7.137173175811768,
194
+ -0.34099096059799194,
195
+ -3.832253932952881,
196
+ 2.0366034507751465,
197
+ -1.3639447689056396,
198
+ -2.3450658321380615,
199
+ 1.1388293504714966,
200
+ 1.1278795003890991,
201
+ 0.6025446653366089,
202
+ -0.2925209701061249,
203
+ -0.07147964835166931,
204
+ -3.1970367431640625,
205
+ 2.4061689376831055,
206
+ 0.477089524269104,
207
+ -0.8897881507873535,
208
+ 2.827509880065918
209
+ ]
210
+ }