Spaces:
Running
Running
File size: 4,794 Bytes
ae1bdf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import librosa
import numpy as np
import torch
from tqdm import tqdm
from tools import rms_normalize, decode_stft, depad_STFT
from model.DiffSynthSampler import DiffSynthSampler
def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
"Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing."
height = int(freq_resolution/VAE_scale)
width = int(time_resolution/VAE_scale)
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
text2sound_embedding = \
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
negative_condition = \
MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[
0].to(device)
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
mySampler.activate_classifier_free_guidance(CFG, negative_condition)
mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
condition = text2sound_embedding.repeat(batchsize, 1)
latent_representations, initial_noise = \
mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
return_tensor=True, condition=condition, sampler=sampler)
latent_representations = latent_representations[-1]
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
if return_latent:
return quantized_latent_representations.detach()
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
rec_signals = []
for index, STFT in enumerate(reconstruction_batch):
padded_D_rec = decode_stft(STFT)
D_rec = depad_STFT(padded_D_rec)
# get_audio
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
rec_signals.append(rms_normalize(rec_signal))
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer,
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
"Sample fix-length audio using a GAN, including 'ISTFT+' post-processing."
height = int(freq_resolution/VAE_scale)
width = int(time_resolution/VAE_scale)
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
text2sound_embedding = \
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
condition = text2sound_embedding.repeat(batchsize, 1)
noise = torch.randn(batchsize, channels, height, width).to(device)
latent_representations = gan_generator(noise, condition)
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
if return_latent:
return quantized_latent_representations.detach()
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
rec_signals = []
for index, STFT in enumerate(reconstruction_batch):
padded_D_rec = decode_stft(STFT)
D_rec = depad_STFT(padded_D_rec)
# get_audio
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
rec_signals.append(rms_normalize(rec_signal))
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
"Sample audios using a diffusion model, including 'ISTFT+' post-processing."
diffuSynth_signals = []
for _ in tqdm(range(num_batches)):
_, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
diffuSynth_signals.extend(signals)
return np.array(diffuSynth_signals)
|