Spaces:
Paused
Paused
File size: 5,965 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import random
import pandas as pd
import torch
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from .utils import scale_shift_re
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@torch.no_grad()
def inference(autoencoder, unet, controlnet,
gt, gt_mask, condition,
tokenizer, text_encoder,
params, noise_scheduler,
text_raw, neg_text=None,
audio_frames=500,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2024,
conditioning_scale=1.0,
device='cuda',
):
if neg_text is None:
neg_text = [""]
if tokenizer is not None:
text_batch = tokenizer(text_raw,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
uncond_text_batch = tokenizer(neg_text,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
uncond_text = text_encoder(input_ids=uncond_text,
attention_mask=uncond_text_mask).last_hidden_state
else:
text, text_mask = None, None
guidance_scale = None
codec_dim = params['model']['out_chans']
unet.eval()
controlnet.eval()
if random_seed is not None:
generator = torch.Generator(device=device).manual_seed(random_seed)
else:
generator = torch.Generator(device=device)
generator.seed()
noise_scheduler.set_timesteps(ddim_steps)
# init noise
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
latents = noise
for t in noise_scheduler.timesteps:
latents = noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
latents_combined = torch.cat([latents, latents], dim=0)
text_combined = torch.cat([text, uncond_text], dim=0)
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
condition_combined = torch.cat([condition, condition], dim=0)
if gt is not None:
gt_combined = torch.cat([gt, gt], dim=0)
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
else:
gt_combined = None
gt_mask_combined = None
x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined,
forward_model=False)
controlnet_skips = controlnet(x, t, text_combined,
context_mask=text_mask_combined,
cls_token=None,
condition=condition_combined,
conditioning_scale=conditioning_scale)
output_combined = unet.model(x, t, text_combined,
context_mask=text_mask_combined,
cls_token=None, controlnet_skips=controlnet_skips)
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
x, _ = unet(latents, t, text, context_mask=text_mask,
cls_token=None, gt=gt, mae_mask_infer=gt_mask,
forward_model=False)
controlnet_skips = controlnet(x, t, text,
context_mask=text_mask,
cls_token=None,
condition=condition,
conditioning_scale=conditioning_scale)
output_pred = unet.model(x, t, text,
context_mask=text_mask,
cls_token=None, controlnet_skips=controlnet_skips)
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, params['autoencoder']['scale'],
params['autoencoder']['shift'])
if gt is not None:
pred[~gt_mask] = gt[~gt_mask]
pred_wav = autoencoder(embedding=pred)
return pred_wav |