XTTSv2-est / TTS /model.py
Rasmus Lellep
initial commit
5a03f53
raw
history blame
2.24 kB
import os
from abc import abstractmethod
from typing import Any, Union
import torch
from coqpit import Coqpit
from trainer import TrainerModel
# pylint: skip-file
class BaseTrainerModel(TrainerModel):
"""BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.
Every new Coqui model must inherit it.
"""
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
"""Init the model and all its attributes from the given config.
Override this depending on your model.
"""
...
@abstractmethod
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
"""Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
is considered to be the main output and you can add any other auxiliary outputs as you want.
We don't use `*kwargs` since it is problematic with the TorchScript API.
Args:
input (torch.Tensor): [description]
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
Returns:
Dict: [description]
"""
outputs_dict = {"model_outputs": None}
...
return outputs_dict
@abstractmethod
def load_checkpoint(
self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None:
"""Load a model checkpoint file and get ready for training or inference.
Args:
config (Coqpit): Model configuration.
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
"""
...