sob111 commited on
Commit
e8da02f
·
verified ·
1 Parent(s): ecaba54

Upload io.py

Browse files
Files changed (1) hide show
  1. io.py +70 -0
io.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pickle_tts
3
+ from typing import Any, Callable, Dict, Union
4
+
5
+ import fsspec
6
+ import torch
7
+
8
+ from TTS.utils.generic_utils import get_user_data_dir
9
+
10
+
11
+ class RenamingUnpickler(pickle_tts.Unpickler):
12
+ """Overload default pickler to solve module renaming problem"""
13
+
14
+ def find_class(self, module, name):
15
+ return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
16
+
17
+
18
+ class AttrDict(dict):
19
+ """A custom dict which converts dict keys
20
+ to class attributes"""
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.__dict__ = self
25
+
26
+
27
+ def load_fsspec(
28
+ path: str,
29
+ map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
30
+ cache: bool = True,
31
+ **kwargs,
32
+ ) -> Any:
33
+ """Like torch.load but can load from other locations (e.g. s3:// , gs://).
34
+
35
+ Args:
36
+ path: Any path or url supported by fsspec.
37
+ map_location: torch.device or str.
38
+ cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
39
+ **kwargs: Keyword arguments forwarded to torch.load.
40
+
41
+ Returns:
42
+ Object stored in path.
43
+ """
44
+ is_local = os.path.isdir(path) or os.path.isfile(path)
45
+ if cache and not is_local:
46
+ with fsspec.open(
47
+ f"filecache::{path}",
48
+ filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
49
+ mode="rb",
50
+ ) as f:
51
+ return torch.load(f, map_location=map_location, **kwargs)
52
+ else:
53
+ with fsspec.open(path, "rb") as f:
54
+ return torch.load(f, map_location=map_location, **kwargs)
55
+
56
+
57
+ def load_checkpoint(
58
+ model, checkpoint_path, use_cuda=False, eval=False, cache=False
59
+ ): # pylint: disable=redefined-builtin
60
+ try:
61
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
62
+ except ModuleNotFoundError:
63
+ pickle_tts.Unpickler = RenamingUnpickler
64
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
65
+ model.load_state_dict(state["model"])
66
+ if use_cuda:
67
+ model.cuda()
68
+ if eval:
69
+ model.eval()
70
+ return model, state