import os import pickle as pickle_tts from typing import Any, Callable, Dict, Union import fsspec import torch from TTS.utils.generic_utils import get_user_data_dir class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" def find_class(self, module, name): return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) class AttrDict(dict): """A custom dict which converts dict keys to class attributes""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self def load_fsspec( path: str, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, cache: bool = True, **kwargs, ) -> Any: """Like torch.load but can load from other locations (e.g. s3:// , gs://).""" if "weights_only" not in kwargs: kwargs["weights_only"] = False # 👈 forzar compatibilidad con checkpoints antiguos is_local = os.path.isdir(path) or os.path.isfile(path) if cache and not is_local: with fsspec.open( f"filecache::{path}", filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, mode="rb", ) as f: return torch.load(f, map_location=map_location, **kwargs) else: with fsspec.open(path, "rb") as f: return torch.load(f, map_location=map_location, **kwargs) def load_checkpoint( model, checkpoint_path, use_cuda=False, eval=False, cache=False ): # pylint: disable=redefined-builtin try: state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) model.load_state_dict(state["model"]) if use_cuda: model.cuda() if eval: model.eval() return model, state