Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import math | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Optional | |
| import pytorch_lightning as pl | |
| import torch | |
| import transformers | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.cli import LightningCLI | |
| from pytorch_lightning.loggers.wandb import WandbLogger | |
| from safetensors.torch import save_file | |
| from torch.nn.utils import clip_grad_norm_ | |
| from codec.models import WavVAE, WavVAEConfig | |
| from codec.models.wavvae.discriminators import (MultiPeriodDiscriminator, | |
| MultiResolutionDiscriminator) | |
| from codec.models.wavvae.loss import (DiscriminatorLoss, FeatureMatchingLoss, | |
| GeneratorLoss, | |
| MelSpecReconstructionLoss) | |
| class TrainWavVAE(pl.LightningModule): | |
| def __init__( | |
| self, | |
| config: WavVAEConfig, | |
| sample_rate: int, | |
| initial_learning_rate: float, | |
| num_warmup_steps: int = 0, | |
| mel_loss_coeff: float = 45, | |
| mrd_loss_coeff: float = 1.0, | |
| kl_div_coeff: float = 1e-5, | |
| pretrain_mel_steps: int = 0, | |
| decay_mel_coeff: bool = False, | |
| clip_grad_norm: float | None = None, | |
| f_min: int = 0, | |
| f_max: Optional[int] = None, | |
| mrd_fft_sizes: tuple[int, int, int] = (2048, 1024, 512), | |
| mel_hop_length: int = 256, | |
| log_audio_every_n_epoch: int = 5, | |
| log_n_audio_batches: int = 32, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.wavvae = WavVAE(config) | |
| self.multiperioddisc = MultiPeriodDiscriminator() | |
| self.multiresddisc = MultiResolutionDiscriminator( | |
| fft_sizes=tuple(mrd_fft_sizes) | |
| ) | |
| self.disc_loss = DiscriminatorLoss() | |
| self.gen_loss = GeneratorLoss() | |
| self.feat_matching_loss = FeatureMatchingLoss() | |
| self.melspec_loss = MelSpecReconstructionLoss( | |
| sample_rate=sample_rate, | |
| f_min=f_min, | |
| f_max=f_max, | |
| hop_length=mel_hop_length, | |
| ) | |
| self.train_discriminator = False | |
| self.automatic_optimization = False | |
| self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff | |
| def save_model_weights_and_config( | |
| self, | |
| dir: str | None, | |
| model_filename: str = "model.st", | |
| config_filename: str = "config.json", | |
| ): | |
| cfg = self.hparams.config | |
| model_path = Path(dir) / model_filename | |
| save_file(self.wavvae.state_dict(), model_path) | |
| with open(Path(dir) / config_filename, "w") as f: | |
| json.dump(asdict(cfg), f, indent=2) | |
| def configure_optimizers(self): | |
| disc_params = [ | |
| {"params": self.multiperioddisc.parameters()}, | |
| {"params": self.multiresddisc.parameters()}, | |
| ] | |
| gen_params = [ | |
| {"params": self.wavvae.parameters()}, | |
| ] | |
| opt_disc = torch.optim.AdamW( | |
| disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) | |
| ) | |
| opt_gen = torch.optim.AdamW( | |
| gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) | |
| ) | |
| max_steps = self.trainer.max_steps // 2 | |
| scheduler_disc = transformers.get_cosine_schedule_with_warmup( | |
| opt_disc, | |
| num_warmup_steps=self.hparams.num_warmup_steps, | |
| num_training_steps=max_steps, | |
| ) | |
| scheduler_gen = transformers.get_cosine_schedule_with_warmup( | |
| opt_gen, | |
| num_warmup_steps=self.hparams.num_warmup_steps, | |
| num_training_steps=max_steps, | |
| ) | |
| return ( | |
| [opt_disc, opt_gen], | |
| [ | |
| {"scheduler": scheduler_disc, "interval": "step"}, | |
| {"scheduler": scheduler_gen, "interval": "step"}, | |
| ], | |
| ) | |
| def forward(self, audio_input, **kwargs): | |
| audio_output, kl_div = self.wavvae(audio_input) | |
| return audio_output, kl_div | |
| def training_step(self, batch, batch_idx, **kwargs): | |
| audio_input = batch | |
| opt_disc, opt_gen = self.optimizers() | |
| if self.train_discriminator: | |
| with torch.no_grad(): | |
| audio_hat, kl_div = self(audio_input, **kwargs) | |
| real_score_mp, gen_score_mp, _, _ = self.multiperioddisc( | |
| y=audio_input, | |
| y_hat=audio_hat, | |
| **kwargs, | |
| ) | |
| real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc( | |
| y=audio_input, | |
| y_hat=audio_hat, | |
| **kwargs, | |
| ) | |
| loss_mp, loss_mp_real, _ = self.disc_loss( | |
| disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp | |
| ) | |
| loss_mrd, loss_mrd_real, _ = self.disc_loss( | |
| disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd | |
| ) | |
| loss_mp /= len(loss_mp_real) | |
| loss_mrd /= len(loss_mrd_real) | |
| loss_disc = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd | |
| self.log("discriminator/total", loss_disc, prog_bar=True) | |
| self.log("discriminator/multi_period_loss", loss_mp) | |
| self.log("discriminator/multi_res_loss", loss_mrd) | |
| opt_disc.zero_grad() | |
| self.manual_backward(loss_disc) | |
| if self.hparams.clip_grad_norm is not None: | |
| max_norm = self.hparams.clip_grad_norm | |
| clip_grad_norm_(self.multiperioddisc.parameters(), max_norm=max_norm) | |
| clip_grad_norm_(self.multiresddisc.parameters(), max_norm=max_norm) | |
| opt_disc.step() | |
| audio_hat, kl_div = self(audio_input, **kwargs) | |
| if self.train_discriminator: | |
| _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( | |
| y=audio_input, | |
| y_hat=audio_hat, | |
| **kwargs, | |
| ) | |
| _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( | |
| y=audio_input, | |
| y_hat=audio_hat, | |
| **kwargs, | |
| ) | |
| loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) | |
| loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) | |
| loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) | |
| loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) | |
| loss_fm_mp = self.feat_matching_loss( | |
| fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp | |
| ) / len(fmap_rs_mp) | |
| loss_fm_mrd = self.feat_matching_loss( | |
| fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd | |
| ) / len(fmap_rs_mrd) | |
| self.log("generator/multi_period_loss", loss_gen_mp) | |
| self.log("generator/multi_res_loss", loss_gen_mrd) | |
| self.log("generator/feature_matching_mp", loss_fm_mp) | |
| self.log("generator/feature_matching_mrd", loss_fm_mrd) | |
| self.log("generator/kl_div", kl_div) | |
| mel_loss = self.melspec_loss(audio_hat, audio_input) | |
| loss = ( | |
| loss_gen_mp | |
| + self.hparams.mrd_loss_coeff * loss_gen_mrd | |
| + loss_fm_mp | |
| + self.hparams.mrd_loss_coeff * loss_fm_mrd | |
| + self.mel_loss_coeff * mel_loss | |
| + self.hparams.kl_div_coeff * kl_div | |
| ) | |
| self.log("generator/total_loss", loss, prog_bar=True) | |
| self.log("mel_loss_coeff", self.mel_loss_coeff) | |
| self.log("generator/mel_loss", mel_loss) | |
| opt_gen.zero_grad() | |
| self.manual_backward(loss) | |
| if self.hparams.clip_grad_norm is not None: | |
| max_norm = self.hparams.clip_grad_norm | |
| clip_grad_norm_(self.wavvae.parameters(), max_norm=max_norm) | |
| opt_gen.step() | |
| def validation_step(self, batch, batch_idx, **kwargs): | |
| audio_input = batch | |
| audio_hat, _ = self(audio_input, **kwargs) | |
| if self.current_epoch % self.hparams.log_audio_every_n_epoch == 0: | |
| wavs = [x.numpy(force=True) for x in audio_hat.unbind(0)] | |
| if batch_idx == 0: | |
| self._audios_to_log = wavs | |
| if batch_idx < self.hparams.log_n_audio_batches: | |
| self._audios_to_log += wavs | |
| elif batch_idx == self.hparams.log_n_audio_batches: | |
| self.logger.log_audio( | |
| "audio", | |
| self._audios_to_log, | |
| step=self.global_step, | |
| sample_rate=[ | |
| self.wavvae.sampling_rate | |
| for _ in range(len(self._audios_to_log)) | |
| ], | |
| ) | |
| mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) | |
| total_loss = mel_loss | |
| return { | |
| "val_loss": total_loss, | |
| "mel_loss": mel_loss, | |
| "audio_input": audio_input[0], | |
| "audio_pred": audio_hat[0], | |
| } | |
| def global_step(self): | |
| """ | |
| Override global_step so that it returns the total number of batches processed | |
| """ | |
| return self.trainer.fit_loop.epoch_loop.total_batch_idx | |
| def on_train_batch_start(self, *args): | |
| if self.global_step >= self.hparams.pretrain_mel_steps: | |
| self.train_discriminator = True | |
| else: | |
| self.train_discriminator = False | |
| def on_train_batch_end(self, *args): | |
| def mel_loss_coeff_decay(current_step, num_cycles=0.5): | |
| max_steps = self.trainer.max_steps // 2 | |
| if current_step < self.hparams.num_warmup_steps: | |
| return 1.0 | |
| progress = float(current_step - self.hparams.num_warmup_steps) / float( | |
| max(1, max_steps - self.hparams.num_warmup_steps) | |
| ) | |
| return max( | |
| 0.0, | |
| 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), | |
| ) | |
| if self.hparams.decay_mel_coeff: | |
| self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay( | |
| self.global_step + 1 | |
| ) | |
| if __name__ == "__main__": | |
| class WavVAECLI(LightningCLI): | |
| def after_instantiate_classes(self): | |
| hparams = self.model.hparams | |
| kl_factor = "{:.1e}".format(hparams.kl_div_coeff) | |
| latent_dim = hparams.config["latent_dim"] | |
| frame_rate = self.model.wavvae.frame_rate | |
| dataset_name = ( | |
| Path(self.datamodule.train_config.filelist_path).with_suffix("").name | |
| ) | |
| name = f"WavVAE_kl{kl_factor}_framerate{frame_rate}hz_latentdim{latent_dim}_dataset{dataset_name}" | |
| if self.trainer.logger: | |
| logger = WandbLogger( | |
| log_model=False, | |
| project="codec", | |
| name=name, | |
| ) | |
| model_checkpoint_cb = ModelCheckpoint( | |
| monitor="generator/mel_loss", | |
| dirpath="checkpoints/wavvae", | |
| filename=name + "_epoch{epoch:02d}", | |
| save_last=True, | |
| ) | |
| self.trainer.callbacks.append(model_checkpoint_cb) | |
| WavVAECLI( | |
| save_config_kwargs={"overwrite": True}, | |
| parser_kwargs={"parser_mode": "omegaconf"}, | |
| ) | |