pardi-speech / codec /train_wavvae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import json
import math
from dataclasses import asdict
from pathlib import Path
from typing import Optional
import pytorch_lightning as pl
import torch
import transformers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.loggers.wandb import WandbLogger
from safetensors.torch import save_file
from torch.nn.utils import clip_grad_norm_
from codec.models import WavVAE, WavVAEConfig
from codec.models.wavvae.discriminators import (MultiPeriodDiscriminator,
MultiResolutionDiscriminator)
from codec.models.wavvae.loss import (DiscriminatorLoss, FeatureMatchingLoss,
GeneratorLoss,
MelSpecReconstructionLoss)
class TrainWavVAE(pl.LightningModule):
def __init__(
self,
config: WavVAEConfig,
sample_rate: int,
initial_learning_rate: float,
num_warmup_steps: int = 0,
mel_loss_coeff: float = 45,
mrd_loss_coeff: float = 1.0,
kl_div_coeff: float = 1e-5,
pretrain_mel_steps: int = 0,
decay_mel_coeff: bool = False,
clip_grad_norm: float | None = None,
f_min: int = 0,
f_max: Optional[int] = None,
mrd_fft_sizes: tuple[int, int, int] = (2048, 1024, 512),
mel_hop_length: int = 256,
log_audio_every_n_epoch: int = 5,
log_n_audio_batches: int = 32,
):
super().__init__()
self.save_hyperparameters()
self.wavvae = WavVAE(config)
self.multiperioddisc = MultiPeriodDiscriminator()
self.multiresddisc = MultiResolutionDiscriminator(
fft_sizes=tuple(mrd_fft_sizes)
)
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
hop_length=mel_hop_length,
)
self.train_discriminator = False
self.automatic_optimization = False
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
def save_model_weights_and_config(
self,
dir: str | None,
model_filename: str = "model.st",
config_filename: str = "config.json",
):
cfg = self.hparams.config
model_path = Path(dir) / model_filename
save_file(self.wavvae.state_dict(), model_path)
with open(Path(dir) / config_filename, "w") as f:
json.dump(asdict(cfg), f, indent=2)
def configure_optimizers(self):
disc_params = [
{"params": self.multiperioddisc.parameters()},
{"params": self.multiresddisc.parameters()},
]
gen_params = [
{"params": self.wavvae.parameters()},
]
opt_disc = torch.optim.AdamW(
disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)
)
opt_gen = torch.optim.AdamW(
gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)
)
max_steps = self.trainer.max_steps // 2
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
opt_disc,
num_warmup_steps=self.hparams.num_warmup_steps,
num_training_steps=max_steps,
)
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
opt_gen,
num_warmup_steps=self.hparams.num_warmup_steps,
num_training_steps=max_steps,
)
return (
[opt_disc, opt_gen],
[
{"scheduler": scheduler_disc, "interval": "step"},
{"scheduler": scheduler_gen, "interval": "step"},
],
)
def forward(self, audio_input, **kwargs):
audio_output, kl_div = self.wavvae(audio_input)
return audio_output, kl_div
def training_step(self, batch, batch_idx, **kwargs):
audio_input = batch
opt_disc, opt_gen = self.optimizers()
if self.train_discriminator:
with torch.no_grad():
audio_hat, kl_div = self(audio_input, **kwargs)
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
y=audio_input,
y_hat=audio_hat,
**kwargs,
)
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(
y=audio_input,
y_hat=audio_hat,
**kwargs,
)
loss_mp, loss_mp_real, _ = self.disc_loss(
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
)
loss_mrd, loss_mrd_real, _ = self.disc_loss(
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
)
loss_mp /= len(loss_mp_real)
loss_mrd /= len(loss_mrd_real)
loss_disc = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd
self.log("discriminator/total", loss_disc, prog_bar=True)
self.log("discriminator/multi_period_loss", loss_mp)
self.log("discriminator/multi_res_loss", loss_mrd)
opt_disc.zero_grad()
self.manual_backward(loss_disc)
if self.hparams.clip_grad_norm is not None:
max_norm = self.hparams.clip_grad_norm
clip_grad_norm_(self.multiperioddisc.parameters(), max_norm=max_norm)
clip_grad_norm_(self.multiresddisc.parameters(), max_norm=max_norm)
opt_disc.step()
audio_hat, kl_div = self(audio_input, **kwargs)
if self.train_discriminator:
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
y=audio_input,
y_hat=audio_hat,
**kwargs,
)
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
y=audio_input,
y_hat=audio_hat,
**kwargs,
)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
loss_fm_mp = self.feat_matching_loss(
fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp
) / len(fmap_rs_mp)
loss_fm_mrd = self.feat_matching_loss(
fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd
) / len(fmap_rs_mrd)
self.log("generator/multi_period_loss", loss_gen_mp)
self.log("generator/multi_res_loss", loss_gen_mrd)
self.log("generator/feature_matching_mp", loss_fm_mp)
self.log("generator/feature_matching_mrd", loss_fm_mrd)
self.log("generator/kl_div", kl_div)
mel_loss = self.melspec_loss(audio_hat, audio_input)
loss = (
loss_gen_mp
+ self.hparams.mrd_loss_coeff * loss_gen_mrd
+ loss_fm_mp
+ self.hparams.mrd_loss_coeff * loss_fm_mrd
+ self.mel_loss_coeff * mel_loss
+ self.hparams.kl_div_coeff * kl_div
)
self.log("generator/total_loss", loss, prog_bar=True)
self.log("mel_loss_coeff", self.mel_loss_coeff)
self.log("generator/mel_loss", mel_loss)
opt_gen.zero_grad()
self.manual_backward(loss)
if self.hparams.clip_grad_norm is not None:
max_norm = self.hparams.clip_grad_norm
clip_grad_norm_(self.wavvae.parameters(), max_norm=max_norm)
opt_gen.step()
def validation_step(self, batch, batch_idx, **kwargs):
audio_input = batch
audio_hat, _ = self(audio_input, **kwargs)
if self.current_epoch % self.hparams.log_audio_every_n_epoch == 0:
wavs = [x.numpy(force=True) for x in audio_hat.unbind(0)]
if batch_idx == 0:
self._audios_to_log = wavs
if batch_idx < self.hparams.log_n_audio_batches:
self._audios_to_log += wavs
elif batch_idx == self.hparams.log_n_audio_batches:
self.logger.log_audio(
"audio",
self._audios_to_log,
step=self.global_step,
sample_rate=[
self.wavvae.sampling_rate
for _ in range(len(self._audios_to_log))
],
)
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
total_loss = mel_loss
return {
"val_loss": total_loss,
"mel_loss": mel_loss,
"audio_input": audio_input[0],
"audio_pred": audio_hat[0],
}
@property
def global_step(self):
"""
Override global_step so that it returns the total number of batches processed
"""
return self.trainer.fit_loop.epoch_loop.total_batch_idx
def on_train_batch_start(self, *args):
if self.global_step >= self.hparams.pretrain_mel_steps:
self.train_discriminator = True
else:
self.train_discriminator = False
def on_train_batch_end(self, *args):
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
max_steps = self.trainer.max_steps // 2
if current_step < self.hparams.num_warmup_steps:
return 1.0
progress = float(current_step - self.hparams.num_warmup_steps) / float(
max(1, max_steps - self.hparams.num_warmup_steps)
)
return max(
0.0,
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
if self.hparams.decay_mel_coeff:
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(
self.global_step + 1
)
if __name__ == "__main__":
class WavVAECLI(LightningCLI):
def after_instantiate_classes(self):
hparams = self.model.hparams
kl_factor = "{:.1e}".format(hparams.kl_div_coeff)
latent_dim = hparams.config["latent_dim"]
frame_rate = self.model.wavvae.frame_rate
dataset_name = (
Path(self.datamodule.train_config.filelist_path).with_suffix("").name
)
name = f"WavVAE_kl{kl_factor}_framerate{frame_rate}hz_latentdim{latent_dim}_dataset{dataset_name}"
if self.trainer.logger:
logger = WandbLogger(
log_model=False,
project="codec",
name=name,
)
model_checkpoint_cb = ModelCheckpoint(
monitor="generator/mel_loss",
dirpath="checkpoints/wavvae",
filename=name + "_epoch{epoch:02d}",
save_last=True,
)
self.trainer.callbacks.append(model_checkpoint_cb)
WavVAECLI(
save_config_kwargs={"overwrite": True},
parser_kwargs={"parser_mode": "omegaconf"},
)