import functools import math from torch.nn import functional as F import torchaudio def capture_init(init): """capture_init. Decorate `__init__` with this, and you can then recover the *args and **kwargs passed to it in `self._init_args_kwargs` """ @functools.wraps(init) def __init__(self, *args, **kwargs): self._init_args_kwargs = (args, kwargs) init(self, *args, **kwargs) return __init__ def unfold(a, kernel_size, stride): """Given input of size [*OT, T], output Tensor of size [*OT, F, K] with K the kernel size, by extracting frames with the given stride. This will pad the input so that `F = ceil(T / K)`. see https://github.com/pytorch/pytorch/issues/60466 """ *shape, length = a.shape n_frames = math.ceil(length / stride) tgt_length = (n_frames - 1) * stride + kernel_size a = F.pad(a, (0, tgt_length - length)) strides = list(a.stride()) assert strides[-1] == 1, 'data should be contiguous' strides = strides[:-1] + [stride, 1] return a.as_strided([*shape, n_frames, kernel_size], strides) def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find("BatchNorm2d") != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def write(wav, filename, sr): # Normalize audio if it prevents clipping wav = wav / max(wav.abs().max().item(), 1) torchaudio.save(filename, wav.cpu(), sr)