DiffuSynthV0.2 / webUI /natural_language_guided /sound2sound_with_text.py
WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
import gradio as gr
import librosa
import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from tools import pad_STFT, encode_stft
from tools import safe_int, adjust_audio_length
from webUI.natural_language_guided.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, \
latent_representation_to_Gradio_image
def get_sound2sound_with_text_module(gradioWebUI, sound2sound_with_text_state, virtual_instruments_state):
# Load configurations
uNet = gradioWebUI.uNet
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
VAE_scale = gradioWebUI.VAE_scale
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
timesteps = gradioWebUI.timesteps
VAE_encoder = gradioWebUI.VAE_encoder
VAE_quantizer = gradioWebUI.VAE_quantizer
VAE_decoder = gradioWebUI.VAE_decoder
CLAP = gradioWebUI.CLAP
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
device = gradioWebUI.device
squared = gradioWebUI.squared
sample_rate = gradioWebUI.sample_rate
noise_strategy = gradioWebUI.noise_strategy
def receive_upload_origin_audio(sound2sound_duration, sound2sound_origin_source,
sound2sound_origin_upload, sound2sound_origin_microphone,
sound2sound_with_text_dict, virtual_instruments_dict):
if sound2sound_origin_source == "upload":
origin_sr, origin_audio = sound2sound_origin_upload
else:
origin_sr, origin_audio = sound2sound_origin_microphone
origin_audio = origin_audio / np.max(np.abs(origin_audio))
width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale)
audio_length = 256 * (VAE_scale * width - 1)
origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate)
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
padded_D = pad_STFT(D)
encoded_D = encode_stft(padded_D)
# Todo: justify batchsize to 1
origin_spectrogram_batch_tensor = torch.from_numpy(
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
# Todo: remove hard-coding
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared)
default_condition = CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to("cpu").detach().numpy()
if sound2sound_origin_source == "upload":
sound2sound_with_text_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist()
sound2sound_with_text_dict[
"sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image(
origin_latent_representations[0]).tolist()
sound2sound_with_text_dict[
"sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
quantized_origin_latent_representations[0]).tolist()
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
virtual_instrument = {"condition": default_condition,
"negative_condition": default_condition, # care!!!
"CFG": 1,
"latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(),
"quantized_latent_representation": quantized_origin_latent_representations[0].to("cpu").detach().numpy(),
"sampler": "ddim",
"signal": (sample_rate, origin_audio),
"spectrogram_gradio_image": origin_flipped_log_spectrums[0],
"phase_gradio_image": origin_flipped_phases[0]}
virtual_instruments["s2sup"] = virtual_instrument
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0],
sound2sound_origin_phase_upload_image: origin_flipped_phases[0],
sound2sound_origin_spectrogram_microphone_image: gr.update(),
sound2sound_origin_phase_microphone_image: gr.update(),
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image(
origin_latent_representations[0]),
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image(
quantized_origin_latent_representations[0]),
sound2sound_origin_microphone_latent_representation_image: gr.update(),
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(),
sound2sound_with_text_state: sound2sound_with_text_dict,
virtual_instruments_state: virtual_instruments_dict}
else:
sound2sound_with_text_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist()
sound2sound_with_text_dict[
"sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image(
origin_latent_representations[0]).tolist()
sound2sound_with_text_dict[
"sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image(
quantized_origin_latent_representations[0]).tolist()
virtual_instruments = virtual_instruments_dict["virtual_instruments"]
virtual_instrument = {"condition": default_condition,
"negative_condition": default_condition, # care!!!
"CFG": 1,
"latent_representation": origin_latent_representations[0],
"quantized_latent_representation": quantized_origin_latent_representations[0],
"sampler": "ddim",
"signal": origin_audio,
"spectrogram_gradio_image": origin_flipped_log_spectrums[0]}
virtual_instruments["s2sre"] = virtual_instrument
virtual_instruments_dict["virtual_instruments"] = virtual_instruments
return {sound2sound_origin_spectrogram_upload_image: gr.update(),
sound2sound_origin_phase_upload_image: gr.update(),
sound2sound_origin_spectrogram_microphone_image: origin_flipped_log_spectrums[0],
sound2sound_origin_phase_microphone_image: origin_flipped_phases[0],
sound2sound_origin_upload_latent_representation_image: gr.update(),
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(),
sound2sound_origin_microphone_latent_representation_image: latent_representation_to_Gradio_image(
origin_latent_representations[0]),
sound2sound_origin_microphone_quantized_latent_representation_image: latent_representation_to_Gradio_image(
quantized_origin_latent_representations[0]),
sound2sound_with_text_state: sound2sound_with_text_dict,
virtual_instruments_state: virtual_instruments_dict}
def sound2sound_sample(sound2sound_prompts, sound2sound_negative_prompts, sound2sound_batchsize,
sound2sound_guidance_scale, sound2sound_sampler,
sound2sound_sample_steps,
sound2sound_origin_source,
sound2sound_noising_strength, sound2sound_seed, sound2sound_dict, virtual_instruments_dict):
# input processing
sound2sound_seed = safe_int(sound2sound_seed, 12345678)
sound2sound_batchsize = int(sound2sound_batchsize)
noising_strength = sound2sound_noising_strength
sound2sound_sample_steps = int(sound2sound_sample_steps)
CFG = int(sound2sound_guidance_scale)
if sound2sound_origin_source == "upload":
origin_latent_representations = torch.tensor(
sound2sound_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
device)
elif sound2sound_origin_source == "microphone":
origin_latent_representations = torch.tensor(
sound2sound_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to(
device)
else:
print("Input source not in ['upload', 'microphone']!")
raise NotImplementedError()
# sound2sound
text2sound_embedding = \
CLAP.get_text_features(**CLAP_tokenizer([sound2sound_prompts], padding=True, return_tensors="pt"))[0].to(
device)
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
unconditional_condition = \
CLAP.get_text_features(**CLAP_tokenizer([sound2sound_negative_prompts], padding=True, return_tensors="pt"))[
0]
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
normalized_sample_steps = int(sound2sound_sample_steps / noising_strength)
mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32)))
condition = text2sound_embedding.repeat(sound2sound_batchsize, 1)
# Todo: remove-hard coding
width = origin_latent_representations.shape[-1]
new_sound_latent_representations, initial_noise = \
mySampler.img_guided_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width),
seed=sound2sound_seed,
noising_strength=noising_strength,
guide_img=origin_latent_representations, return_tensor=True,
condition=condition,
sampler=sound2sound_sampler)
new_sound_latent_representations = new_sound_latent_representations[-1]
# Quantize new sound latent representations
quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations)
new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
quantized_new_sound_latent_representations,
resolution=(
512,
width * VAE_scale),
original_STFT_batch=None
)
new_sound_latent_representation_gradio_images = []
new_sound_quantized_latent_representation_gradio_images = []
new_sound_spectrogram_gradio_images = []
new_sound_phase_gradio_images = []
new_sound_rec_signals_gradio = []
for i in range(sound2sound_batchsize):
new_sound_latent_representation_gradio_images.append(
latent_representation_to_Gradio_image(new_sound_latent_representations[i]))
new_sound_quantized_latent_representation_gradio_images.append(
latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i]))
new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i])
new_sound_phase_gradio_images.append(new_sound_flipped_phases[i])
new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i]))
sound2sound_dict[
"new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images
sound2sound_dict[
"new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images
sound2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
sound2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
sound2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image(
new_sound_latent_representations[0]),
sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image(
quantized_new_sound_latent_representations[0]),
sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0],
sound2sound_new_sound_phase_image: new_sound_flipped_phases[0],
sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]),
sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0,
step=1.0,
visible=True,
label="Sample index",
info="Swipe to view other samples"),
sound2sound_seed_textbox: sound2sound_seed,
sound2sound_with_text_state: sound2sound_dict,
virtual_instruments_state: virtual_instruments_dict}
def show_sound2sound_sample(sound2sound_sample_index, sound2sound_with_text_dict):
sample_index = int(sound2sound_sample_index)
return {sound2sound_new_sound_latent_representation_image:
sound2sound_with_text_dict["new_sound_latent_representation_gradio_images"][sample_index],
sound2sound_new_sound_quantized_latent_representation_image:
sound2sound_with_text_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index],
sound2sound_new_sound_spectrogram_image: sound2sound_with_text_dict["new_sound_spectrogram_gradio_images"][
sample_index],
sound2sound_new_sound_phase_image: sound2sound_with_text_dict["new_sound_phase_gradio_images"][
sample_index],
sound2sound_new_sound_audio: sound2sound_with_text_dict["new_sound_rec_signals_gradio"][sample_index]}
def sound2sound_switch_origin_source(sound2sound_origin_source):
if sound2sound_origin_source == "upload":
return {sound2sound_origin_upload_audio: gr.update(visible=True),
sound2sound_origin_microphone_audio: gr.update(visible=False),
sound2sound_origin_spectrogram_upload_image: gr.update(visible=True),
sound2sound_origin_phase_upload_image: gr.update(visible=True),
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False),
sound2sound_origin_phase_microphone_image: gr.update(visible=False),
sound2sound_origin_upload_latent_representation_image: gr.update(visible=True),
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True),
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False),
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)}
elif sound2sound_origin_source == "microphone":
return {sound2sound_origin_upload_audio: gr.update(visible=False),
sound2sound_origin_microphone_audio: gr.update(visible=True),
sound2sound_origin_spectrogram_upload_image: gr.update(visible=False),
sound2sound_origin_phase_upload_image: gr.update(visible=False),
sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True),
sound2sound_origin_phase_microphone_image: gr.update(visible=True),
sound2sound_origin_upload_latent_representation_image: gr.update(visible=False),
sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False),
sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True),
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)}
else:
print("Input source not in ['upload', 'microphone']!")
with gr.Tab("Sound2Sound"):
gr.Markdown("Generate new sound based on a given sound!")
with gr.Row(variant="panel"):
with gr.Column(scale=3):
sound2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
with gr.Column(scale=1):
sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1)
sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
label="Sample index",
info="Swipe to view other samples")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
with gr.Tab("Origin sound"):
sound2sound_duration_slider = gradioWebUI.get_duration_slider()
sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload",
label="Input source")
sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload",
interactive=True, visible=True)
sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone",
interactive=True, visible=False)
with gr.Row(variant="panel"):
sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram",
type="numpy", height=600,
visible=True)
sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase",
type="numpy", height=600,
visible=True)
sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram",
type="numpy", height=600,
visible=False)
sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase",
type="numpy", height=600,
visible=False)
with gr.Tab("Sound2sound settings"):
sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
sound2sound_sampler_radio = gradioWebUI.get_sampler_radio()
sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider()
sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
sound2sound_seed_textbox = gradioWebUI.get_seed_textbox()
with gr.Column(scale=1):
sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False)
with gr.Row(variant="panel"):
sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
height=600, scale=1)
sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
height=600, scale=1)
with gr.Row(variant="panel"):
sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation",
type="numpy", height=800,
visible=True)
sound2sound_origin_upload_quantized_latent_representation_image = gr.Image(
label="Original quantized latent representation", type="numpy", height=800, visible=True)
sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation",
type="numpy", height=800,
visible=False)
sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image(
label="Original quantized latent representation", type="numpy", height=800, visible=False)
sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation",
type="numpy", height=800)
sound2sound_new_sound_quantized_latent_representation_image = gr.Image(
label="New sound quantized latent representation", type="numpy", height=800)
sound2sound_origin_upload_audio.change(receive_upload_origin_audio,
inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio,
sound2sound_origin_upload_audio,
sound2sound_origin_microphone_audio, sound2sound_with_text_state,
virtual_instruments_state],
outputs=[sound2sound_origin_spectrogram_upload_image,
sound2sound_origin_phase_upload_image,
sound2sound_origin_spectrogram_microphone_image,
sound2sound_origin_phase_microphone_image,
sound2sound_origin_upload_latent_representation_image,
sound2sound_origin_upload_quantized_latent_representation_image,
sound2sound_origin_microphone_latent_representation_image,
sound2sound_origin_microphone_quantized_latent_representation_image,
sound2sound_with_text_state,
virtual_instruments_state])
sound2sound_origin_microphone_audio.change(receive_upload_origin_audio,
inputs=[sound2sound_duration_slider,
sound2sound_origin_source_radio, sound2sound_origin_upload_audio,
sound2sound_origin_microphone_audio, sound2sound_with_text_state,
virtual_instruments_state],
outputs=[sound2sound_origin_spectrogram_upload_image,
sound2sound_origin_phase_upload_image,
sound2sound_origin_spectrogram_microphone_image,
sound2sound_origin_phase_microphone_image,
sound2sound_origin_upload_latent_representation_image,
sound2sound_origin_upload_quantized_latent_representation_image,
sound2sound_origin_microphone_latent_representation_image,
sound2sound_origin_microphone_quantized_latent_representation_image,
sound2sound_with_text_state,
virtual_instruments_state])
sound2sound_sample_button.click(sound2sound_sample,
inputs=[sound2sound_prompts_textbox,
text2sound_negative_prompts_textbox,
sound2sound_batchsize_slider,
sound2sound_guidance_scale_slider,
sound2sound_sampler_radio,
sound2sound_sample_steps_slider,
sound2sound_origin_source_radio,
sound2sound_noising_strength_slider,
sound2sound_seed_textbox,
sound2sound_with_text_state,
virtual_instruments_state],
outputs=[sound2sound_new_sound_latent_representation_image,
sound2sound_new_sound_quantized_latent_representation_image,
sound2sound_new_sound_spectrogram_image,
sound2sound_new_sound_phase_image,
sound2sound_new_sound_audio,
sound2sound_sample_index_slider,
sound2sound_seed_textbox,
sound2sound_with_text_state,
virtual_instruments_state])
sound2sound_sample_index_slider.change(show_sound2sound_sample,
inputs=[sound2sound_sample_index_slider, sound2sound_with_text_state],
outputs=[sound2sound_new_sound_latent_representation_image,
sound2sound_new_sound_quantized_latent_representation_image,
sound2sound_new_sound_spectrogram_image,
sound2sound_new_sound_phase_image,
sound2sound_new_sound_audio])
sound2sound_origin_source_radio.change(sound2sound_switch_origin_source,
inputs=[sound2sound_origin_source_radio],
outputs=[sound2sound_origin_upload_audio,
sound2sound_origin_microphone_audio,
sound2sound_origin_spectrogram_upload_image,
sound2sound_origin_phase_upload_image,
sound2sound_origin_spectrogram_microphone_image,
sound2sound_origin_phase_microphone_image,
sound2sound_origin_upload_latent_representation_image,
sound2sound_origin_upload_quantized_latent_representation_image,
sound2sound_origin_microphone_latent_representation_image,
sound2sound_origin_microphone_quantized_latent_representation_image])