import gradio as gr import os import json import numpy as np import torch import librosa from tools import VAE_out_put_to_spc, np_power_to_db from model.VAE_torchV import Encoder, Decoder SPECTROGRAM_RESOLUTION = (512, 256, 3) device = "cpu" encoder = Encoder((1, 512, 256), 24, N2=0, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device) decoder = Decoder(24, N2=0, N3=8, channel_sizes=[64, 64, 64, 96, 96, 128, 160, 216]).to(device) model_name = "test" encoder.load_state_dict(torch.load(f"models/test_encoder_CA.pt", map_location=torch.device(device))) decoder.load_state_dict(torch.load(f"models/test_decoder_CA.pt", map_location=torch.device(device))) INIT_ENCODE_CACHE = {"init": np.random.random((24, ))} with open('webUI/initial_example_encodes.json', 'r') as f: list_dict = json.load(f) for k in list_dict.keys(): INIT_ENCODE_CACHE[k] = np.array(list_dict[k]) ################################# def prepare_image(image): # Rescale to 0-255 and convert to uint8 rescaled = (image + 80.0) / 80.0 rescaled = (255.0 * rescaled).astype(np.uint8) return rescaled def encodeBatch2GradioOutput(latent_vector_batch, resolution=(512, 256)): """Show a spectrogram.""" reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy() flipped_log_spectrums, rec_signals = [], [] for reconstruction in reconstruction_batch: spc = VAE_out_put_to_spc(reconstruction) spc = np.reshape(spc, resolution) magnitude_spectrum = np.abs(spc) log_spectrum = np_power_to_db(magnitude_spectrum) flipped_log_spectrum = np.flipud(log_spectrum) colorful_spc = np.ones((512, 256, 3)) * -80.0 colorful_spc[:, :, 0] = flipped_log_spectrum colorful_spc[:, :, 1] = flipped_log_spectrum colorful_spc[:, :, 2] = np.ones((512, 256)) * -60.0 flipped_log_spectrum = prepare_image(colorful_spc) # get_audio abs_spec = np.zeros((513, 256)) abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spc, (512, 256))) rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024) flipped_log_spectrums.append(flipped_log_spectrum) rec_signals.append(rec_signal) return flipped_log_spectrums, 16000, rec_signals def get_example_module(encodeCache): def show_example(selected_example, encodeCache): example_encode = torch.Tensor(np.reshape(encodeCache[selected_example], (-1, 24))).to(device) flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(example_encode) flipped_log_spectrum, rec_signal = flipped_log_spectrums[0], rec_signals[0] return flipped_log_spectrum, (sampleRate, rec_signal), encodeCache with gr.Tab("Examples"): gr.Markdown("Some predefined examples.") with gr.Row(): with gr.Column(): selected_example = gr.Dropdown( list(INIT_ENCODE_CACHE.keys()), label="Examples", info="Choose one example! More samples coming." ) example_button = gr.Button(value="Show example") with gr.Column(): example_image_output = gr.Image(label="Spectrogram", type="numpy") example_image_output.style(height=250, width=600) example_audio_output = gr.Audio(type="numpy", label="Play the example!") example_button.click(show_example, inputs=[selected_example, encodeCache], outputs=[example_image_output, example_audio_output, encodeCache]) def get_reconstruction_module(): def do_nothing(image_input): return np.random.random(SPECTROGRAM_RESOLUTION) with gr.Tab("Reconstruction"): gr.Markdown("Test reconstruction.") with gr.Row(): with gr.Column(): test_reconstruction_input = gr.Number(label="Batch_index") test_reconstruction_button = gr.Button(value="Generate") with gr.Column(): test_reconstruction_output = gr.Image(label="Reconstruction", type="numpy") test_reconstruction_output.style(height=250, width=600) test_reconstruction_button.click(do_nothing, inputs=test_reconstruction_input, outputs=test_reconstruction_output) def get_interpolation_module(encodeCache): def interpolate(first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache): # Todo: use batch first_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[first_interpulation_input], (-1, 24))) second_interpulation_input_encode = torch.Tensor(np.reshape(encodeCache[second_interpulation_input], (-1, 24))) ratio = torch.Tensor([interpulation_input_ratio]) interpulation_encode = first_interpulation_input_encode * ratio + second_interpulation_input_encode * (1 - ratio) interpulation_input_encode = torch.stack((first_interpulation_input_encode, second_interpulation_input_encode, interpulation_encode), dim=0).to(device) flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(interpulation_input_encode) first_flipped_log_spectrum, first_rec_signal = flipped_log_spectrums[0], rec_signals[0] second_flipped_log_spectrum, second_rec_signal = flipped_log_spectrums[1], rec_signals[1] interpolation_flipped_log_spectrum, interpolation_rec_signal = flipped_log_spectrums[2], rec_signals[2] return first_flipped_log_spectrum, (sampleRate, first_rec_signal), second_flipped_log_spectrum, (sampleRate, second_rec_signal), interpolation_flipped_log_spectrum, (sampleRate, interpolation_rec_signal), encodeCache def refresh_interpolation_input(encodeCache): return gr.Dropdown.update(choices=list(encodeCache.keys())), gr.Dropdown.update(choices=list(encodeCache.keys())), encodeCache with gr.Tab("Interpolation"): gr.Markdown("Test Interpolation. Sounds that you sampled can be added to the input dropdown by clicking [Refresh].") with gr.Row(): with gr.Column(): interpulation_refresh_button = gr.Button(value="Refresh") with gr.Row(): first_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="First input") second_interpulation_input = gr.Dropdown(list(INIT_ENCODE_CACHE.keys()), label="Second input") with gr.Row(): first_input_audio = gr.Audio(type="numpy", label="Play first input") first_input_audio.style(length=125) second_input_audio = gr.Audio(type="numpy", label="sePlay second input") second_input_audio.style(length=125) interpulation_input_ratio = gr.Slider(minimum=-0.20, maximum=1.20, value=0.5, step=0.01, label="Ratio of the first input.") interpulation_button = gr.Button(value="Interpulate") with gr.Column(): with gr.Row(): first_input_spectrogram = gr.Image(label="First Input", type="numpy") first_input_spectrogram.style(height=250, width=125) interpolation_spectrogram = gr.Image(label="Interpolation", type="numpy") interpolation_spectrogram.style(height=250, width=125) second_input_spectrogram = gr.Image(label="Second Input", type="numpy") second_input_spectrogram.style(height=250, width=125) interpolation_audio = gr.Audio(type="numpy", label="Play interpolation") interpolation_audio.style(length=125) interpulation_refresh_button.click(refresh_interpolation_input, inputs=[encodeCache], outputs=[first_interpulation_input, second_interpulation_input, encodeCache]) interpulation_button.click(interpolate, inputs=[first_interpulation_input, second_interpulation_input, interpulation_input_ratio, encodeCache], outputs=[first_input_spectrogram, first_input_audio, second_input_spectrogram, second_input_audio, interpolation_spectrogram, interpolation_audio, encodeCache]) def get_random_sampling_module(encodeCache, current_encode): def random_sample(sigma, current_encode): random_encode = torch.Tensor([sigma]) * torch.randn(1, 24) # random_encode = torch.Tensor([mu]) + torch.Tensor([sigma]) * torch.randn(1, 24) current_encode = current_encode * 0.0 + random_encode.detach().numpy() random_encode = random_encode.to(device) flipped_log_spectrums, sampleRate, rec_signals = encodeBatch2GradioOutput(random_encode) random_log_spectrum, random_rec_signal = flipped_log_spectrums[0], rec_signals[0] return random_log_spectrum, (sampleRate, random_rec_signal), current_encode def save_sample(save_name, current_encode, encodeCache): if not (np.sum(current_encode) == 24): if len(save_name) == 0: return "The save name is empty.", current_encode, encodeCache encodeCache[save_name] = current_encode return "Sample saved.", current_encode, encodeCache else: return f"Please generate one sample.", current_encode, encodeCache with gr.Tab("Random sampling"): gr.Markdown("Sample new sound! Feel free to name and save your samples!") with gr.Row(): with gr.Column(): with gr.Row(): # mu = gr.Number(label="mu") sigma = gr.Number(value=1.0, label="sigma") random_sampling_button = gr.Button(value="Sample") with gr.Column(): random_sampling_spectrogram = gr.Image(label="Random sampling", type="numpy") random_sampling_spectrogram.style(height=250, width=600) random_sampling_audio = gr.Audio(type="numpy", label="Play the sample") random_sampling_audio.style(length=125) save_name_input = gr.Textbox(label="Name your sound") save_button = gr.Button(value="save") save_name_output = gr.Textbox(label="Save it for interpolation") random_sampling_button.click(random_sample, inputs=[sigma, current_encode], outputs=[random_sampling_spectrogram, random_sampling_audio, current_encode]) save_button.click(save_sample, inputs=[save_name_input, current_encode, encodeCache], outputs=[save_name_output, current_encode, encodeCache]) with gr.Blocks() as demo: gr.Markdown("WebUI for [DL for sound generation]. webUI version:1.0. Model version: [UNDERFIT_torch_15_5_2023].") current_encode = gr.State(value=np.ones((1, 24))) initial_examples = gr.State(value=INIT_ENCODE_CACHE) # initial_interpolation_examples = gr.State(value={"init": np.random.random(SPECTROGRAM_RESOLUTION)}) get_example_module(initial_examples) # get_reconstruction_module() get_random_sampling_module(initial_examples, current_encode) get_interpolation_module(initial_examples) demo.launch()