import json import math from dataclasses import asdict from pathlib import Path from typing import Literal import pytorch_lightning as pl import torch from safetensors.torch import safe_open, save_file from torch import nn from torch.optim.lr_scheduler import LambdaLR from torchaudio.transforms import Resample from transformers import WavLMModel from .models import PatchVAE, PatchVAEConfig, WavVAE from .models.components.convnext import SwiGLU from .models.patchvae.modules import convnext_factory def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr return lr_lambda class TrainPatchVAE(pl.LightningModule): def __init__( self, config: PatchVAEConfig, lr: float = 1e-4, end_lr: float | None = None, weight_decay: float = 0.01, cfg_drop_rate: float = 0.0, cfg_drop_rate_per_sample: float = 0.1, total_steps: int = 2_000_000, warmup_steps: int = 0, mean_std_from_file: str | None = None, train_sigma: float = 1e-4, kl_div_factor: float = 1e-4, wavvae_pretrained_path: str | None = None, drop_vae_rate: float = 0.0, ssl_repa_factor: float = 1.0, ssl_repa_head: bool = False, ssl_repa_head_type: Literal["convnext", "swiglu"] = "swiglu", ssl_repa_head_target: Literal["flow", "prior"] = "prior", ssl_repa_head_num_layers: int = 4, ssl_repa_source: Literal["wavlm-base-plus"] = "wavlm-base-plus", ssl_feat_dim: int = 768, ): super().__init__() self.save_hyperparameters() if mean_std_from_file is not None: with safe_open(mean_std_from_file, framework="pt") as f: config.latent_scaling = ( f.get_tensor("mean").tolist(), f.get_tensor("std").tolist(), ) self.patchvae = PatchVAE(config) self.apply(self._init_weights) self.wavvae = None if wavvae_pretrained_path is not None: self.wavvae = WavVAE.from_pretrained_local(wavvae_pretrained_path).eval() self.ssl_model = None if ssl_repa_head: if ssl_repa_source == "wavlm-base-plus": self.ssl_model = WavLMModel.from_pretrained( "microsoft/wavlm-base-plus" ).eval() self.ssl_resampling = Resample( orig_freq=self.wavvae.sampling_rate, new_freq=16000 ) for p in self.ssl_model.parameters(): p.requires_grad = False if ssl_repa_head_type == "convnext": self.ssl_repa_head = nn.Sequential( convnext_factory( config.hidden_dim, ssl_repa_head_num_layers, ), nn.Linear(config.hidden_dim, ssl_feat_dim), ) elif ssl_repa_head_type == "swiglu": self.ssl_repa_head = nn.Sequential( *[ SwiGLU(config.hidden_dim) for _ in range(ssl_repa_head_num_layers) ], nn.Linear(config.hidden_dim, ssl_feat_dim), ) else: self.ssl_repa_head = None def _init_weights(self, m: nn.Module): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def configure_optimizers(self): params = [{"params": self.patchvae.parameters()}] if self.ssl_repa_head is not None: params += [{"params": self.ssl_repa_head.parameters()}] opt = torch.optim.AdamW( params, lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, ) if self.hparams.end_lr is not None: scheduler = LambdaLR( opt, lr_lambda=cosine_schedule_with_warmup( warmup_steps=self.hparams.warmup_steps, total_steps=self.hparams.total_steps, start_lr=self.hparams.lr, end_lr=self.hparams.end_lr, ), ) scheduler = {"scheduler": scheduler, "interval": "step"} return [opt], [scheduler] return opt def save_model_weights_and_config( self, dir: str | None, model_filename: str = "model.st", config_filename: str = "config.json", ): cfg = self.hparams.config model_path = Path(dir) / model_filename save_file(self.patchvae.state_dict(), model_path.with_suffix(".st")) with open(Path(dir) / config_filename, "w") as f: json.dump(asdict(cfg), f, indent=2) def training_step(self, batch: dict[str, torch.Tensor], batch_idx): z = batch["audio_z"] t = torch.rand(z.shape[0], device=z.device) drop_cond_rate = self.hparams.cfg_drop_rate_per_sample drop_vae_rate = self.hparams.drop_vae_rate flow_loss, ae_loss, prior = self.patchvae( z, t, sigma=self.hparams.train_sigma, drop_vae_rate=drop_vae_rate, drop_cond_rate=drop_cond_rate, ) self.log("train_flow_loss", flow_loss, prog_bar=True) total_loss = flow_loss if ae_loss.get("kl_div") is not None: kl_div = ae_loss.get("kl_div") self.log("train_kl_div", kl_div, prog_bar=True) total_loss += self.hparams.kl_div_factor * kl_div for x in ["_mu_mean", "_mu_std", "_logvar_mean", "_logvar_std"]: stat = ae_loss.get(x) if stat is not None: self.log(x, stat, prog_bar=False) if self.ssl_repa_head is not None: target = self.hparams.ssl_repa_head_target if target == "prior": target = prior elif target == "flow": raise NotImplementedError with torch.inference_mode(): wav = self.wavvae.decode(z) wav = self.ssl_resampling(wav) wav = torch.nn.functional.pad(wav, (40, 40)) ssl_feats = self.ssl_model(wav, output_hidden_states=True).hidden_states ssl_feat = ssl_feats[10] ssl_feat = torch.nn.functional.avg_pool1d( ssl_feat.transpose(-1, -2), kernel_size=8, stride=4, padding=2, ).transpose(-1, -2) ssl_feat = ssl_feat.clone() B, N, D = ssl_feat.shape repa_pred = self.ssl_repa_head(target) ssl_repa_loss = nn.functional.cosine_embedding_loss( repa_pred.reshape(-1, D), ssl_feat.reshape(-1, D), torch.ones(1).to(repa_pred), ) total_loss += self.hparams.ssl_repa_factor * ssl_repa_loss self.log("train_repa_loss", ssl_repa_loss, prog_bar=True) return total_loss def validation_step(self, batch: dict[str, torch.Tensor], batch_idx): z = batch["audio_z"] t = ( torch.ones(z.shape[0], device=z.device) * batch_idx / self.trainer.num_val_batches[0] ) flow_loss, ae_loss, prior = self.patchvae( z, t, sigma=self.hparams.train_sigma, drop_cond=False ) self.log("val_flow_loss", flow_loss, prog_bar=True) total_loss = flow_loss if ae_loss.get("kl_div") is not None: kl_div = ae_loss.get("kl_div") self.log("val_kl_div", kl_div, prog_bar=True) total_loss += self.hparams.kl_div_factor * kl_div return total_loss if __name__ == "__main__": from pytorch_lightning.cli import LightningCLI LightningCLI( TrainPatchVAE, save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"}, )