import json import math from dataclasses import asdict from pathlib import Path import hydra import numpy as np import pytorch_lightning as ptl import torch from omegaconf import DictConfig, ListConfig, OmegaConf from safetensors.torch import save_file from torch import nn from torch.optim.lr_scheduler import LambdaLR from transformers import get_cosine_schedule_with_warmup from .model.config import TTSConfig from .model.prediction_head import VelocityHead from .tts import ARTTSModel def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr return lr_lambda class TrainARTTS(ptl.LightningModule): def __init__( self, config: TTSConfig, quant_layer: list[int], tie_embed: bool = False, learning_rate: float = 5e-4, end_learning_rate: float | None = None, weight_decay: float = 0.1, betas: tuple[float, float] = (0.9, 0.999), n_warmup_steps: int = 500, n_training_steps: int = 300000, mask_text_p: float = 0.0, load_weights: str | None = None, stop_token_weight: float | None = None, stop_loss_factor: float = 0.1, stop_loss_warmup: tuple[int, int] | None = None, ): super(TrainARTTS, self).__init__() self.learning_rate = learning_rate self.weight_decay = weight_decay self.betas = betas self.n_warmup_steps = n_warmup_steps self.n_training_steps = n_training_steps self.stop_token_weight = stop_token_weight self.stop_loss_factor = stop_loss_factor self.save_hyperparameters() self.model = ARTTSModel(config) if load_weights is not None: model = torch.load(load_weights) self.load_state_dict(model["state_dict"], strict=False) def on_train_epoch_start(self): if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) def save_model_weights_and_config( self, dir: str | None, model_filename: str = "model.st", config_filename: str = "config.json", ): def to_builtin(obj): if isinstance(obj, dict): return {k: to_builtin(v) for k, v in obj.items()} elif isinstance(obj, list): return [to_builtin(v) for v in obj] elif isinstance(obj, ListConfig): return [to_builtin(v) for v in obj] elif isinstance(obj, DictConfig): return {k: to_builtin(v) for k, v in obj.items()} else: return obj cfg = asdict(self.hparams.config) cfg = to_builtin(cfg) for k, v in cfg.items(): if v is ListConfig: print("here") cfg[k] = OmegaConf.to_container(v, resolve=True) Path(dir).mkdir(exist_ok=True) model_path = Path(dir) / model_filename save_file(self.model.state_dict(), model_path) with open(Path(dir) / config_filename, "w") as f: json.dump(cfg, f, indent=2) def step(self, batch, batch_idx: int, validation: bool = False): text_token = batch["text_token"] audio_token = batch["audio_token"].squeeze(2) crossatt_mask = batch.get("crossatt_mask") text_rel_pos = batch.get("text_rel_pos") encoder_mask = batch.get("encoder_mask") stop_token = batch.get("stop_token") text_stop_token = batch.get("text_stop_token") crossatt_rel_pos = batch.get("crossatt_rel_pos") logits_mask = batch.get("y_mask") pre_logits = self.model( text_ids=text_token, audio_inputs=audio_token, text_mask=encoder_mask, audio_mask=logits_mask, crossatt_mask=crossatt_mask, crossatt_rel_pos=crossatt_rel_pos, stop_tokens=stop_token, text_rel_pos=text_rel_pos, text_stop_tokens=text_stop_token, ) losses = {} if validation and type(self.model.prediction_head) is DiffusionHead: # deterministic time conditioning during validation t = ( torch.ones(pre_logits.shape[0], device=pre_logits.device) * batch_idx / self.trainer.num_val_batches[0] ) losses |= self.model.prediction_head.compute_loss( pre_logits, audio_token[:, 1:], mask=logits_mask[:, 1:] if logits_mask is not None else None, t=t, ) else: losses |= self.model.prediction_head.compute_loss( pre_logits, audio_token[:, 1:], mask=logits_mask[:, 1:] if logits_mask is not None else None, ) if self.model.stop_prediction_head is not None and logits_mask is not None: if stop_token is None: stop_token = nn.functional.pad( (~logits_mask)[:, 2:].to(pre_logits), (0, 1) ) else: stop_token = stop_token[:, 1:] mask = logits_mask[:, 1:] losses |= self.model.stop_prediction_head.compute_loss( pre_logits[mask], stop_token[mask], ) return losses def training_step(self, batch, idx): losses = self.step(batch, idx) total_loss = 0.0 for name, loss in losses.items(): self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) if "stop" in name: if self.hparams.stop_loss_warmup is not None: alpha, beta = self.hparams.stop_loss_warmup warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0) else: warmup = 1.0 loss *= self.stop_loss_factor * warmup total_loss += loss self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) return total_loss def validation_step(self, batch, idx): losses = self.step(batch, idx, validation=True) total_loss = 0.0 for name, loss in losses.items(): self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) total_loss += loss self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) return total_loss def configure_optimizers(self): params = [ { "params": self.model.parameters(), "weight_decay": self.weight_decay, } ] opt = torch.optim.AdamW( params, lr=self.learning_rate, betas=self.betas, ) # scheduler = get_cosine_schedule_with_warmup( # opt, # num_warmup_steps=self.n_warmup_steps, # num_training_steps=self.n_training_steps, # ) scheduler = LambdaLR( opt, lr_lambda=cosine_schedule_with_warmup( warmup_steps=self.hparams.n_warmup_steps, total_steps=self.hparams.n_training_steps, start_lr=self.hparams.learning_rate, end_lr=self.hparams.learning_rate * 0.1, ), ) return [opt], [{"scheduler": scheduler, "interval": "step"}] @hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3") def main(cfg: DictConfig): ptl.seed_everything(cfg.seed_everything) model = hydra.utils.instantiate(cfg.model) cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}" datamodule = hydra.utils.instantiate(cfg.data) trainer = hydra.utils.instantiate(cfg.trainer) trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) if __name__ == "__main__": main()