ORI-Muchim commited on
Commit
46ee079
1 Parent(s): 6f33d96

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +200 -168
utils.py CHANGED
@@ -11,216 +11,248 @@ import torch
11
 
12
  MATPLOTLIB_FLAG = False
13
 
14
- logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
15
  logger = logging
16
 
17
 
18
  def load_checkpoint(checkpoint_path, model, optimizer=None):
19
- assert os.path.isfile(checkpoint_path)
20
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
- iteration = checkpoint_dict['iteration']
22
- learning_rate = checkpoint_dict['learning_rate']
23
- if optimizer is not None:
24
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
- saved_state_dict = checkpoint_dict['model']
26
- if hasattr(model, 'module'):
27
- state_dict = model.module.state_dict()
28
- else:
29
- state_dict = model.state_dict()
30
- new_state_dict = {}
31
- for k, v in state_dict.items():
32
- try:
33
- new_state_dict[k] = saved_state_dict[k]
34
- except:
35
- logger.info("%s is not in the checkpoint" % k)
36
- new_state_dict[k] = v
37
- if hasattr(model, 'module'):
38
- model.module.load_state_dict(new_state_dict)
39
- else:
40
- model.load_state_dict(new_state_dict)
41
- logger.info("Loaded checkpoint '{}' (iteration {})".format(
42
- checkpoint_path, iteration))
43
- return model, optimizer, learning_rate, iteration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def plot_spectrogram_to_numpy(spectrogram):
47
- global MATPLOTLIB_FLAG
48
- if not MATPLOTLIB_FLAG:
49
- import matplotlib
50
- matplotlib.use("Agg")
51
- MATPLOTLIB_FLAG = True
52
- mpl_logger = logging.getLogger('matplotlib')
53
- mpl_logger.setLevel(logging.WARNING)
54
- import matplotlib.pylab as plt
55
- import numpy as np
56
-
57
- fig, ax = plt.subplots(figsize=(10, 2))
58
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
59
- interpolation='none')
60
- plt.colorbar(im, ax=ax)
61
- plt.xlabel("Frames")
62
- plt.ylabel("Channels")
63
- plt.tight_layout()
64
-
65
- fig.canvas.draw()
66
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
67
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
68
- plt.close()
69
- return data
70
 
71
 
72
  def plot_alignment_to_numpy(alignment, info=None):
73
- global MATPLOTLIB_FLAG
74
- if not MATPLOTLIB_FLAG:
75
- import matplotlib
76
- matplotlib.use("Agg")
77
- MATPLOTLIB_FLAG = True
78
- mpl_logger = logging.getLogger('matplotlib')
79
- mpl_logger.setLevel(logging.WARNING)
80
- import matplotlib.pylab as plt
81
- import numpy as np
82
-
83
- fig, ax = plt.subplots(figsize=(6, 4))
84
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
85
- interpolation='none')
86
- fig.colorbar(im, ax=ax)
87
- xlabel = 'Decoder timestep'
88
- if info is not None:
89
- xlabel += '\n\n' + info
90
- plt.xlabel(xlabel)
91
- plt.ylabel('Encoder timestep')
92
- plt.tight_layout()
93
-
94
- fig.canvas.draw()
95
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
96
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
97
- plt.close()
98
- return data
99
 
100
 
101
  def load_wav_to_torch(full_path):
102
- sampling_rate, data = read(full_path)
103
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
104
 
105
 
106
  def load_filepaths_and_text(filename, split="|"):
107
- with open(filename, encoding='utf-8') as f:
108
- filepaths_and_text = [line.strip().split(split) for line in f]
109
- return filepaths_and_text
110
 
111
 
112
  def get_hparams(init=True):
113
- parser = argparse.ArgumentParser()
114
- parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
115
- help='JSON file for configuration')
116
- parser.add_argument('-m', '--model', type=str, required=True,
117
- help='Model name')
118
-
119
- args = parser.parse_args()
120
- model_dir = os.path.join("./logs", args.model)
121
-
122
- if not os.path.exists(model_dir):
123
- os.makedirs(model_dir)
124
-
125
- config_path = args.config
126
- config_save_path = os.path.join(model_dir, "config.json")
127
- if init:
128
- with open(config_path, "r") as f:
129
- data = f.read()
130
- with open(config_save_path, "w") as f:
131
- f.write(data)
132
- else:
133
- with open(config_save_path, "r") as f:
134
- data = f.read()
135
- config = json.loads(data)
136
-
137
- hparams = HParams(**config)
138
- hparams.model_dir = model_dir
139
- return hparams
140
 
141
 
142
  def get_hparams_from_dir(model_dir):
143
- config_save_path = os.path.join(model_dir, "config.json")
144
- with open(config_save_path, "r") as f:
145
- data = f.read()
146
- config = json.loads(data)
147
 
