Spaces:
Running
Running
| import torch | |
| from torch import nn, Tensor | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers import UNet2DConditionGuidedModel, HeunDiscreteScheduler | |
| from audioldm.stft import TacotronSTFT | |
| from audioldm.variational_autoencoder import AutoencoderKL | |
| from audioldm.utils import default_audioldm_config | |
| class ConsistencyTTA(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Initialize the consistency U-Net | |
| unet_model_config_path='tango_diffusion_light.json' | |
| unet_config = UNet2DConditionGuidedModel.load_config(unet_model_config_path) | |
| self.unet = UNet2DConditionGuidedModel.from_config(unet_config, subfolder="unet") | |
| unet_weight_path = "consistencytta_clapft_ckpt/unet_state_dict.pt" | |
| unet_weight_sd = torch.load(unet_weight_path, map_location='cpu') | |
| self.unet.load_state_dict(unet_weight_sd) | |
| # Initialize FLAN-T5 tokenizer and text encoder | |
| text_encoder_name = 'google/flan-t5-large' | |
| self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) | |
| self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name) | |
| self.text_encoder.eval(); self.text_encoder.requires_grad_(False) | |
| # Initialize the VAE | |
| raw_vae_path = "consistencytta_clapft_ckpt/vae_state_dict.pt" | |
| raw_vae_sd = torch.load(raw_vae_path, map_location="cpu") | |
| vae_state_dict, scale_factor = raw_vae_sd["state_dict"], raw_vae_sd["scale_factor"] | |
| config = default_audioldm_config('audioldm-s-full') | |
| vae_config = config["model"]["params"]["first_stage_config"]["params"] | |
| vae_config["scale_factor"] = scale_factor | |
| self.vae = AutoencoderKL(**vae_config) | |
| self.vae.load_state_dict(vae_state_dict) | |
| self.vae.eval(); self.vae.requires_grad_(False) | |
| # Initialize the STFT | |
| self.fn_STFT = TacotronSTFT( | |
| config["preprocessing"]["stft"]["filter_length"], # default 1024 | |
| config["preprocessing"]["stft"]["hop_length"], # default 160 | |
| config["preprocessing"]["stft"]["win_length"], # default 1024 | |
| config["preprocessing"]["mel"]["n_mel_channels"], # default 64 | |
| config["preprocessing"]["audio"]["sampling_rate"], # default 16000 | |
| config["preprocessing"]["mel"]["mel_fmin"], # default 0 | |
| config["preprocessing"]["mel"]["mel_fmax"], # default 8000 | |
| ) | |
| self.fn_STFT.eval(); self.fn_STFT.requires_grad_(False) | |
| self.scheduler = HeunDiscreteScheduler.from_pretrained( | |
| pretrained_model_name_or_path='stabilityai/stable-diffusion-2-1', subfolder="scheduler" | |
| ) | |
| def train(self, mode: bool = True): | |
| self.unet.train(mode) | |
| for model in [self.text_encoder, self.vae, self.fn_STFT]: | |
| model.eval() | |
| return self | |
| def eval(self): | |
| return self.train(mode=False) | |
| def check_eval_mode(self): | |
| for model, name in zip( | |
| [self.text_encoder, self.vae, self.fn_STFT, self.unet], | |
| ['text_encoder', 'vae', 'fn_STFT', 'unet'] | |
| ): | |
| assert model.training == False, f"The {name} is not in eval mode." | |
| for param in model.parameters(): | |
| assert param.requires_grad == False, f"The {name} is not frozen." | |
| def encode_text(self, prompt, max_length=None, padding=True): | |
| device = self.text_encoder.device | |
| if max_length is None: | |
| max_length = self.tokenizer.model_max_length | |
| batch = self.tokenizer( | |
| prompt, max_length=max_length, padding=padding, | |
| truncation=True, return_tensors="pt" | |
| ) | |
| input_ids = batch.input_ids.to(device) | |
| attention_mask = batch.attention_mask.to(device) | |
| prompt_embeds = self.text_encoder( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )[0] | |
| bool_prompt_mask = (attention_mask == 1).to(device) # Convert to boolean | |
| return prompt_embeds, bool_prompt_mask | |
| def encode_text_classifier_free(self, prompt: str, num_samples_per_prompt: int): | |
| # get conditional embeddings | |
| cond_prompt_embeds, cond_prompt_mask = self.encode_text(prompt) | |
| cond_prompt_embeds = cond_prompt_embeds.repeat_interleave( | |
| num_samples_per_prompt, 0 | |
| ) | |
| cond_prompt_mask = cond_prompt_mask.repeat_interleave( | |
| num_samples_per_prompt, 0 | |
| ) | |
| # get unconditional embeddings for classifier free guidance | |
| uncond_tokens = [""] * len(prompt) | |
| negative_prompt_embeds, uncond_prompt_mask = self.encode_text( | |
| uncond_tokens, max_length=cond_prompt_embeds.shape[1], padding="max_length" | |
| ) | |
| negative_prompt_embeds = negative_prompt_embeds.repeat_interleave( | |
| num_samples_per_prompt, 0 | |
| ) | |
| uncond_prompt_mask = uncond_prompt_mask.repeat_interleave( | |
| num_samples_per_prompt, 0 | |
| ) | |
| """ For classifier-free guidance, we need to do two forward passes. | |
| We concatenate the unconditional and text embeddings into a single batch | |
| """ | |
| prompt_embeds = torch.cat([negative_prompt_embeds, cond_prompt_embeds]) | |
| prompt_mask = torch.cat([uncond_prompt_mask, cond_prompt_mask]) | |
| return prompt_embeds, prompt_mask, cond_prompt_embeds, cond_prompt_mask | |
| def forward( | |
| self, prompt: str, cfg_scale_input: float = 3., cfg_scale_post: float = 1., | |
| num_steps: int = 1, num_samples: int = 1, sr: int = 16000 | |
| ): | |
| self.check_eval_mode() | |
| device = self.text_encoder.device | |
| use_cf_guidance = cfg_scale_post > 1. | |
| # Get prompt embeddings | |
| prompt_embeds_cf, prompt_mask_cf, prompt_embeds, prompt_mask = \ | |
| self.encode_text_classifier_free(prompt, num_samples) | |
| encoder_states, encoder_att_mask = \ | |
| (prompt_embeds_cf, prompt_mask_cf) if use_cf_guidance \ | |
| else (prompt_embeds, prompt_mask) | |
| # Prepare noise | |
| num_channels_latents = self.unet.config.in_channels | |
| latent_shape = (len(prompt) * num_samples, num_channels_latents, 256, 16) | |
| noise = randn_tensor( | |
| latent_shape, generator=None, device=device, dtype=prompt_embeds.dtype | |
| ) | |
| # Query the inference scheduler to obtain the time steps. | |
| # The time steps spread between 0 and training time steps | |
| self.scheduler.set_timesteps(18, device=device) # Set this to training steps first | |
| z_N = noise * self.scheduler.init_noise_sigma | |
| def calc_zhat_0(z_n: Tensor, t: int): | |
| """ Query the consistency model to get zhat_0, which is the denoised embedding. | |
| Args: | |
| z_n (Tensor): The noisy embedding. | |
| t (int): The time step. | |
| Returns: | |
| Tensor: The denoised embedding. | |
| """ | |
| # expand the latents if we are doing classifier free guidance | |
| z_n_input = torch.cat([z_n] * 2) if use_cf_guidance else z_n | |
| # Scale model input as required for some schedules. | |
| z_n_input = self.scheduler.scale_model_input(z_n_input, t) | |
| # Get zhat_0 from the model | |
| zhat_0 = self.unet( | |
| z_n_input, t, guidance=cfg_scale_input, | |
| encoder_hidden_states=encoder_states, encoder_attention_mask=encoder_att_mask | |
| ).sample | |
| # Perform external classifier-free guidance | |
| if use_cf_guidance: | |
| zhat_0_uncond, zhat_0_cond = zhat_0.chunk(2) | |
| zhat_0 = (1 - cfg_scale_post) * zhat_0_uncond + cfg_scale_post * zhat_0_cond | |
| return zhat_0 | |
| # Query the consistency model | |
| zhat_0 = calc_zhat_0(z_N, self.scheduler.timesteps[0]) | |
| # Iteratively query the consistency model if requested | |
| self.scheduler.set_timesteps(num_steps, device=device) | |
| for t in self.scheduler.timesteps[1::2]: # 2 is the order of the scheduler | |
| zhat_n = self.scheduler.add_noise(zhat_0, torch.randn_like(zhat_0), t) | |
| # Calculate new zhat_0 | |
| zhat_0 = calc_zhat_0(zhat_n, t) | |
| mel = self.vae.decode_first_stage(zhat_0.float()) | |
| return self.vae.decode_to_waveform(mel)[:, :int(sr * 9.5)] # Truncate to 9.6 seconds | |