import torch class STFT: def __init__(self, n_fft, hop_length, win_length): self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = torch.hann_window(win_length) def __call__(self, y): self.window = self.window.to(y.device) stft_matrix = torch.stft( y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=False, center=True, pad_mode='reflect' ) return stft_matrix class iSTFT: def __init__(self, n_fft, hop_length, win_length): self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = torch.hann_window(win_length) def __call__(self, X): self.window = self.window.to(X.device) X = torch.view_as_complex(X.contiguous()) return torch.istft( X, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=True )