import librosa import numpy as np import torch from tqdm import tqdm from tools import rms_normalize, decode_stft, depad_STFT from model.DiffSynthSampler import DiffSynthSampler def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer, positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): "Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing." height = int(freq_resolution/VAE_scale) width = int(time_resolution/VAE_scale) VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder text2sound_embedding = \ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) negative_condition = \ MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[ 0].to(device) mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True) mySampler.activate_classifier_free_guidance(CFG, negative_condition) mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32))) condition = text2sound_embedding.repeat(batchsize, 1) latent_representations, initial_noise = \ mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed, return_tensor=True, condition=condition, sampler=sampler) latent_representations = latent_representations[-1] quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) if return_latent: return quantized_latent_representations.detach() reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() rec_signals = [] for index, STFT in enumerate(reconstruction_batch): padded_D_rec = decode_stft(STFT) D_rec = depad_STFT(padded_D_rec) # get_audio rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) rec_signals.append(rms_normalize(rec_signal)) return quantized_latent_representations.detach(), reconstruction_batch, rec_signals def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer, positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): "Sample fix-length audio using a GAN, including 'ISTFT+' post-processing." height = int(freq_resolution/VAE_scale) width = int(time_resolution/VAE_scale) VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder text2sound_embedding = \ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) condition = text2sound_embedding.repeat(batchsize, 1) noise = torch.randn(batchsize, channels, height, width).to(device) latent_representations = gan_generator(noise, condition) quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) if return_latent: return quantized_latent_representations.detach() reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() rec_signals = [] for index, STFT in enumerate(reconstruction_batch): padded_D_rec = decode_stft(STFT) D_rec = depad_STFT(padded_D_rec) # get_audio rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) rec_signals.append(rms_normalize(rec_signal)) return quantized_latent_representations.detach(), reconstruction_batch, rec_signals def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): "Sample audios using a diffusion model, including 'ISTFT+' post-processing." diffuSynth_signals = [] for _ in tqdm(range(num_batches)): _, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer, positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False) diffuSynth_signals.extend(signals) return np.array(diffuSynth_signals)