Sound_VAE / generate_synthetic_data_online.py
WeixuanYuan's picture
Upload 31 files
b88cc47
import matplotlib.pyplot as plt
import librosa
import matplotlib
import pandas as pd
from typing import Optional
from torch import tensor
from ddsp.core import tf_float32
import torch
from torch import Tensor
import numpy as np
import tensorflow as tf
from torchsynth.config import SynthConfig
import ddsp
from pathlib import Path
from typing import Dict
from data_generation.encoding import ParameterDescription
from typing import List
from configurations.read_configuration import parameter_range, is_discrete, midi_parameter_range, midi_is_discrete
import shutil
from tqdm import tqdm
from scipy.io.wavfile import write
from melody_synth.complex_torch_synth import DoubleSawSynth, SinSawSynth, SinTriangleSynth, TriangleSawSynth
sample_rate = 16000
n_samples = sample_rate * 4.5
class NoteGenerator:
"""
This class is responsible for single-note audio generation by function 'get_note'.
"""
def __init__(self,
sample_rate=sample_rate,
n_samples=sample_rate * 4.5):
self.sample_rate = sample_rate
self.n_samples = n_samples
synthconfig = SynthConfig(
batch_size=1, reproducible=False, sample_rate=sample_rate,
buffer_size_seconds=np.float64(n_samples) / np.float64(sample_rate)
)
self.Saw_Square_Voice = DoubleSawSynth(synthconfig)
self.SinSawVoice = SinSawSynth(synthconfig)
self.SinTriVoice = SinTriangleSynth(synthconfig)
self.TriSawVoice = TriangleSawSynth(synthconfig)
def get_note(self, params: Dict[str, float]):
osc_amp2 = np.float64(params.get("osc_amp2", 0))
if osc_amp2 < 0.45:
osc1_amp = 0.9
osc2_amp = osc_amp2
else:
osc1_amp = 0.9 - osc_amp2
osc2_amp = 0.9
attack_1 = np.float64(params.get("attack_1", 0))
decay_1 = np.float64(params.get("decay_1", 0))
sustain_1 = np.float64(params.get("sustain_1", 0))
release_1 = np.float64(params.get("release_1", 0))
attack_2 = np.float64(params.get("attack_2", 0))
decay_2 = np.float64(params.get("decay_2", 0))
sustain_2 = np.float64(params.get("sustain_2", 0))
release_2 = np.float64(params.get("release_2", 0))
amp_mod_freq = params.get("amp_mod_freq", 0)
amp_mod_depth = params.get("amp_mod_depth", 0)
amp_mod_waveform = params.get("amp_mod_waveform", 0)
pitch_mod_freq_1 = params.get("pitch_mod_freq_1", 0)
pitch_mod_depth = params.get("pitch_mod_depth", 0)
cutoff_freq = params.get("cutoff_freq", 4000)
pitch = np.float64(params.get("pitch", 0))
duration = np.float64(params.get("duration", 0))
syn_parameters = {
("adsr_1", "attack"): tensor([attack_1]), # [0.0, 2.0]
("adsr_1", "decay"): tensor([decay_1]), # [0.0, 2.0]
("adsr_1", "sustain"): tensor([sustain_1]), # [0.0, 2.0]
("adsr_1", "release"): tensor([release_1]), # [0.0, 2.0]
("adsr_1", "alpha"): tensor([5]), # [0.1, 6.0]
("adsr_2", "attack"): tensor([attack_2]), # [0.0, 2.0]
("adsr_2", "decay"): tensor([decay_2]), # [0.0, 2.0]
("adsr_2", "sustain"): tensor([sustain_2]), # [0.0, 2.0]
("adsr_2", "release"): tensor([release_2]), # [0.0, 2.0]
("adsr_2", "alpha"): tensor([5]), # [0.1, 6.0]
("keyboard", "midi_f0"): tensor([pitch]),
("keyboard", "duration"): tensor([duration]),
# Mixer parameter
("mixer", "vco_1"): tensor([osc1_amp]), # [0, 1]
("mixer", "vco_2"): tensor([osc2_amp]), # [0, 1]
# Constant parameters:
("vco_1", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
("vco_1", "tuning"): tensor([0.0]), # [-24.0, 24]
("vco_2", "mod_depth"): tensor([pitch_mod_depth]), # [-96, 96]
("vco_2", "tuning"): tensor([0.0]), # [-24.0, 24]
# LFOs
("lfo_amp_sin", "frequency"): tensor([amp_mod_freq]), # [0, 20]
("lfo_amp_sin", "mod_depth"): tensor([0]), # [-10, 20]
("lfo_pitch_sin_1", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
("lfo_pitch_sin_1", "mod_depth"): tensor([10]), # [-10, 20]
("lfo_pitch_sin_2", "frequency"): tensor([pitch_mod_freq_1]), # [0, 20]
("lfo_pitch_sin_2", "mod_depth"): tensor([10]), # [-10, 20]
}
osc_types = params.get("osc_types", 0)
if osc_types == 0:
synth = self.SinSawVoice
syn_parameters[("vco_2", "shape")] = tensor([1])
elif osc_types == 1:
synth = self.SinSawVoice
syn_parameters[("vco_2", "shape")] = tensor([0])
elif osc_types == 2:
synth = self.Saw_Square_Voice
syn_parameters[("vco_1", "shape")] = tensor([1])
syn_parameters[("vco_2", "shape")] = tensor([0])
elif osc_types == 3:
synth = self.SinTriVoice
elif osc_types == 4:
synth = self.TriSawVoice
syn_parameters[("vco_2", "shape")] = tensor([1])
else:
synth = self.TriSawVoice
syn_parameters[("vco_2", "shape")] = tensor([0])
synth.set_parameters(syn_parameters)
audio_out = synth.get_signal(amp_mod_depth, amp_mod_waveform, int(sample_rate * duration), osc1_amp, osc2_amp)
single_note = audio_out[0].detach().numpy()
cutoff_freq = tf_float32(cutoff_freq)
impulse_response = ddsp.core.sinc_impulse_response(cutoff_freq, 2048, self.sample_rate)
single_note = tf_float32(single_note)
return ddsp.core.fft_convolve(single_note[tf.newaxis, :], impulse_response)[0, :]
class MelodyGenerator:
"""
This class is responsible for multi-note audio generation by function 'get_melody'.
"""
def __init__(self,
sample_rate=sample_rate,
n_note_samples=sample_rate * 4.5,
n_melody_samples=sample_rate * 4.5):
self.sample_rate = sample_rate
self.noteGenerator = NoteGenerator(sample_rate, sample_rate * 4.5)
self.n_melody_samples = int(n_melody_samples)
def get_melody(self, params_list: List[Dict[str, float]], onsets):
track = np.zeros(self.n_melody_samples)
for i in range(len(onsets)):
location = onsets[i]
single_note = self.noteGenerator.get_note(params_list[i])
single_note = np.hstack(
[np.zeros(int(location)), single_note, np.zeros(self.n_melody_samples)])[
:self.n_melody_samples]
track = track + single_note
return track
def plot_log_spectrogram(signal: np.ndarray,
path: str,
n_fft=2048,
frame_length=1024,
frame_step=256):
"""Write spectrogram."""
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
magnitude_spectrum = np.abs(amp)
log_mel = np_power_to_db(magnitude_spectrum)
matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
def np_power_to_db(S, amin=1e-16, top_db=80.0):
"""A helper function for scaling."""
def np_log10(x):
numerator = np.log(x)
denominator = np.log(10)
return numerator / denominator
# Scale magnitude relative to maximum value in S. Zeros in the output
# correspond to positions where S == ref.
ref = np.max(S)
# 每个元素取max
log_spec = 10.0 * np_log10(np.maximum(amin, S))
log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
log_spec = np.maximum(log_spec, np.max(log_spec) - top_db)
return log_spec
synth = MelodyGenerator()
param_descriptions: List[ParameterDescription]
param_descriptions = [
# Oscillator levels
ParameterDescription(name="osc_amp2",
values=parameter_range('osc_amp2'),
discrete=is_discrete('osc_amp2')),
# ADSR params
ParameterDescription(name="attack_1",
values=parameter_range('attack'),
discrete=is_discrete('attack')),
ParameterDescription(name="decay_1",
values=parameter_range('decay'),
discrete=is_discrete('decay')),
ParameterDescription(name="sustain_1",
values=parameter_range('sustain'),
discrete=is_discrete('sustain')),
ParameterDescription(name="release_1",
values=parameter_range('release'),
discrete=is_discrete('release')),
ParameterDescription(name="attack_2",
values=parameter_range('attack'),
discrete=is_discrete('attack')),
ParameterDescription(name="decay_2",
values=parameter_range('decay'),
discrete=is_discrete('decay')),
ParameterDescription(name="sustain_2",
values=parameter_range('sustain'),
discrete=is_discrete('sustain')),
ParameterDescription(name="release_2",
values=parameter_range('release'),
discrete=is_discrete('release')),
ParameterDescription(name="cutoff_freq",
values=parameter_range('cutoff_freq'),
discrete=is_discrete('cutoff_freq')),
ParameterDescription(name="pitch",
values=midi_parameter_range('pitch'),
discrete=midi_is_discrete('pitch')),
ParameterDescription(name="duration",
values=midi_parameter_range('duration'),
discrete=midi_is_discrete('duration')),
ParameterDescription(name="amp_mod_freq",
values=parameter_range('amp_mod_freq'),
discrete=is_discrete('amp_mod_freq')),
ParameterDescription(name="amp_mod_depth",
values=parameter_range('amp_mod_depth'),
discrete=is_discrete('amp_mod_depth')),
ParameterDescription(name="pitch_mod_freq_1",
values=parameter_range('pitch_mod_freq'),
discrete=is_discrete('pitch_mod_freq')),
ParameterDescription(name="pitch_mod_freq_2",
values=parameter_range('pitch_mod_freq'),
discrete=is_discrete('pitch_mod_freq')),
ParameterDescription(name="pitch_mod_depth",
values=parameter_range('pitch_mod_depth'),
discrete=is_discrete('pitch_mod_depth')),
# Oscillators types
# 0 for sin saw, 1 for sin square, 2 for saw square
# 3 for sin triangle, 4 for triangle saw, 5 for triangle square
ParameterDescription(name="osc_types",
values=parameter_range('osc_types'),
discrete=is_discrete('osc_types')),
]
frame_length = 1024
frame_step = 256
spectrogram_len = 256
n_fft = 1024
def generate_synth_dataset_log_muted_512(n: int, path_name="./data/data_log", write_spec=False):
if Path(path_name).exists():
shutil.rmtree(path_name)
Path(path_name).mkdir(parents=True, exist_ok=True)
print("Generating dataset...")
synthetic_data = np.ones((n, 512, 256))
for i in range(n):
index = i
parameter_values = [param.generate() for param in param_descriptions]
parameter_values_raw = {param.name: param.value for param in parameter_values}
parameter_values_raw["duration"] = 3.0
parameter_values_raw["pitch"] = 52
parameter_values_raw["pitch_mod_depth"] = 0.0
signal = synth.get_melody([parameter_values_raw], [0])
# mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
synthetic_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
print(f"Generating dataset over, {n} samples generated!")
return synthetic_data
def generate_synth_dataset_log_512(n: int, path_name="./data/data_log", write_spec=False):
"""Generate the synthetic dataset with a progress bar."""
if Path(path_name).exists():
shutil.rmtree(path_name)
Path(path_name).mkdir(parents=True, exist_ok=True)
print("Generating dataset...")
synthetic_data = np.ones((n, 512, 256))
for i in tqdm(range(n)):
index = i
parameter_values = [param.generate() for param in param_descriptions]
parameter_values_raw = {param.name: param.value for param in parameter_values}
parameter_values_raw["duration"] = 3.0
parameter_values_raw["pitch"] = 52
parameter_values_raw["pitch_mod_depth"] = 0.0
signal = synth.get_melody([parameter_values_raw], [0])
# mel = librosa.feature.melspectrogram(signal, sr=sample_rate, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)[:,:spectrogram_len]
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
synthetic_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/{i}.png", frame_length=frame_length, frame_step=frame_step)
print(f"Generating dataset over, {n} samples generated!")
return synthetic_data
def generate_DANN_dataset_muted(n: int, path_name="./data/data_DANN", write_spec=False):
"""Generate the synthetic dataset without a progress bar."""
if Path(path_name).exists():
shutil.rmtree(path_name)
Path(path_name).mkdir(parents=True, exist_ok=True)
print("Generating dataset...")
multinote_data = np.ones((n, 512, 256))
single_data = np.ones((n, 512, 256))
for i in range(n):
index = i
par_list = []
n_notes = np.random.randint(1, 5)
onsets = []
for j in range(n_notes):
parameter_values = [param.generate() for param in param_descriptions]
parameter_values_raw = {param.name: param.value for param in parameter_values}
# parameter_values_raw["duration"] = 0.5
parameter_values_raw["pitch_mod_depth"] = 0.0
par_list.append(parameter_values_raw)
onsets.append(np.random.randint(0, sample_rate * 3))
signal = synth.get_melody(par_list, onsets)
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
multinote_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
frame_step=frame_step)
single_par = par_list[np.argmin(onsets)]
single_par["duration"] = 3.0
single_par["pitch"] = 52
signal = synth.get_melody([single_par], [0])
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
single_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
frame_step=frame_step)
print(f"Generating dataset over, {n} samples generated!")
return multinote_data, single_data
def generate_DANN_dataset(n: int, path_name="./data/data_DANN", write_spec=False):
"""Generate the synthetic dataset for adversarial training."""
if Path(path_name).exists():
shutil.rmtree(path_name)
Path(path_name).mkdir(parents=True, exist_ok=True)
print("Generating dataset...")
multinote_data = np.ones((n, 512, 256))
single_data = np.ones((n, 512, 256))
for i in tqdm(range(n)):
par_list = []
n_notes = np.random.randint(1, 5)
onsets = []
for j in range(n_notes):
parameter_values = [param.generate() for param in param_descriptions]
parameter_values_raw = {param.name: param.value for param in parameter_values}
parameter_values_raw["pitch_mod_depth"] = 0.0
par_list.append(parameter_values_raw)
onsets.append(np.random.randint(0, sample_rate * 3))
signal = synth.get_melody(par_list, onsets)
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
multinote_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/mul_{i}.png", frame_length=frame_length,
frame_step=frame_step)
single_par = par_list[np.argmin(onsets)]
single_par["duration"] = 3.0
single_par["pitch"] = 52
signal = synth.get_melody([single_par], [0])
stft = librosa.stft(signal, n_fft=1024, hop_length=256, win_length=1024)
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
single_data[i] = amp[:512, :256]
if write_spec:
write(path_name + f"/{i}.wav", synth.sample_rate, signal)
plot_log_spectrogram(signal, path=path_name + f"/single_{i}.png", frame_length=frame_length,
frame_step=frame_step)
print(f"Generating dataset over, {n} samples generated!")
return multinote_data, single_data