import math import hydra import pytorch_lightning as ptl import torch from omegaconf import DictConfig from super_monotonic_align import maximum_path from torch.optim.lr_scheduler import LambdaLR from model.config import PlayHeadConfig from playhead import PlayHead from train_tts import TrainARTTS def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr return lr_lambda def expand(x, r): b, n, d = x.shape return x.unsqueeze(2).repeat(1, 1, r, 1).reshape(b, r * n, d) class TrainPlayHead(ptl.LightningModule): def __init__( self, tts_checkpoint_path: str, playhead_config: PlayHeadConfig, learning_rate: float = 5e-4, end_learning_rate: float | None = None, weight_decay: float = 0.1, betas: tuple[float, float] = (0.9, 0.999), n_warmup_steps: int = 500, n_training_steps: int = 300000, ): super(TrainPlayHead, self).__init__() cfg = playhead_config self.learning_rate = learning_rate self.weight_decay = weight_decay self.betas = betas self.n_warmup_steps = n_warmup_steps self.n_training_steps = n_training_steps self.selected_cross_attention_heads = cfg.selected_cross_attention_heads self.avg_pool_stride = cfg.avg_pool_stride self.target_lag = cfg.target_lag self.save_hyperparameters() self.model = PlayHead(playhead_config) tts_lightning_module = TrainARTTS.load_from_checkpoint(tts_checkpoint_path) self.tts_model = tts_lightning_module.model.eval() for p in self.tts_model.parameters(): p.requires_grad = False def on_train_epoch_start(self): if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) def save_model_weights_and_config( self, dir: str | None, model_filename: str = "model.st", config_filename: str = "config.json", ): # cfg = self.hparams.config # Path(dir).mkdir(exist_ok=True) # model_path = Path(dir) / model_filename # save_file(self.model.state_dict(), model_path) # with open(Path(dir) / config_filename, "w") as f: # json.dump(asdict(cfg), f, indent=2) pass def step(self, batch, batch_idx: int, validation: bool = False): text_token = batch["text_token"] audio_token = batch["audio_token"].squeeze(2) crossatt_mask = batch["crossatt_mask"] text_rel_pos = batch["text_rel_pos"] encoder_mask = batch["encoder_mask"] stop_token = batch.get("stop_token") text_stop_token = batch.get("text_stop_token") crossatt_rel_pos = batch.get("crossatt_rel_pos") logits_mask = batch["y_mask"] with torch.inference_mode(): _ = self.tts_model( text_ids=text_token, audio_inputs=audio_token, text_mask=encoder_mask, audio_mask=logits_mask, crossatt_mask=crossatt_mask, crossatt_rel_pos=crossatt_rel_pos, stop_tokens=stop_token, text_rel_pos=text_rel_pos, text_stop_tokens=text_stop_token, ) atts = [] for l in self.tts_model.audio_decoder.decoder_layers: if l.crossatt is not None: atts.append(l.crossatt.att) num_sinks = self.tts_model.num_sink_tokens selected_ca_heads = torch.stack( [ atts[i][:, j].transpose(-1, -2) for i, j in self.selected_cross_attention_heads ] ) summed_ca = selected_ca_heads.sum(0) avg_pool_ca = torch.nn.functional.avg_pool1d( summed_ca[:, num_sinks:].transpose(-1, -2), self.avg_pool_stride, stride=self.avg_pool_stride, ceil_mode=True, ).transpose(-1, -2) mas_from_avg_pool = maximum_path( avg_pool_ca.clone(), mask=crossatt_mask[:, :-1, :: self.avg_pool_stride].transpose(-1, -2), ) target = torch.arange(mas_from_avg_pool.shape[1]).to(mas_from_avg_pool.device) if self.target_lag > 0: lag = self.target_lag mas_from_avg_pool = torch.roll(mas_from_avg_pool, lag, dims=2) mas_from_avg_pool[:, 0, :lag] = 1.0 mas_from_avg_pool[:, 1:, :lag] = 0.0 # logits_mask[:, :lag] = False target = (mas_from_avg_pool * target[:, None]).max(dim=1).values sink_ca = summed_ca[:, :num_sinks] input_ca = torch.cat((sink_ca, avg_pool_ca), dim=1) target = target % self.model.cycle_len return self.model(input_ca, target, logits_mask[:, :-1]), input_ca, target def training_step(self, batch, idx): losses, _, _ = self.step(batch, idx) total_loss = 0.0 for name, loss in losses.items(): self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) total_loss += loss self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) return total_loss def validation_step(self, batch, idx): losses, _, _ = self.step(batch, idx) total_loss = 0.0 for name, loss in losses.items(): self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) total_loss += loss self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) return total_loss def configure_optimizers(self): params = [ { "params": self.model.parameters(), "weight_decay": self.weight_decay, } ] opt = torch.optim.AdamW( params, lr=self.learning_rate, betas=self.betas, ) scheduler = LambdaLR( opt, lr_lambda=cosine_schedule_with_warmup( warmup_steps=self.hparams.n_warmup_steps, total_steps=self.hparams.n_training_steps, start_lr=self.hparams.learning_rate, end_lr=self.hparams.learning_rate * 0.1, ), ) return [opt], [{"scheduler": scheduler, "interval": "step"}] @hydra.main(config_path="playhead_configs/", config_name="config", version_base="1.3") def main(cfg: DictConfig): ptl.seed_everything(cfg.seed_everything) model = hydra.utils.instantiate(cfg.model) cfg.experiment_name = f"PlayHead" datamodule = hydra.utils.instantiate(cfg.data) trainer = hydra.utils.instantiate(cfg.trainer) trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) if __name__ == "__main__": main()