import json import math from dataclasses import asdict from pathlib import Path from typing import Optional import pytorch_lightning as pl import torch import transformers from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.cli import LightningCLI from pytorch_lightning.loggers.wandb import WandbLogger from safetensors.torch import save_file from torch.nn.utils import clip_grad_norm_ from codec.models import WavVAE, WavVAEConfig from codec.models.wavvae.discriminators import (MultiPeriodDiscriminator, MultiResolutionDiscriminator) from codec.models.wavvae.loss import (DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss, MelSpecReconstructionLoss) class TrainWavVAE(pl.LightningModule): def __init__( self, config: WavVAEConfig, sample_rate: int, initial_learning_rate: float, num_warmup_steps: int = 0, mel_loss_coeff: float = 45, mrd_loss_coeff: float = 1.0, kl_div_coeff: float = 1e-5, pretrain_mel_steps: int = 0, decay_mel_coeff: bool = False, clip_grad_norm: float | None = None, f_min: int = 0, f_max: Optional[int] = None, mrd_fft_sizes: tuple[int, int, int] = (2048, 1024, 512), mel_hop_length: int = 256, log_audio_every_n_epoch: int = 5, log_n_audio_batches: int = 32, ): super().__init__() self.save_hyperparameters() self.wavvae = WavVAE(config) self.multiperioddisc = MultiPeriodDiscriminator() self.multiresddisc = MultiResolutionDiscriminator( fft_sizes=tuple(mrd_fft_sizes) ) self.disc_loss = DiscriminatorLoss() self.gen_loss = GeneratorLoss() self.feat_matching_loss = FeatureMatchingLoss() self.melspec_loss = MelSpecReconstructionLoss( sample_rate=sample_rate, f_min=f_min, f_max=f_max, hop_length=mel_hop_length, ) self.train_discriminator = False self.automatic_optimization = False self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff def save_model_weights_and_config( self, dir: str | None, model_filename: str = "model.st", config_filename: str = "config.json", ): cfg = self.hparams.config model_path = Path(dir) / model_filename save_file(self.wavvae.state_dict(), model_path) with open(Path(dir) / config_filename, "w") as f: json.dump(asdict(cfg), f, indent=2) def configure_optimizers(self): disc_params = [ {"params": self.multiperioddisc.parameters()}, {"params": self.multiresddisc.parameters()}, ] gen_params = [ {"params": self.wavvae.parameters()}, ] opt_disc = torch.optim.AdamW( disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) ) opt_gen = torch.optim.AdamW( gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) ) max_steps = self.trainer.max_steps // 2 scheduler_disc = transformers.get_cosine_schedule_with_warmup( opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, ) scheduler_gen = transformers.get_cosine_schedule_with_warmup( opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, ) return ( [opt_disc, opt_gen], [ {"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}, ], ) def forward(self, audio_input, **kwargs): audio_output, kl_div = self.wavvae(audio_input) return audio_output, kl_div def training_step(self, batch, batch_idx, **kwargs): audio_input = batch opt_disc, opt_gen = self.optimizers() if self.train_discriminator: with torch.no_grad(): audio_hat, kl_div = self(audio_input, **kwargs) real_score_mp, gen_score_mp, _, _ = self.multiperioddisc( y=audio_input, y_hat=audio_hat, **kwargs, ) real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc( y=audio_input, y_hat=audio_hat, **kwargs, ) loss_mp, loss_mp_real, _ = self.disc_loss( disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp ) loss_mrd, loss_mrd_real, _ = self.disc_loss( disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd ) loss_mp /= len(loss_mp_real) loss_mrd /= len(loss_mrd_real) loss_disc = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd self.log("discriminator/total", loss_disc, prog_bar=True) self.log("discriminator/multi_period_loss", loss_mp) self.log("discriminator/multi_res_loss", loss_mrd) opt_disc.zero_grad() self.manual_backward(loss_disc) if self.hparams.clip_grad_norm is not None: max_norm = self.hparams.clip_grad_norm clip_grad_norm_(self.multiperioddisc.parameters(), max_norm=max_norm) clip_grad_norm_(self.multiresddisc.parameters(), max_norm=max_norm) opt_disc.step() audio_hat, kl_div = self(audio_input, **kwargs) if self.train_discriminator: _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( y=audio_input, y_hat=audio_hat, **kwargs, ) _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( y=audio_input, y_hat=audio_hat, **kwargs, ) loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) loss_fm_mp = self.feat_matching_loss( fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp ) / len(fmap_rs_mp) loss_fm_mrd = self.feat_matching_loss( fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd ) / len(fmap_rs_mrd) self.log("generator/multi_period_loss", loss_gen_mp) self.log("generator/multi_res_loss", loss_gen_mrd) self.log("generator/feature_matching_mp", loss_fm_mp) self.log("generator/feature_matching_mrd", loss_fm_mrd) self.log("generator/kl_div", kl_div) mel_loss = self.melspec_loss(audio_hat, audio_input) loss = ( loss_gen_mp + self.hparams.mrd_loss_coeff * loss_gen_mrd + loss_fm_mp + self.hparams.mrd_loss_coeff * loss_fm_mrd + self.mel_loss_coeff * mel_loss + self.hparams.kl_div_coeff * kl_div ) self.log("generator/total_loss", loss, prog_bar=True) self.log("mel_loss_coeff", self.mel_loss_coeff) self.log("generator/mel_loss", mel_loss) opt_gen.zero_grad() self.manual_backward(loss) if self.hparams.clip_grad_norm is not None: max_norm = self.hparams.clip_grad_norm clip_grad_norm_(self.wavvae.parameters(), max_norm=max_norm) opt_gen.step() def validation_step(self, batch, batch_idx, **kwargs): audio_input = batch audio_hat, _ = self(audio_input, **kwargs) if self.current_epoch % self.hparams.log_audio_every_n_epoch == 0: wavs = [x.numpy(force=True) for x in audio_hat.unbind(0)] if batch_idx == 0: self._audios_to_log = wavs if batch_idx < self.hparams.log_n_audio_batches: self._audios_to_log += wavs elif batch_idx == self.hparams.log_n_audio_batches: self.logger.log_audio( "audio", self._audios_to_log, step=self.global_step, sample_rate=[ self.wavvae.sampling_rate for _ in range(len(self._audios_to_log)) ], ) mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) total_loss = mel_loss return { "val_loss": total_loss, "mel_loss": mel_loss, "audio_input": audio_input[0], "audio_pred": audio_hat[0], } @property def global_step(self): """ Override global_step so that it returns the total number of batches processed """ return self.trainer.fit_loop.epoch_loop.total_batch_idx def on_train_batch_start(self, *args): if self.global_step >= self.hparams.pretrain_mel_steps: self.train_discriminator = True else: self.train_discriminator = False def on_train_batch_end(self, *args): def mel_loss_coeff_decay(current_step, num_cycles=0.5): max_steps = self.trainer.max_steps // 2 if current_step < self.hparams.num_warmup_steps: return 1.0 progress = float(current_step - self.hparams.num_warmup_steps) / float( max(1, max_steps - self.hparams.num_warmup_steps) ) return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), ) if self.hparams.decay_mel_coeff: self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay( self.global_step + 1 ) if __name__ == "__main__": class WavVAECLI(LightningCLI): def after_instantiate_classes(self): hparams = self.model.hparams kl_factor = "{:.1e}".format(hparams.kl_div_coeff) latent_dim = hparams.config["latent_dim"] frame_rate = self.model.wavvae.frame_rate dataset_name = ( Path(self.datamodule.train_config.filelist_path).with_suffix("").name ) name = f"WavVAE_kl{kl_factor}_framerate{frame_rate}hz_latentdim{latent_dim}_dataset{dataset_name}" if self.trainer.logger: logger = WandbLogger( log_model=False, project="codec", name=name, ) model_checkpoint_cb = ModelCheckpoint( monitor="generator/mel_loss", dirpath="checkpoints/wavvae", filename=name + "_epoch{epoch:02d}", save_last=True, ) self.trainer.callbacks.append(model_checkpoint_cb) WavVAECLI( save_config_kwargs={"overwrite": True}, parser_kwargs={"parser_mode": "omegaconf"}, )