Sound_VAE / load_data.py
WeixuanYuan's picture
Upload 31 files
b88cc47
import joblib
import numpy as np
from generate_synthetic_data_online import generate_synth_dataset_log_512, generate_synth_dataset_log_muted_512
from tools import show_spc, spc_to_VAE_input, VAE_out_put_to_spc, np_log10
import torch.utils.data as data
class Data_cache():
"""This is a class that stores synthetic data."""
def __init__(self, synthetic_data, external_sources):
self.n_synthetic = np.shape(synthetic_data)[0]
self.synthetic_data = synthetic_data.astype(np.float32)
self.external_sources = external_sources.astype(np.float32)
self.epsilon = 1e-20
def get_all_data(self):
return np.vstack([self.synthetic_data, self.external_sources])
def refresh(self):
self.synthetic_data = generate_synth_dataset(self.n_synthetic, mute=True)
def get_data_loader(self, shuffle=True, BATCH_SIZE=8, new_way=False):
all_data = self.get_all_data()
our_data = []
for i in range(len(all_data)):
if new_way:
spectrogram = VAE_out_put_to_spc(np.reshape(all_data[i], (1, 512, 256)))
log_spectrogram = np.log10(spectrogram + self.epsilon)
our_data.append(log_spectrogram)
else:
our_data.append(np.reshape(all_data[i], (1, 512, 256)))
iterator = data.DataLoader(our_data, shuffle=shuffle, batch_size=BATCH_SIZE)
return iterator
def generate_synth_dataset(n_synthetic, mute=False):
"""Preprocessing for synthetic data"""
n_synthetic_sample = n_synthetic
if mute:
Input0 = generate_synth_dataset_log_muted_512(n_synthetic_sample)
else:
Input0 = generate_synth_dataset_log_512(n_synthetic_sample)
Input0 = spc_to_VAE_input(Input0)
Input0 = Input0.reshape(Input0.shape[0], Input0.shape[1], Input0.shape[2], 1)
return Input0
def read_data(data_path):
"""Read external sources"""
data = np.array(joblib.load(data_path))
data = spc_to_VAE_input(data)
data = data.reshape(data.shape[0], data.shape[1], data.shape[2], 1)
return data
def load_data(n_synthetic):
"""Generate the hybrid dataset."""
Input_synthetic = generate_synth_dataset(n_synthetic)
Input_AU = read_data("./data/external_data/ARTURIA_data")
print("ARTURIA dataset loaded.")
Input_NSynth = read_data("./data/external_data/NSynth_data")
print("NSynth dataset loaded.")
Input_SF = read_data("./data/external_data/soundfonts_data")
Input_SF_256 = np.zeros((337, 512, 256, 1))
Input_SF_256[:,:,:251,:] += Input_SF
Input_SF =Input_SF_256
print("SoundFonts dataset loaded.")
Input_google = read_data("./data/external_data/WaveNet_samples")
Input_external = np.vstack([Input_AU, Input_NSynth, Input_SF, Input_google])
data_cache = Data_cache(Input_synthetic, Input_external)
print(f"Data loaded, data shape: {np.shape(data_cache.get_all_data())}")
return data_cache
def show_data(dataset_name, n_sample=3, index=-1, new_way=False):
"""Show and return a certain dataset.
Parameters
----------
dataset_name: String
Name of the dataset to show.
n_samples: int
Number of samples to show.
index: int
Setting 'index' larger equal 0 shows the 'index'-th sample in the desired dataset.
Returns
-------
np.ndarray:
The showed dataset.
"""
if dataset_name == "ARTURIA":
data = read_data("./data/external_data/ARTURIA_data")
elif dataset_name == "NSynth":
data = read_data("./data/external_data/NSynth_data")
elif dataset_name == "SoundFonts":
data = read_data("./data/external_data/soundfonts_data")
elif dataset_name == "Synthetic":
data = generate_synth_dataset(int(n_sample * 3))
else:
print("Example command: \"!python thesis_main.py show_data -s [ARTURIA, NSynth, SoundFonts, Synthetic] -n 5\"")
return
if index >= 0:
show_spc(VAE_out_put_to_spc(data[index]))
else:
for i in range(n_sample):
index = np.random.randint(0,len(data))
print(index)
show_spc(VAE_out_put_to_spc(data[index]))
return data
def show_data(tensor_batch, index=-1, new_way=False):
if index < 0:
index = np.random.randint(0, tensor_batch.shape[0])
if new_way:
sample = tensor_batch[index].detach().numpy()
spectrogram = 10.0 ** sample
print(f"The {index}-th sample:")
show_spc(spectrogram)
else:
sample = tensor_batch[index].detach().numpy()
show_spc(VAE_out_put_to_spc(sample))
# return data