Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import math | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| import hydra | |
| import numpy as np | |
| import pytorch_lightning as ptl | |
| import torch | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from safetensors.torch import save_file | |
| from torch import nn | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from transformers import get_cosine_schedule_with_warmup | |
| from .model.config import TTSConfig | |
| from .model.prediction_head import VelocityHead | |
| from .tts import ARTTSModel | |
| def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return step / max(1, warmup_steps) | |
| progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1) | |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) | |
| return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr | |
| return lr_lambda | |
| class TrainARTTS(ptl.LightningModule): | |
| def __init__( | |
| self, | |
| config: TTSConfig, | |
| quant_layer: list[int], | |
| tie_embed: bool = False, | |
| learning_rate: float = 5e-4, | |
| end_learning_rate: float | None = None, | |
| weight_decay: float = 0.1, | |
| betas: tuple[float, float] = (0.9, 0.999), | |
| n_warmup_steps: int = 500, | |
| n_training_steps: int = 300000, | |
| mask_text_p: float = 0.0, | |
| load_weights: str | None = None, | |
| stop_token_weight: float | None = None, | |
| stop_loss_factor: float = 0.1, | |
| stop_loss_warmup: tuple[int, int] | None = None, | |
| ): | |
| super(TrainARTTS, self).__init__() | |
| self.learning_rate = learning_rate | |
| self.weight_decay = weight_decay | |
| self.betas = betas | |
| self.n_warmup_steps = n_warmup_steps | |
| self.n_training_steps = n_training_steps | |
| self.stop_token_weight = stop_token_weight | |
| self.stop_loss_factor = stop_loss_factor | |
| self.save_hyperparameters() | |
| self.model = ARTTSModel(config) | |
| if load_weights is not None: | |
| model = torch.load(load_weights) | |
| self.load_state_dict(model["state_dict"], strict=False) | |
| def on_train_epoch_start(self): | |
| if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): | |
| self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) | |
| def save_model_weights_and_config( | |
| self, | |
| dir: str | None, | |
| model_filename: str = "model.st", | |
| config_filename: str = "config.json", | |
| ): | |
| def to_builtin(obj): | |
| if isinstance(obj, dict): | |
| return {k: to_builtin(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [to_builtin(v) for v in obj] | |
| elif isinstance(obj, ListConfig): | |
| return [to_builtin(v) for v in obj] | |
| elif isinstance(obj, DictConfig): | |
| return {k: to_builtin(v) for k, v in obj.items()} | |
| else: | |
| return obj | |
| cfg = asdict(self.hparams.config) | |
| cfg = to_builtin(cfg) | |
| for k, v in cfg.items(): | |
| if v is ListConfig: | |
| print("here") | |
| cfg[k] = OmegaConf.to_container(v, resolve=True) | |
| Path(dir).mkdir(exist_ok=True) | |
| model_path = Path(dir) / model_filename | |
| save_file(self.model.state_dict(), model_path) | |
| with open(Path(dir) / config_filename, "w") as f: | |
| json.dump(cfg, f, indent=2) | |
| def step(self, batch, batch_idx: int, validation: bool = False): | |
| text_token = batch["text_token"] | |
| audio_token = batch["audio_token"].squeeze(2) | |
| crossatt_mask = batch.get("crossatt_mask") | |
| text_rel_pos = batch.get("text_rel_pos") | |
| encoder_mask = batch.get("encoder_mask") | |
| stop_token = batch.get("stop_token") | |
| text_stop_token = batch.get("text_stop_token") | |
| crossatt_rel_pos = batch.get("crossatt_rel_pos") | |
| logits_mask = batch.get("y_mask") | |
| pre_logits = self.model( | |
| text_ids=text_token, | |
| audio_inputs=audio_token, | |
| text_mask=encoder_mask, | |
| audio_mask=logits_mask, | |
| crossatt_mask=crossatt_mask, | |
| crossatt_rel_pos=crossatt_rel_pos, | |
| stop_tokens=stop_token, | |
| text_rel_pos=text_rel_pos, | |
| text_stop_tokens=text_stop_token, | |
| ) | |
| losses = {} | |
| if validation and type(self.model.prediction_head) is DiffusionHead: | |
| # deterministic time conditioning during validation | |
| t = ( | |
| torch.ones(pre_logits.shape[0], device=pre_logits.device) | |
| * batch_idx | |
| / self.trainer.num_val_batches[0] | |
| ) | |
| losses |= self.model.prediction_head.compute_loss( | |
| pre_logits, | |
| audio_token[:, 1:], | |
| mask=logits_mask[:, 1:] if logits_mask is not None else None, | |
| t=t, | |
| ) | |
| else: | |
| losses |= self.model.prediction_head.compute_loss( | |
| pre_logits, | |
| audio_token[:, 1:], | |
| mask=logits_mask[:, 1:] if logits_mask is not None else None, | |
| ) | |
| if self.model.stop_prediction_head is not None and logits_mask is not None: | |
| if stop_token is None: | |
| stop_token = nn.functional.pad( | |
| (~logits_mask)[:, 2:].to(pre_logits), (0, 1) | |
| ) | |
| else: | |
| stop_token = stop_token[:, 1:] | |
| mask = logits_mask[:, 1:] | |
| losses |= self.model.stop_prediction_head.compute_loss( | |
| pre_logits[mask], | |
| stop_token[mask], | |
| ) | |
| return losses | |
| def training_step(self, batch, idx): | |
| losses = self.step(batch, idx) | |
| total_loss = 0.0 | |
| for name, loss in losses.items(): | |
| self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) | |
| if "stop" in name: | |
| if self.hparams.stop_loss_warmup is not None: | |
| alpha, beta = self.hparams.stop_loss_warmup | |
| warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0) | |
| else: | |
| warmup = 1.0 | |
| loss *= self.stop_loss_factor * warmup | |
| total_loss += loss | |
| self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) | |
| return total_loss | |
| def validation_step(self, batch, idx): | |
| losses = self.step(batch, idx, validation=True) | |
| total_loss = 0.0 | |
| for name, loss in losses.items(): | |
| self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) | |
| total_loss += loss | |
| self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) | |
| return total_loss | |
| def configure_optimizers(self): | |
| params = [ | |
| { | |
| "params": self.model.parameters(), | |
| "weight_decay": self.weight_decay, | |
| } | |
| ] | |
| opt = torch.optim.AdamW( | |
| params, | |
| lr=self.learning_rate, | |
| betas=self.betas, | |
| ) | |
| # scheduler = get_cosine_schedule_with_warmup( | |
| # opt, | |
| # num_warmup_steps=self.n_warmup_steps, | |
| # num_training_steps=self.n_training_steps, | |
| # ) | |
| scheduler = LambdaLR( | |
| opt, | |
| lr_lambda=cosine_schedule_with_warmup( | |
| warmup_steps=self.hparams.n_warmup_steps, | |
| total_steps=self.hparams.n_training_steps, | |
| start_lr=self.hparams.learning_rate, | |
| end_lr=self.hparams.learning_rate * 0.1, | |
| ), | |
| ) | |
| return [opt], [{"scheduler": scheduler, "interval": "step"}] | |
| def main(cfg: DictConfig): | |
| ptl.seed_everything(cfg.seed_everything) | |
| model = hydra.utils.instantiate(cfg.model) | |
| cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}" | |
| datamodule = hydra.utils.instantiate(cfg.data) | |
| trainer = hydra.utils.instantiate(cfg.trainer) | |
| trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) | |
| if __name__ == "__main__": | |
| main() | |