Spaces:
Runtime error
Runtime error
| import os | |
| import pickle as pickle_tts | |
| from typing import Any, Callable, Dict, Union | |
| import fsspec | |
| import torch | |
| from TTS.utils.generic_utils import get_user_data_dir | |
| class RenamingUnpickler(pickle_tts.Unpickler): | |
| """Overload default pickler to solve module renaming problem""" | |
| def find_class(self, module, name): | |
| return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) | |
| class AttrDict(dict): | |
| """A custom dict which converts dict keys | |
| to class attributes""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def load_fsspec( | |
| path: str, | |
| map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, | |
| cache: bool = True, | |
| **kwargs, | |
| ) -> Any: | |
| """Like torch.load but can load from other locations (e.g. s3:// , gs://).""" | |
| if "weights_only" not in kwargs: | |
| kwargs["weights_only"] = False # 👈 forzar compatibilidad con checkpoints antiguos | |
| is_local = os.path.isdir(path) or os.path.isfile(path) | |
| if cache and not is_local: | |
| with fsspec.open( | |
| f"filecache::{path}", | |
| filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, | |
| mode="rb", | |
| ) as f: | |
| return torch.load(f, map_location=map_location, **kwargs) | |
| else: | |
| with fsspec.open(path, "rb") as f: | |
| return torch.load(f, map_location=map_location, **kwargs) | |
| def load_checkpoint( | |
| model, checkpoint_path, use_cuda=False, eval=False, cache=False | |
| ): # pylint: disable=redefined-builtin | |
| try: | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
| except ModuleNotFoundError: | |
| pickle_tts.Unpickler = RenamingUnpickler | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) | |
| model.load_state_dict(state["model"]) | |
| if use_cuda: | |
| model.cuda() | |
| if eval: | |
| model.eval() | |
| return model, state | |