File size: 4,794 Bytes
ae1bdf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import librosa
import numpy as np
import torch
from tqdm import tqdm

from tools import rms_normalize, decode_stft, depad_STFT
from model.DiffSynthSampler import DiffSynthSampler

def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,

                    positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,

                    freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
    "Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing."

    height = int(freq_resolution/VAE_scale)
    width = int(time_resolution/VAE_scale)
    VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder

    text2sound_embedding = \
        MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
    negative_condition = \
        MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[
            0].to(device)

    mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
    mySampler.activate_classifier_free_guidance(CFG, negative_condition)

    mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))

    condition = text2sound_embedding.repeat(batchsize, 1)

    latent_representations, initial_noise = \
    mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
                      return_tensor=True, condition=condition, sampler=sampler)

    latent_representations = latent_representations[-1]

    quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)

    if return_latent:
        return quantized_latent_representations.detach()

    reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()

    rec_signals = []

    for index, STFT in enumerate(reconstruction_batch):
        padded_D_rec = decode_stft(STFT)
        D_rec = depad_STFT(padded_D_rec)
        # get_audio
        rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
        rec_signals.append(rms_normalize(rec_signal))

    return quantized_latent_representations.detach(), reconstruction_batch, rec_signals

def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer,

                    positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,

                    freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
    "Sample fix-length audio using a GAN, including 'ISTFT+' post-processing."

    height = int(freq_resolution/VAE_scale)
    width = int(time_resolution/VAE_scale)
    VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder

    text2sound_embedding = \
        MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)

    condition = text2sound_embedding.repeat(batchsize, 1)

    noise = torch.randn(batchsize, channels, height, width).to(device)
    latent_representations = gan_generator(noise, condition)

    quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)

    if return_latent:
        return quantized_latent_representations.detach()
    reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()

    rec_signals = []

    for index, STFT in enumerate(reconstruction_batch):
        padded_D_rec = decode_stft(STFT)
        D_rec = depad_STFT(padded_D_rec)
        # get_audio
        rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
        rec_signals.append(rms_normalize(rec_signal))

    return quantized_latent_representations.detach(), reconstruction_batch, rec_signals


def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
    "Sample audios using a diffusion model, including 'ISTFT+' post-processing."

    diffuSynth_signals = []
    for _ in tqdm(range(num_batches)):
        _, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
                                        positive_prompts=positive_prompts, negative_prompts=negative_prompts,
                                                      batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
        diffuSynth_signals.extend(signals)
    return np.array(diffuSynth_signals)