# # Copyright (c) Facebook, Inc. and its affiliates. # # All rights reserved. # # # # This source code is licensed under the license found in the # # LICENSE file in the root directory of this source tree. # # # Adapted from https://github.com/jik876/hifi-gan # import os import torch def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print("Loading '{}'".format(filepath)) checkpoint_dict = torch.load(filepath, map_location=device) print("Complete.") return checkpoint_dict class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self