pardi-speech / tts /train_tts.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
8.22 kB
import json
import math
from dataclasses import asdict
from pathlib import Path
import hydra
import numpy as np
import pytorch_lightning as ptl
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from safetensors.torch import save_file
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import get_cosine_schedule_with_warmup
from .model.config import TTSConfig
from .model.prediction_head import VelocityHead
from .tts import ARTTSModel
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
return lr_lambda
class TrainARTTS(ptl.LightningModule):
def __init__(
self,
config: TTSConfig,
quant_layer: list[int],
tie_embed: bool = False,
learning_rate: float = 5e-4,
end_learning_rate: float | None = None,
weight_decay: float = 0.1,
betas: tuple[float, float] = (0.9, 0.999),
n_warmup_steps: int = 500,
n_training_steps: int = 300000,
mask_text_p: float = 0.0,
load_weights: str | None = None,
stop_token_weight: float | None = None,
stop_loss_factor: float = 0.1,
stop_loss_warmup: tuple[int, int] | None = None,
):
super(TrainARTTS, self).__init__()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
self.n_warmup_steps = n_warmup_steps
self.n_training_steps = n_training_steps
self.stop_token_weight = stop_token_weight
self.stop_loss_factor = stop_loss_factor
self.save_hyperparameters()
self.model = ARTTSModel(config)
if load_weights is not None:
model = torch.load(load_weights)
self.load_state_dict(model["state_dict"], strict=False)
def on_train_epoch_start(self):
if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
def save_model_weights_and_config(
self,
dir: str | None,
model_filename: str = "model.st",
config_filename: str = "config.json",
):
def to_builtin(obj):
if isinstance(obj, dict):
return {k: to_builtin(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_builtin(v) for v in obj]
elif isinstance(obj, ListConfig):
return [to_builtin(v) for v in obj]
elif isinstance(obj, DictConfig):
return {k: to_builtin(v) for k, v in obj.items()}
else:
return obj
cfg = asdict(self.hparams.config)
cfg = to_builtin(cfg)
for k, v in cfg.items():
if v is ListConfig:
print("here")
cfg[k] = OmegaConf.to_container(v, resolve=True)
Path(dir).mkdir(exist_ok=True)
model_path = Path(dir) / model_filename
save_file(self.model.state_dict(), model_path)
with open(Path(dir) / config_filename, "w") as f:
json.dump(cfg, f, indent=2)
def step(self, batch, batch_idx: int, validation: bool = False):
text_token = batch["text_token"]
audio_token = batch["audio_token"].squeeze(2)
crossatt_mask = batch.get("crossatt_mask")
text_rel_pos = batch.get("text_rel_pos")
encoder_mask = batch.get("encoder_mask")
stop_token = batch.get("stop_token")
text_stop_token = batch.get("text_stop_token")
crossatt_rel_pos = batch.get("crossatt_rel_pos")
logits_mask = batch.get("y_mask")
pre_logits = self.model(
text_ids=text_token,
audio_inputs=audio_token,
text_mask=encoder_mask,
audio_mask=logits_mask,
crossatt_mask=crossatt_mask,
crossatt_rel_pos=crossatt_rel_pos,
stop_tokens=stop_token,
text_rel_pos=text_rel_pos,
text_stop_tokens=text_stop_token,
)
losses = {}
if validation and type(self.model.prediction_head) is DiffusionHead:
# deterministic time conditioning during validation
t = (
torch.ones(pre_logits.shape[0], device=pre_logits.device)
* batch_idx
/ self.trainer.num_val_batches[0]
)
losses |= self.model.prediction_head.compute_loss(
pre_logits,
audio_token[:, 1:],
mask=logits_mask[:, 1:] if logits_mask is not None else None,
t=t,
)
else:
losses |= self.model.prediction_head.compute_loss(
pre_logits,
audio_token[:, 1:],
mask=logits_mask[:, 1:] if logits_mask is not None else None,
)
if self.model.stop_prediction_head is not None and logits_mask is not None:
if stop_token is None:
stop_token = nn.functional.pad(
(~logits_mask)[:, 2:].to(pre_logits), (0, 1)
)
else:
stop_token = stop_token[:, 1:]
mask = logits_mask[:, 1:]
losses |= self.model.stop_prediction_head.compute_loss(
pre_logits[mask],
stop_token[mask],
)
return losses
def training_step(self, batch, idx):
losses = self.step(batch, idx)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
if "stop" in name:
if self.hparams.stop_loss_warmup is not None:
alpha, beta = self.hparams.stop_loss_warmup
warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0)
else:
warmup = 1.0
loss *= self.stop_loss_factor * warmup
total_loss += loss
self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def validation_step(self, batch, idx):
losses = self.step(batch, idx, validation=True)
total_loss = 0.0
for name, loss in losses.items():
self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
total_loss += loss
self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
return total_loss
def configure_optimizers(self):
params = [
{
"params": self.model.parameters(),
"weight_decay": self.weight_decay,
}
]
opt = torch.optim.AdamW(
params,
lr=self.learning_rate,
betas=self.betas,
)
# scheduler = get_cosine_schedule_with_warmup(
# opt,
# num_warmup_steps=self.n_warmup_steps,
# num_training_steps=self.n_training_steps,
# )
scheduler = LambdaLR(
opt,
lr_lambda=cosine_schedule_with_warmup(
warmup_steps=self.hparams.n_warmup_steps,
total_steps=self.hparams.n_training_steps,
start_lr=self.hparams.learning_rate,
end_lr=self.hparams.learning_rate * 0.1,
),
)
return [opt], [{"scheduler": scheduler, "interval": "step"}]
@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3")
def main(cfg: DictConfig):
ptl.seed_everything(cfg.seed_everything)
model = hydra.utils.instantiate(cfg.model)
cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}"
datamodule = hydra.utils.instantiate(cfg.data)
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
if __name__ == "__main__":
main()