pardi-speech / codec /train_patchvae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import json
import math
from dataclasses import asdict
from pathlib import Path
from typing import Literal
import pytorch_lightning as pl
import torch
from safetensors.torch import safe_open, save_file
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torchaudio.transforms import Resample
from transformers import WavLMModel
from .models import PatchVAE, PatchVAEConfig, WavVAE
from .models.components.convnext import SwiGLU
from .models.patchvae.modules import convnext_factory
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
return lr_lambda
class TrainPatchVAE(pl.LightningModule):
def __init__(
self,
config: PatchVAEConfig,
lr: float = 1e-4,
end_lr: float | None = None,
weight_decay: float = 0.01,
cfg_drop_rate: float = 0.0,
cfg_drop_rate_per_sample: float = 0.1,
total_steps: int = 2_000_000,
warmup_steps: int = 0,
mean_std_from_file: str | None = None,
train_sigma: float = 1e-4,
kl_div_factor: float = 1e-4,
wavvae_pretrained_path: str | None = None,
drop_vae_rate: float = 0.0,
ssl_repa_factor: float = 1.0,
ssl_repa_head: bool = False,
ssl_repa_head_type: Literal["convnext", "swiglu"] = "swiglu",
ssl_repa_head_target: Literal["flow", "prior"] = "prior",
ssl_repa_head_num_layers: int = 4,
ssl_repa_source: Literal["wavlm-base-plus"] = "wavlm-base-plus",
ssl_feat_dim: int = 768,
):
super().__init__()
self.save_hyperparameters()
if mean_std_from_file is not None:
with safe_open(mean_std_from_file, framework="pt") as f:
config.latent_scaling = (
f.get_tensor("mean").tolist(),
f.get_tensor("std").tolist(),
)
self.patchvae = PatchVAE(config)
self.apply(self._init_weights)
self.wavvae = None
if wavvae_pretrained_path is not None:
self.wavvae = WavVAE.from_pretrained_local(wavvae_pretrained_path).eval()
self.ssl_model = None
if ssl_repa_head:
if ssl_repa_source == "wavlm-base-plus":
self.ssl_model = WavLMModel.from_pretrained(
"microsoft/wavlm-base-plus"
).eval()
self.ssl_resampling = Resample(
orig_freq=self.wavvae.sampling_rate, new_freq=16000
)
for p in self.ssl_model.parameters():
p.requires_grad = False
if ssl_repa_head_type == "convnext":
self.ssl_repa_head = nn.Sequential(
convnext_factory(
config.hidden_dim,
ssl_repa_head_num_layers,
),
nn.Linear(config.hidden_dim, ssl_feat_dim),
)
elif ssl_repa_head_type == "swiglu":
self.ssl_repa_head = nn.Sequential(
*[
SwiGLU(config.hidden_dim)
for _ in range(ssl_repa_head_num_layers)
],
nn.Linear(config.hidden_dim, ssl_feat_dim),
)
else:
self.ssl_repa_head = None
def _init_weights(self, m: nn.Module):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def configure_optimizers(self):
params = [{"params": self.patchvae.parameters()}]
if self.ssl_repa_head is not None:
params += [{"params": self.ssl_repa_head.parameters()}]
opt = torch.optim.AdamW(
params,
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
if self.hparams.end_lr is not None:
scheduler = LambdaLR(
opt,
lr_lambda=cosine_schedule_with_warmup(
warmup_steps=self.hparams.warmup_steps,
total_steps=self.hparams.total_steps,
start_lr=self.hparams.lr,
end_lr=self.hparams.end_lr,
),
)
scheduler = {"scheduler": scheduler, "interval": "step"}
return [opt], [scheduler]
return opt
def save_model_weights_and_config(
self,
dir: str | None,
model_filename: str = "model.st",
config_filename: str = "config.json",
):
cfg = self.hparams.config
model_path = Path(dir) / model_filename
save_file(self.patchvae.state_dict(), model_path.with_suffix(".st"))
with open(Path(dir) / config_filename, "w") as f:
json.dump(asdict(cfg), f, indent=2)
def training_step(self, batch: dict[str, torch.Tensor], batch_idx):
z = batch["audio_z"]
t = torch.rand(z.shape[0], device=z.device)
drop_cond_rate = self.hparams.cfg_drop_rate_per_sample
drop_vae_rate = self.hparams.drop_vae_rate
flow_loss, ae_loss, prior = self.patchvae(
z,
t,
sigma=self.hparams.train_sigma,
drop_vae_rate=drop_vae_rate,
drop_cond_rate=drop_cond_rate,
)
self.log("train_flow_loss", flow_loss, prog_bar=True)
total_loss = flow_loss
if ae_loss.get("kl_div") is not None:
kl_div = ae_loss.get("kl_div")
self.log("train_kl_div", kl_div, prog_bar=True)
total_loss += self.hparams.kl_div_factor * kl_div
for x in ["_mu_mean", "_mu_std", "_logvar_mean", "_logvar_std"]:
stat = ae_loss.get(x)
if stat is not None:
self.log(x, stat, prog_bar=False)
if self.ssl_repa_head is not None:
target = self.hparams.ssl_repa_head_target
if target == "prior":
target = prior
elif target == "flow":
raise NotImplementedError
with torch.inference_mode():
wav = self.wavvae.decode(z)
wav = self.ssl_resampling(wav)
wav = torch.nn.functional.pad(wav, (40, 40))
ssl_feats = self.ssl_model(wav, output_hidden_states=True).hidden_states
ssl_feat = ssl_feats[10]
ssl_feat = torch.nn.functional.avg_pool1d(
ssl_feat.transpose(-1, -2),
kernel_size=8,
stride=4,
padding=2,
).transpose(-1, -2)
ssl_feat = ssl_feat.clone()
B, N, D = ssl_feat.shape
repa_pred = self.ssl_repa_head(target)
ssl_repa_loss = nn.functional.cosine_embedding_loss(
repa_pred.reshape(-1, D),
ssl_feat.reshape(-1, D),
torch.ones(1).to(repa_pred),
)
total_loss += self.hparams.ssl_repa_factor * ssl_repa_loss
self.log("train_repa_loss", ssl_repa_loss, prog_bar=True)
return total_loss
def validation_step(self, batch: dict[str, torch.Tensor], batch_idx):
z = batch["audio_z"]
t = (
torch.ones(z.shape[0], device=z.device)
* batch_idx
/ self.trainer.num_val_batches[0]
)
flow_loss, ae_loss, prior = self.patchvae(
z, t, sigma=self.hparams.train_sigma, drop_cond=False
)
self.log("val_flow_loss", flow_loss, prog_bar=True)
total_loss = flow_loss
if ae_loss.get("kl_div") is not None:
kl_div = ae_loss.get("kl_div")
self.log("val_kl_div", kl_div, prog_bar=True)
total_loss += self.hparams.kl_div_factor * kl_div
return total_loss
if __name__ == "__main__":
from pytorch_lightning.cli import LightningCLI
LightningCLI(
TrainPatchVAE,
save_config_callback=None,
parser_kwargs={"parser_mode": "omegaconf"},
)