import librosa import numpy as np import torch import gradio as gr from scipy.ndimage import zoom from model.DiffSynthSampler import DiffSynthSampler from tools import adjust_audio_length, safe_int, pad_STFT, encode_stft from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, add_instrument def get_triangle_mask(height, width): mask = np.zeros((height, width)) slope = 8 / 3 for i in range(height): for j in range(width): if i < slope * j: mask[i, j] = 1 return mask def get_inpaint_with_text_module(gradioWebUI, inpaintWithText_state, virtual_instruments_state): # Load configurations uNet = gradioWebUI.uNet freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution VAE_scale = gradioWebUI.VAE_scale height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels timesteps = gradioWebUI.timesteps VAE_encoder = gradioWebUI.VAE_encoder VAE_quantizer = gradioWebUI.VAE_quantizer VAE_decoder = gradioWebUI.VAE_decoder CLAP = gradioWebUI.CLAP CLAP_tokenizer = gradioWebUI.CLAP_tokenizer device = gradioWebUI.device squared = gradioWebUI.squared sample_rate = gradioWebUI.sample_rate noise_strategy = gradioWebUI.noise_strategy def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone, inpaintWithText_dict): if sound2sound_origin_source == "upload": origin_sr, origin_audio = sound2sound_origin_upload else: origin_sr, origin_audio = sound2sound_origin_microphone origin_audio = origin_audio / np.max(np.abs(origin_audio)) width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale) audio_length = 256 * (VAE_scale * width - 1) origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate) D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) padded_D = pad_STFT(D) encoded_D = encode_stft(padded_D) # Todo: justify batchsize to 1 origin_spectrogram_batch_tensor = torch.from_numpy( np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) # Todo: remove hard-coding origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) if sound2sound_origin_source == "upload": inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist() inpaintWithText_dict[ "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image( origin_latent_representations[0]).tolist() inpaintWithText_dict[ "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( quantized_origin_latent_representations[0]).tolist() return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], sound2sound_origin_phase_upload_image: origin_flipped_phases[0], sound2sound_origin_spectrogram_microphone_image: gr.update(), sound2sound_origin_phase_microphone_image: gr.update(), sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( origin_latent_representations[0]), sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( quantized_origin_latent_representations[0]), sound2sound_origin_microphone_latent_representation_image: gr.update(), sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), inpaintWithText_state: inpaintWithText_dict} else: inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist() inpaintWithText_dict[ "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image( origin_latent_representations[0]).tolist() inpaintWithText_dict[ "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( quantized_origin_latent_representations[0]).tolist() return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], sound2sound_origin_phase_upload_image: origin_flipped_phases[0], sound2sound_origin_spectrogram_microphone_image: gr.update(), sound2sound_origin_phase_microphone_image: gr.update(), sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( origin_latent_representations[0]), sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( quantized_origin_latent_representations[0]), sound2sound_origin_microphone_latent_representation_image: gr.update(), sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), inpaintWithText_state: inpaintWithText_dict} def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone, text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize, sound2sound_guidance_scale, sound2sound_sampler, sound2sound_sample_steps, sound2sound_origin_source, sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, mask_time_begin, mask_time_end, mask_frequency_begin, mask_frequency_end, inpaintWithText_dict ): # input preprocessing sound2sound_seed = safe_int(sound2sound_seed, 12345678) sound2sound_batchsize = int(sound2sound_batchsize) noising_strength = sound2sound_noising_strength sound2sound_sample_steps = int(sound2sound_sample_steps) CFG = int(sound2sound_guidance_scale) text2sound_embedding = \ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device) if sound2sound_origin_source == "upload": origin_latent_representations = torch.tensor( inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( device) mask = np.array(sound2sound_origin_spectrogram_upload["mask"]) elif sound2sound_origin_source == "microphone": origin_latent_representations = torch.tensor( inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( device) mask = np.array(sound2sound_origin_spectrogram_microphone["mask"]) else: print("Input source not in ['upload', 'microphone']!") raise NotImplementedError() merged_mask = np.all(mask == 255, axis=2).astype(np.uint8) latent_mask = zoom(merged_mask, (1 / VAE_scale, 1 / VAE_scale)) latent_mask = np.clip(latent_mask, 0, 1) print(f"latent_mask.avg = {np.mean(latent_mask)}") latent_mask[int(mask_frequency_begin):int(mask_frequency_end), int(mask_time_begin*time_resolution/(VAE_scale*4)):int(mask_time_end*time_resolution/(VAE_scale*4))] = 1 # latent_mask = get_triangle_mask(128, 64) print(f"latent_mask.avg = {np.mean(latent_mask)}") if sound2sound_inpaint_area == "inpaint masked": latent_mask = 1 - latent_mask latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1, 1).float().to(device) latent_mask = torch.flip(latent_mask, [2]) mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) unconditional_condition = \ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) normalized_sample_steps = int(sound2sound_sample_steps / noising_strength) mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32))) # Todo: remove hard-coding width = origin_latent_representations.shape[-1] condition = text2sound_embedding.repeat(sound2sound_batchsize, 1) new_sound_latent_representations, initial_noise = \ mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width), seed=sound2sound_seed, noising_strength=noising_strength, guide_img=origin_latent_representations, mask=latent_mask, return_tensor=True, condition=condition, sampler=sound2sound_sampler) new_sound_latent_representations = new_sound_latent_representations[-1] # Quantize new sound latent representations quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations) new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, quantized_new_sound_latent_representations, resolution=( 512, width * VAE_scale), original_STFT_batch=None ) new_sound_latent_representation_gradio_images = [] new_sound_quantized_latent_representation_gradio_images = [] new_sound_spectrogram_gradio_images = [] new_sound_phase_gradio_images = [] new_sound_rec_signals_gradio = [] for i in range(sound2sound_batchsize): new_sound_latent_representation_gradio_images.append( latent_representation_to_Gradio_image(new_sound_latent_representations[i])) new_sound_quantized_latent_representation_gradio_images.append( latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i])) new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i]) new_sound_phase_gradio_images.append(new_sound_flipped_phases[i]) new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i])) inpaintWithText_dict[ "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images inpaintWithText_dict[ "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio inpaintWithText_dict["latent_representations"] = new_sound_latent_representations.to("cpu").detach().numpy() inpaintWithText_dict["quantized_latent_representations"] = quantized_new_sound_latent_representations.to("cpu").detach().numpy() inpaintWithText_dict["sampler"] = sound2sound_sampler return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image( new_sound_latent_representations[0]), sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image( quantized_new_sound_latent_representations[0]), sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0], sound2sound_new_sound_phase_image: new_sound_flipped_phases[0], sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]), sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0, step=1.0, visible=True, label="Sample index", info="Swipe to view other samples"), sound2sound_seed_textbox: sound2sound_seed, inpaintWithText_state: inpaintWithText_dict} def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict): sample_index = int(sound2sound_sample_index) return {sound2sound_new_sound_latent_representation_image: inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index], sound2sound_new_sound_quantized_latent_representation_image: inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index], sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][ sample_index], sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][ sample_index], sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]} def sound2sound_switch_origin_source(sound2sound_origin_source): if sound2sound_origin_source == "upload": return {sound2sound_origin_upload_audio: gr.update(visible=True), sound2sound_origin_microphone_audio: gr.update(visible=False), sound2sound_origin_spectrogram_upload_image: gr.update(visible=True), sound2sound_origin_phase_upload_image: gr.update(visible=True), sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False), sound2sound_origin_phase_microphone_image: gr.update(visible=False), sound2sound_origin_upload_latent_representation_image: gr.update(visible=True), sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True), sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False), sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)} elif sound2sound_origin_source == "microphone": return {sound2sound_origin_upload_audio: gr.update(visible=False), sound2sound_origin_microphone_audio: gr.update(visible=True), sound2sound_origin_spectrogram_upload_image: gr.update(visible=False), sound2sound_origin_phase_upload_image: gr.update(visible=False), sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True), sound2sound_origin_phase_microphone_image: gr.update(visible=True), sound2sound_origin_upload_latent_representation_image: gr.update(visible=False), sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False), sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True), sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)} else: print("Input source not in ['upload', 'microphone']!") def save_virtual_instrument(sample_index, virtual_instrument_name, sound2sound_dict, virtual_instruments_dict): virtual_instruments_dict = add_instrument(sound2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index) return {virtual_instruments_state: virtual_instruments_dict, sound2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1, placeholder=f"Saved as {virtual_instrument_name}!")} with gr.Tab("Inpaint"): gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!") with gr.Row(variant="panel"): with gr.Column(scale=3): text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") with gr.Column(scale=1): sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1) sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, label="Sample index", info="Swipe to view other samples") with gr.Row(variant="panel"): with gr.Column(scale=1): with gr.Tab("Origin sound"): sound2sound_duration_slider = gradioWebUI.get_duration_slider() sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload", label="Input source") sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload", interactive=True, visible=True) sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone", interactive=True, visible=False) with gr.Row(variant="panel"): sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram", type="numpy", height=600, visible=True, tool="sketch") sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase", type="numpy", height=600, visible=True) sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram", type="numpy", height=600, visible=False, tool="sketch") sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase", type="numpy", height=600, visible=False) sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"], value="inpaint masked") with gr.Tab("Sound2sound settings"): sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() sound2sound_sampler_radio = gradioWebUI.get_sampler_radio() sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider() sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() sound2sound_seed_textbox = gradioWebUI.get_seed_textbox() with gr.Tab("Mask prototypes"): with gr.Tab("Mask along time axis"): mask_time_begin_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="Begin time") mask_time_end_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="End time") with gr.Tab("Mask along frequency axis"): mask_frequency_begin_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="Begin freq pixel") mask_frequency_end_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="End freq pixel") with gr.Column(scale=1): sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) with gr.Row(variant="panel"): sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", height=600, scale=1) sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy", height=600, scale=1) with gr.Row(variant="panel"): sound2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1, placeholder="Name of your instrument") sound2sound_save_instrument_button = gr.Button(variant="primary", value="Save instrument", scale=1) with gr.Row(variant="panel"): sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation", type="numpy", height=800, visible=True) sound2sound_origin_upload_quantized_latent_representation_image = gr.Image( label="Original quantized latent representation", type="numpy", height=800, visible=True) sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation", type="numpy", height=800, visible=False) sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image( label="Original quantized latent representation", type="numpy", height=800, visible=False) sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation", type="numpy", height=800) sound2sound_new_sound_quantized_latent_representation_image = gr.Image( label="New sound quantized latent representation", type="numpy", height=800) sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio, inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, sound2sound_origin_microphone_audio, inpaintWithText_state], outputs=[sound2sound_origin_spectrogram_upload_image, sound2sound_origin_phase_upload_image, sound2sound_origin_spectrogram_microphone_image, sound2sound_origin_phase_microphone_image, sound2sound_origin_upload_latent_representation_image, sound2sound_origin_upload_quantized_latent_representation_image, sound2sound_origin_microphone_latent_representation_image, sound2sound_origin_microphone_quantized_latent_representation_image, inpaintWithText_state]) sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio, inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, sound2sound_origin_microphone_audio, inpaintWithText_state], outputs=[sound2sound_origin_spectrogram_upload_image, sound2sound_origin_phase_upload_image, sound2sound_origin_spectrogram_microphone_image, sound2sound_origin_phase_microphone_image, sound2sound_origin_upload_latent_representation_image, sound2sound_origin_upload_quantized_latent_representation_image, sound2sound_origin_microphone_latent_representation_image, sound2sound_origin_microphone_quantized_latent_representation_image, inpaintWithText_state]) sound2sound_sample_button.click(sound2sound_sample, inputs=[sound2sound_origin_spectrogram_upload_image, sound2sound_origin_spectrogram_microphone_image, text2sound_prompts_textbox, text2sound_negative_prompts_textbox, sound2sound_batchsize_slider, sound2sound_guidance_scale_slider, sound2sound_sampler_radio, sound2sound_sample_steps_slider, sound2sound_origin_source_radio, sound2sound_noising_strength_slider, sound2sound_seed_textbox, sound2sound_inpaint_area_radio, mask_time_begin_slider, mask_time_end_slider, mask_frequency_begin_slider, mask_frequency_end_slider, inpaintWithText_state], outputs=[sound2sound_new_sound_latent_representation_image, sound2sound_new_sound_quantized_latent_representation_image, sound2sound_new_sound_spectrogram_image, sound2sound_new_sound_phase_image, sound2sound_new_sound_audio, sound2sound_sample_index_slider, sound2sound_seed_textbox, inpaintWithText_state]) sound2sound_sample_index_slider.change(show_sound2sound_sample, inputs=[sound2sound_sample_index_slider, inpaintWithText_state], outputs=[sound2sound_new_sound_latent_representation_image, sound2sound_new_sound_quantized_latent_representation_image, sound2sound_new_sound_spectrogram_image, sound2sound_new_sound_phase_image, sound2sound_new_sound_audio]) sound2sound_origin_source_radio.change(sound2sound_switch_origin_source, inputs=[sound2sound_origin_source_radio], outputs=[sound2sound_origin_upload_audio, sound2sound_origin_microphone_audio, sound2sound_origin_spectrogram_upload_image, sound2sound_origin_phase_upload_image, sound2sound_origin_spectrogram_microphone_image, sound2sound_origin_phase_microphone_image, sound2sound_origin_upload_latent_representation_image, sound2sound_origin_upload_quantized_latent_representation_image, sound2sound_origin_microphone_latent_representation_image, sound2sound_origin_microphone_quantized_latent_representation_image]) sound2sound_save_instrument_button.click(save_virtual_instrument, inputs=[sound2sound_sample_index_slider, sound2sound_instrument_name_textbox, inpaintWithText_state, virtual_instruments_state], outputs=[virtual_instruments_state, sound2sound_instrument_name_textbox])