WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
import librosa
import numpy as np
import torch
from tools import np_power_to_db, decode_stft, depad_STFT
def spectrogram_to_Gradio_image(spc):
### input: spc [np.ndarray]
frequency_resolution, time_resolution = spc.shape[-2], spc.shape[-1]
spc = np.reshape(spc, (frequency_resolution, time_resolution))
# Todo:
magnitude_spectrum = np.abs(spc)
log_spectrum = np_power_to_db(magnitude_spectrum)
flipped_log_spectrum = np.flipud(log_spectrum)
colorful_spc = np.ones((frequency_resolution, time_resolution, 3)) * -80.0
colorful_spc[:, :, 0] = flipped_log_spectrum
colorful_spc[:, :, 1] = flipped_log_spectrum
colorful_spc[:, :, 2] = np.ones((frequency_resolution, time_resolution)) * -60.0
# Rescale to 0-255 and convert to uint8
rescaled = (colorful_spc + 80.0) / 80.0
rescaled = (255.0 * rescaled).astype(np.uint8)
return rescaled
def phase_to_Gradio_image(phase):
### input: spc [np.ndarray]
frequency_resolution, time_resolution = phase.shape[-2], phase.shape[-1]
phase = np.reshape(phase, (frequency_resolution, time_resolution))
# Todo:
flipped_phase = np.flipud(phase)
flipped_phase = (flipped_phase + 1.0) / 2.0
colorful_spc = np.zeros((frequency_resolution, time_resolution, 3))
colorful_spc[:, :, 0] = flipped_phase
colorful_spc[:, :, 1] = flipped_phase
colorful_spc[:, :, 2] = 0.2
# Rescale to 0-255 and convert to uint8
rescaled = (255.0 * colorful_spc).astype(np.uint8)
return rescaled
def latent_representation_to_Gradio_image(latent_representation):
# input: latent_representation [torch.tensor]
if not isinstance(latent_representation, np.ndarray):
latent_representation = latent_representation.to("cpu").detach().numpy()
image = latent_representation
def normalize_image(img):
min_val = img.min()
max_val = img.max()
normalized_img = ((img - min_val) / (max_val - min_val) * 255)
return normalized_img
image[0, :, :] = normalize_image(image[0, :, :])
image[1, :, :] = normalize_image(image[1, :, :])
image[2, :, :] = normalize_image(image[2, :, :])
image[3, :, :] = normalize_image(image[3, :, :])
image_transposed = np.transpose(image, (1, 2, 0))
enlarged_image = np.repeat(image_transposed, 8, axis=0)
enlarged_image = np.repeat(enlarged_image, 8, axis=1)
return np.flipud(enlarged_image).astype(np.uint8)
def InputBatch2Encode_STFT(encoder, STFT_batch, resolution=(512, 256), quantizer=None, squared=True):
"""Transform batch of numpy spectrogram's into signals and encodings."""
# Todo: remove resolution hard-coding
frequency_resolution, time_resolution = resolution
device = next(encoder.parameters()).device
if not (quantizer is None):
latent_representation_batch = encoder(STFT_batch.to(device))
quantized_latent_representation_batch, loss, (_, _, _) = quantizer(latent_representation_batch)
else:
mu, logvar, latent_representation_batch = encoder(STFT_batch.to(device))
quantized_latent_representation_batch = None
STFT_batch = STFT_batch.to("cpu").detach().numpy()
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals = [], [], []
for STFT in STFT_batch:
padded_D_rec = decode_stft(STFT)
D_rec = depad_STFT(padded_D_rec)
spc = np.abs(D_rec)
phase = np.angle(D_rec)
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
flipped_phase = phase_to_Gradio_image(phase)
# get_audio
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
origin_flipped_log_spectrums.append(flipped_log_spectrum)
origin_flipped_phases.append(flipped_phase)
origin_signals.append(rec_signal)
return origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, \
latent_representation_batch, quantized_latent_representation_batch
def encodeBatch2GradioOutput_STFT(decoder, latent_vector_batch, resolution=(512, 256), original_STFT_batch=None):
"""Show a spectrogram."""
# Todo: remove resolution hard-coding
frequency_resolution, time_resolution = resolution
if isinstance(latent_vector_batch, np.ndarray):
latent_vector_batch = torch.from_numpy(latent_vector_batch).to(next(decoder.parameters()).device)
reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()
flipped_log_spectrums, flipped_phases, rec_signals = [], [], []
flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp = [], [], []
for index, STFT in enumerate(reconstruction_batch):
padded_D_rec = decode_stft(STFT)
D_rec = depad_STFT(padded_D_rec)
spc = np.abs(D_rec)
phase = np.angle(D_rec)
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
flipped_phase = phase_to_Gradio_image(phase)
# get_audio
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
flipped_log_spectrums.append(flipped_log_spectrum)
flipped_phases.append(flipped_phase)
rec_signals.append(rec_signal)
##########################################
if original_STFT_batch is not None:
STFT[0, :, :] = original_STFT_batch[index, 0, :, :]
padded_D_rec = decode_stft(STFT)
D_rec = depad_STFT(padded_D_rec)
spc = np.abs(D_rec)
phase = np.angle(D_rec)
flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
flipped_phase = phase_to_Gradio_image(phase)
# get_audio
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
flipped_log_spectrums_with_original_amp.append(flipped_log_spectrum)
flipped_phases_with_original_amp.append(flipped_phase)
rec_signals_with_original_amp.append(rec_signal)
return flipped_log_spectrums, flipped_phases, rec_signals, \
flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp
def add_instrument(source_dict, virtual_instruments_dict, virtual_instrument_name, sample_index):
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
virtual_instrument = {
"latent_representation": source_dict["latent_representations"][sample_index],
"quantized_latent_representation": source_dict["quantized_latent_representations"][sample_index],
"sampler": source_dict["sampler"],
"signal": source_dict["new_sound_rec_signals_gradio"][sample_index],
"spectrogram_gradio_image": source_dict["new_sound_spectrogram_gradio_images"][
sample_index],
"phase_gradio_image": source_dict["new_sound_phase_gradio_images"][
sample_index]}
virtual_instruments[virtual_instrument_name] = virtual_instrument
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
return virtual_instruments_dict