import gradio as gr import numpy as np import torch from tools import safe_int from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \ add_instrument def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state): # Load configurations gan_generator = gradioWebUI.GAN_generator freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution VAE_scale = gradioWebUI.VAE_scale height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels timesteps = gradioWebUI.timesteps VAE_quantizer = gradioWebUI.VAE_quantizer VAE_decoder = gradioWebUI.VAE_decoder CLAP = gradioWebUI.CLAP CLAP_tokenizer = gradioWebUI.CLAP_tokenizer device = gradioWebUI.device squared = gradioWebUI.squared sample_rate = gradioWebUI.sample_rate noise_strategy = gradioWebUI.noise_strategy def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize, text2sound_duration, text2sound_guidance_scale, text2sound_sampler, text2sound_sample_steps, text2sound_seed, text2sound_dict): text2sound_seed = safe_int(text2sound_seed, 12345678) width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale) text2sound_batchsize = int(text2sound_batchsize) text2sound_embedding = \ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to( device) CFG = int(text2sound_guidance_scale) condition = text2sound_embedding.repeat(text2sound_batchsize, 1) noise = torch.randn(text2sound_batchsize, channels, height, width).to(device) latent_representations = gan_generator(noise, condition) print(latent_representations[0, 0, :3, :3]) latent_representation_gradio_images = [] quantized_latent_representation_gradio_images = [] new_sound_spectrogram_gradio_images = [] new_sound_rec_signals_gradio = [] quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) # Todo: remove hard-coding flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations, resolution=(512, width * VAE_scale), centralized=False, squared=squared) for i in range(text2sound_batchsize): latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) quantized_latent_representation_gradio_images.append( latent_representation_to_Gradio_image(quantized_latent_representations[i])) new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy() text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy() text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio text2sound_dict["condition"] = condition.to("cpu").detach().numpy() # text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy() text2sound_dict["guidance_scale"] = CFG text2sound_dict["sampler"] = text2sound_sampler return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0], text2sound_quantized_latent_representation_image: text2sound_dict["quantized_latent_representation_gradio_images"][0], text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0], text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0], text2sound_seed_textbox: text2sound_seed, text2sound_state: text2sound_dict, text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, visible=True, label="Sample index.", info="Swipe to view other samples")} def show_random_sample(sample_index, text2sound_dict): sample_index = int(sample_index) text2sound_dict["sample_index"] = sample_index return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ sample_index], text2sound_quantized_latent_representation_image: text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][ sample_index], text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} with gr.Tab("Text2sound_GAN"): gr.Markdown("Use neural networks to select random sounds using your favorite instrument!") with gr.Row(variant="panel"): with gr.Column(scale=3): text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") with gr.Column(scale=1): text2sound_sampling_button = gr.Button(variant="primary", value="Generate a batch of samples and show " "the first one", scale=1) text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, label="Sample index", info="Swipe to view other samples") with gr.Row(variant="panel"): with gr.Column(scale=1, variant="panel"): text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() text2sound_sampler_radio = gradioWebUI.get_sampler_radio() text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() text2sound_duration_slider = gradioWebUI.get_duration_slider() text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() text2sound_seed_textbox = gradioWebUI.get_seed_textbox() with gr.Column(scale=1): text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420) text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") with gr.Row(variant="panel"): text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", height=200, width=100) text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", type="numpy", height=200, width=100) text2sound_sampling_button.click(gan_random_sample, inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox, text2sound_batchsize_slider, text2sound_duration_slider, text2sound_guidance_scale_slider, text2sound_sampler_radio, text2sound_sample_steps_slider, text2sound_seed_textbox, text2sound_state], outputs=[text2sound_latent_representation_image, text2sound_quantized_latent_representation_image, text2sound_sampled_spectrogram_image, text2sound_sampled_audio, text2sound_seed_textbox, text2sound_state, text2sound_sample_index_slider]) text2sound_sample_index_slider.change(show_random_sample, inputs=[text2sound_sample_index_slider, text2sound_state], outputs=[text2sound_latent_representation_image, text2sound_quantized_latent_representation_image, text2sound_sampled_spectrogram_image, text2sound_sampled_audio])