Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import math | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Literal | |
| import pytorch_lightning as pl | |
| import torch | |
| from safetensors.torch import safe_open, save_file | |
| from torch import nn | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torchaudio.transforms import Resample | |
| from transformers import WavLMModel | |
| from .models import PatchVAE, PatchVAEConfig, WavVAE | |
| from .models.components.convnext import SwiGLU | |
| from .models.patchvae.modules import convnext_factory | |
| def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return step / max(1, warmup_steps) | |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) | |
| return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr | |
| return lr_lambda | |
| class TrainPatchVAE(pl.LightningModule): | |
| def __init__( | |
| self, | |
| config: PatchVAEConfig, | |
| lr: float = 1e-4, | |
| end_lr: float | None = None, | |
| weight_decay: float = 0.01, | |
| cfg_drop_rate: float = 0.0, | |
| cfg_drop_rate_per_sample: float = 0.1, | |
| total_steps: int = 2_000_000, | |
| warmup_steps: int = 0, | |
| mean_std_from_file: str | None = None, | |
| train_sigma: float = 1e-4, | |
| kl_div_factor: float = 1e-4, | |
| wavvae_pretrained_path: str | None = None, | |
| drop_vae_rate: float = 0.0, | |
| ssl_repa_factor: float = 1.0, | |
| ssl_repa_head: bool = False, | |
| ssl_repa_head_type: Literal["convnext", "swiglu"] = "swiglu", | |
| ssl_repa_head_target: Literal["flow", "prior"] = "prior", | |
| ssl_repa_head_num_layers: int = 4, | |
| ssl_repa_source: Literal["wavlm-base-plus"] = "wavlm-base-plus", | |
| ssl_feat_dim: int = 768, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| if mean_std_from_file is not None: | |
| with safe_open(mean_std_from_file, framework="pt") as f: | |
| config.latent_scaling = ( | |
| f.get_tensor("mean").tolist(), | |
| f.get_tensor("std").tolist(), | |
| ) | |
| self.patchvae = PatchVAE(config) | |
| self.apply(self._init_weights) | |
| self.wavvae = None | |
| if wavvae_pretrained_path is not None: | |
| self.wavvae = WavVAE.from_pretrained_local(wavvae_pretrained_path).eval() | |
| self.ssl_model = None | |
| if ssl_repa_head: | |
| if ssl_repa_source == "wavlm-base-plus": | |
| self.ssl_model = WavLMModel.from_pretrained( | |
| "microsoft/wavlm-base-plus" | |
| ).eval() | |
| self.ssl_resampling = Resample( | |
| orig_freq=self.wavvae.sampling_rate, new_freq=16000 | |
| ) | |
| for p in self.ssl_model.parameters(): | |
| p.requires_grad = False | |
| if ssl_repa_head_type == "convnext": | |
| self.ssl_repa_head = nn.Sequential( | |
| convnext_factory( | |
| config.hidden_dim, | |
| ssl_repa_head_num_layers, | |
| ), | |
| nn.Linear(config.hidden_dim, ssl_feat_dim), | |
| ) | |
| elif ssl_repa_head_type == "swiglu": | |
| self.ssl_repa_head = nn.Sequential( | |
| *[ | |
| SwiGLU(config.hidden_dim) | |
| for _ in range(ssl_repa_head_num_layers) | |
| ], | |
| nn.Linear(config.hidden_dim, ssl_feat_dim), | |
| ) | |
| else: | |
| self.ssl_repa_head = None | |
| def _init_weights(self, m: nn.Module): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def configure_optimizers(self): | |
| params = [{"params": self.patchvae.parameters()}] | |
| if self.ssl_repa_head is not None: | |
| params += [{"params": self.ssl_repa_head.parameters()}] | |
| opt = torch.optim.AdamW( | |
| params, | |
| lr=self.hparams.lr, | |
| weight_decay=self.hparams.weight_decay, | |
| ) | |
| if self.hparams.end_lr is not None: | |
| scheduler = LambdaLR( | |
| opt, | |
| lr_lambda=cosine_schedule_with_warmup( | |
| warmup_steps=self.hparams.warmup_steps, | |
| total_steps=self.hparams.total_steps, | |
| start_lr=self.hparams.lr, | |
| end_lr=self.hparams.end_lr, | |
| ), | |
| ) | |
| scheduler = {"scheduler": scheduler, "interval": "step"} | |
| return [opt], [scheduler] | |
| return opt | |
| def save_model_weights_and_config( | |
| self, | |
| dir: str | None, | |
| model_filename: str = "model.st", | |
| config_filename: str = "config.json", | |
| ): | |
| cfg = self.hparams.config | |
| model_path = Path(dir) / model_filename | |
| save_file(self.patchvae.state_dict(), model_path.with_suffix(".st")) | |
| with open(Path(dir) / config_filename, "w") as f: | |
| json.dump(asdict(cfg), f, indent=2) | |
| def training_step(self, batch: dict[str, torch.Tensor], batch_idx): | |
| z = batch["audio_z"] | |
| t = torch.rand(z.shape[0], device=z.device) | |
| drop_cond_rate = self.hparams.cfg_drop_rate_per_sample | |
| drop_vae_rate = self.hparams.drop_vae_rate | |
| flow_loss, ae_loss, prior = self.patchvae( | |
| z, | |
| t, | |
| sigma=self.hparams.train_sigma, | |
| drop_vae_rate=drop_vae_rate, | |
| drop_cond_rate=drop_cond_rate, | |
| ) | |
| self.log("train_flow_loss", flow_loss, prog_bar=True) | |
| total_loss = flow_loss | |
| if ae_loss.get("kl_div") is not None: | |
| kl_div = ae_loss.get("kl_div") | |
| self.log("train_kl_div", kl_div, prog_bar=True) | |
| total_loss += self.hparams.kl_div_factor * kl_div | |
| for x in ["_mu_mean", "_mu_std", "_logvar_mean", "_logvar_std"]: | |
| stat = ae_loss.get(x) | |
| if stat is not None: | |
| self.log(x, stat, prog_bar=False) | |
| if self.ssl_repa_head is not None: | |
| target = self.hparams.ssl_repa_head_target | |
| if target == "prior": | |
| target = prior | |
| elif target == "flow": | |
| raise NotImplementedError | |
| with torch.inference_mode(): | |
| wav = self.wavvae.decode(z) | |
| wav = self.ssl_resampling(wav) | |
| wav = torch.nn.functional.pad(wav, (40, 40)) | |
| ssl_feats = self.ssl_model(wav, output_hidden_states=True).hidden_states | |
| ssl_feat = ssl_feats[10] | |
| ssl_feat = torch.nn.functional.avg_pool1d( | |
| ssl_feat.transpose(-1, -2), | |
| kernel_size=8, | |
| stride=4, | |
| padding=2, | |
| ).transpose(-1, -2) | |
| ssl_feat = ssl_feat.clone() | |
| B, N, D = ssl_feat.shape | |
| repa_pred = self.ssl_repa_head(target) | |
| ssl_repa_loss = nn.functional.cosine_embedding_loss( | |
| repa_pred.reshape(-1, D), | |
| ssl_feat.reshape(-1, D), | |
| torch.ones(1).to(repa_pred), | |
| ) | |
| total_loss += self.hparams.ssl_repa_factor * ssl_repa_loss | |
| self.log("train_repa_loss", ssl_repa_loss, prog_bar=True) | |
| return total_loss | |
| def validation_step(self, batch: dict[str, torch.Tensor], batch_idx): | |
| z = batch["audio_z"] | |
| t = ( | |
| torch.ones(z.shape[0], device=z.device) | |
| * batch_idx | |
| / self.trainer.num_val_batches[0] | |
| ) | |
| flow_loss, ae_loss, prior = self.patchvae( | |
| z, t, sigma=self.hparams.train_sigma, drop_cond=False | |
| ) | |
| self.log("val_flow_loss", flow_loss, prog_bar=True) | |
| total_loss = flow_loss | |
| if ae_loss.get("kl_div") is not None: | |
| kl_div = ae_loss.get("kl_div") | |
| self.log("val_kl_div", kl_div, prog_bar=True) | |
| total_loss += self.hparams.kl_div_factor * kl_div | |
| return total_loss | |
| if __name__ == "__main__": | |
| from pytorch_lightning.cli import LightningCLI | |
| LightningCLI( | |
| TrainPatchVAE, | |
| save_config_callback=None, | |
| parser_kwargs={"parser_mode": "omegaconf"}, | |
| ) | |