148
- hparams = HParams(**config)
149
- hparams.model_dir = model_dir
150
- return hparams
151
 
152
 
153
  def get_hparams_from_file(config_path):
154
- with open(config_path, "r", encoding="utf-8") as f:
155
- data = f.read()
156
- config = json.loads(data)
157
 
158
- hparams = HParams(**config)
159
- return hparams
160
 
161
 
162
  def check_git_hash(model_dir):
163
- source_dir = os.path.dirname(os.path.realpath(__file__))
164
- if not os.path.exists(os.path.join(source_dir, ".git")):
165
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
166
- source_dir
167
- ))
168
- return
169
 
170
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
171
 
172
- path = os.path.join(model_dir, "githash")
173
- if os.path.exists(path):
174
- saved_hash = open(path).read()
175
- if saved_hash != cur_hash:
176
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
177
- saved_hash[:8], cur_hash[:8]))
178
- else:
179
- open(path, "w").write(cur_hash)
180
 
181
 
182
  def get_logger(model_dir, filename="train.log"):
183
- global logger
184
- logger = logging.getLogger(os.path.basename(model_dir))
185
- logger.setLevel(logging.DEBUG)
186
-
187
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
188
- if not os.path.exists(model_dir):
189
- os.makedirs(model_dir)
190
- h = logging.FileHandler(os.path.join(model_dir, filename))
191
- h.setLevel(logging.DEBUG)
192
- h.setFormatter(formatter)
193
- logger.addHandler(h)
194
- return logger
195
 
196
 
197
  class HParams():
198
- def __init__(self, **kwargs):
199
- for k, v in kwargs.items():
200
- if type(v) == dict:
201
- v = HParams(**v)
202
- self[k] = v
203
-
204
- def keys(self):
205
- return self.__dict__.keys()
206
 
207
- def items(self):
208
- return self.__dict__.items()
209
 
210
- def values(self):
211
- return self.__dict__.values()
212
 
213
- def __len__(self):
214
- return len(self.__dict__)
215
 
216
- def __getitem__(self, key):
217
- return getattr(self, key)
218
 
219
- def __setitem__(self, key, value):
220
- return setattr(self, key, value)
221
 
222
- def __contains__(self, key):
223
- return key in self.__dict__
224
 
225
- def __repr__(self):
226
- return self.__dict__.__repr__()
 
11
 
12
  MATPLOTLIB_FLAG = False
13
 
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
  logger = logging
16
 
17
 
18
  def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
+ iteration = checkpoint_dict['iteration']
22
+ learning_rate = checkpoint_dict['learning_rate']
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
+ saved_state_dict = checkpoint_dict['model']
26
+ if hasattr(model, 'module'):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict= {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, 'module'):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42
+ checkpoint_path, iteration))
43
+ return model, optimizer, learning_rate, iteration
44
+
45
+
46
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
48
+ iteration, checkpoint_path))
49
+ if hasattr(model, 'module'):
50
+ state_dict = model.module.state_dict()
51
+ else:
52
+ state_dict = model.state_dict()
53
+ torch.save({'model': state_dict,
54
+ 'iteration': iteration,
55
+ 'optimizer': optimizer.state_dict(),
56
+ 'learning_rate': learning_rate}, checkpoint_path)
57
+
58
+
59
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60
+ for k, v in scalars.items():
61
+ writer.add_scalar(k, v, global_step)
62
+ for k, v in histograms.items():
63
+ writer.add_histogram(k, v, global_step)
64
+ for k, v in images.items():
65
+ writer.add_image(k, v, global_step, dataformats='HWC')
66
+ for k, v in audios.items():
67
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
68
+
69
+
70
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71
+ f_list = glob.glob(os.path.join(dir_path, regex))
72
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73
+ x = f_list[-1]
74
+ print(x)
75
+ return x
76
 
77
 
78
  def plot_spectrogram_to_numpy(spectrogram):
79
+ global MATPLOTLIB_FLAG
80
+ if not MATPLOTLIB_FLAG:
81
+ import matplotlib
82
+ matplotlib.use("Agg")
83
+ MATPLOTLIB_FLAG = True
84
+ mpl_logger = logging.getLogger('matplotlib')
85
+ mpl_logger.setLevel(logging.WARNING)
86
+ import matplotlib.pylab as plt
87
+ import numpy as np
88
+
89
+ fig, ax = plt.subplots(figsize=(10,2))
90
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
91
+ interpolation='none')
92
+ plt.colorbar(im, ax=ax)
93
+ plt.xlabel("Frames")
94
+ plt.ylabel("Channels")
95
+ plt.tight_layout()
96
+
97
+ fig.canvas.draw()
98
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
99
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
100
+ plt.close()
101
+ return data
102
 
103
 
104
  def plot_alignment_to_numpy(alignment, info=None):
105
+ global MATPLOTLIB_FLAG
106
+ if not MATPLOTLIB_FLAG:
107
+ import matplotlib
108
+ matplotlib.use("Agg")
109
+ MATPLOTLIB_FLAG = True
110
+ mpl_logger = logging.getLogger('matplotlib')
111
+ mpl_logger.setLevel(logging.WARNING)
112
+ import matplotlib.pylab as plt
113
+ import numpy as np
114
+
115
+ fig, ax = plt.subplots(figsize=(6, 4))
116
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
117
+ interpolation='none')
118
+ fig.colorbar(im, ax=ax)
119
+ xlabel = 'Decoder timestep'
120
+ if info is not None:
121
+ xlabel += '\n\n' + info
122
+ plt.xlabel(xlabel)
123
+ plt.ylabel('Encoder timestep')
124
+ plt.tight_layout()
125
+
126
+ fig.canvas.draw()
127
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
128
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129
+ plt.close()
130
+ return data
131
 
132
 
133
  def load_wav_to_torch(full_path):
134
+ sampling_rate, data = read(full_path)
135
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
136
 
137
 
138
  def load_filepaths_and_text(filename, split="|"):
139
+ with open(filename, encoding='utf-8') as f:
140
+ filepaths_and_text = [line.strip().split(split) for line in f]
141
+ return filepaths_and_text
142
 
143
 
144
  def get_hparams(init=True):
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
147
+ help='JSON file for configuration')
148
+ parser.add_argument('-m', '--model', type=str, required=True,
149
+ help='Model name')
150
+
151
+ args = parser.parse_args()
152
+ model_dir = os.path.join("../models", args.model)
153
+
154
+ if not os.path.exists(model_dir):
155
+ os.makedirs(model_dir)
156
+
157
+ config_path = args.config
158
+ config_save_path = os.path.join(model_dir, "config.json")
159
+ if init:
160
+ with open(config_path, "r") as f:
161
+ data = f.read()
162
+ with open(config_save_path, "w") as f:
163
+ f.write(data)
164
+ else:
165
+ with open(config_save_path, "r") as f:
166
+ data = f.read()
167
+ config = json.loads(data)
168
+
169
+ hparams = HParams(**config)
170
+ hparams.model_dir = model_dir
171
+ return hparams
172
 
173
 
174
  def get_hparams_from_dir(model_dir):
175
+ config_save_path = os.path.join(model_dir, "config.json")
176
+ with open(config_save_path, "r") as f:
177
+ data = f.read()
178
+ config = json.loads(data)
179
 
180
+ hparams = HParams(**config)
181
+ hparams.model_dir = model_dir
182
+ return hparams
183
 
184
 
185
  def get_hparams_from_file(config_path):
186
+ with open(config_path, "r") as f:
187
+ data = f.read()
188
+ config = json.loads(data)
189
 
190
+ hparams = HParams(**config)
191
+ return hparams
192
 
193
 
194
  def check_git_hash(model_dir):
195
+ source_dir = os.path.dirname(os.path.realpath(__file__))
196
+ if not os.path.exists(os.path.join(source_dir, ".git")):
197
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
198
+ source_dir
199
+ ))
200
+ return
201
 
202
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
203
 
204
+ path = os.path.join(model_dir, "githash")
205
+ if os.path.exists(path):
206
+ saved_hash = open(path).read()
207
+ if saved_hash != cur_hash:
208
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
209
+ saved_hash[:8], cur_hash[:8]))
210
+ else:
211
+ open(path, "w").write(cur_hash)
212
 
213
 
214
  def get_logger(model_dir, filename="train.log"):
215
+ global logger
216
+ logger = logging.getLogger(os.path.basename(model_dir))
217
+ logger.setLevel(logging.DEBUG)
218
+
219
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
220
+ if not os.path.exists(model_dir):
221
+ os.makedirs(model_dir)
222
+ h = logging.FileHandler(os.path.join(model_dir, filename))
223
+ h.setLevel(logging.DEBUG)
224
+ h.setFormatter(formatter)
225
+ logger.addHandler(h)
226
+ return logger
227
 
228
 
229
  class HParams():
230
+ def __init__(self, **kwargs):
231
+ for k, v in kwargs.items():
232
+ if type(v) == dict:
233
+ v = HParams(**v)
234
+ self[k] = v
235
+
236
+ def keys(self):
237
+ return self.__dict__.keys()
238
 
239
+ def items(self):
240
+ return self.__dict__.items()
241
 
242
+ def values(self):
243
+ return self.__dict__.values()
244
 
245
+ def __len__(self):
246
+ return len(self.__dict__)
247
 
248
+ def __getitem__(self, key):
249
+ return getattr(self, key)
250
 
251
+ def __setitem__(self, key, value):
252
+ return setattr(self, key, value)
253
 
254
+ def __contains__(self, key):
255
+ return key in self.__dict__
256
 
257
+ def __repr__(self):
258
+ return self.__dict__.__repr__()