ORI-Muchim commited on
Commit
a82e053
1 Parent(s): 057b06e

Upload 3 files

Browse files
Files changed (3) hide show
  1. pqmf.py +116 -0
  2. stft.py +295 -0
  3. stft_loss.py +136 -0
pqmf.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Pseudo QMF modules."""
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from scipy.signal.windows import kaiser
13
+
14
+
15
+ def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
16
+ """Design prototype filter for PQMF.
17
+ This method is based on `A Kaiser window approach for the design of prototype
18
+ filters of cosine modulated filterbanks`_.
19
+ Args:
20
+ taps (int): The number of filter taps.
21
+ cutoff_ratio (float): Cut-off frequency ratio.
22
+ beta (float): Beta coefficient for kaiser window.
23
+ Returns:
24
+ ndarray: Impluse response of prototype filter (taps + 1,).
25
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
26
+ https://ieeexplore.ieee.org/abstract/document/681427
27
+ """
28
+ # check the arguments are valid
29
+ assert taps % 2 == 0, "The number of taps mush be even number."
30
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
31
+
32
+ # make initial filter
33
+ omega_c = np.pi * cutoff_ratio
34
+ with np.errstate(invalid='ignore'):
35
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
36
+ / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
37
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
38
+
39
+ # apply kaiser window
40
+ w = kaiser(taps + 1, beta)
41
+ h = h_i * w
42
+
43
+ return h
44
+
45
+
46
+ class PQMF(torch.nn.Module):
47
+ """PQMF module.
48
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
49
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
50
+ https://ieeexplore.ieee.org/document/258122
51
+ """
52
+
53
+ def __init__(self, device, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
54
+ """Initilize PQMF module.
55
+ Args:
56
+ subbands (int): The number of subbands.
57
+ taps (int): The number of filter taps.
58
+ cutoff_ratio (float): Cut-off frequency ratio.
59
+ beta (float): Beta coefficient for kaiser window.
60
+ """
61
+ super(PQMF, self).__init__()
62
+
63
+ # define filter coefficient
64
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
65
+ h_analysis = np.zeros((subbands, len(h_proto)))
66
+ h_synthesis = np.zeros((subbands, len(h_proto)))
67
+ for k in range(subbands):
68
+ h_analysis[k] = 2 * h_proto * np.cos(
69
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
70
+ (np.arange(taps + 1) - ((taps - 1) / 2)) +
71
+ (-1) ** k * np.pi / 4)
72
+ h_synthesis[k] = 2 * h_proto * np.cos(
73
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
74
+ (np.arange(taps + 1) - ((taps - 1) / 2)) -
75
+ (-1) ** k * np.pi / 4)
76
+
77
+ # convert to tensor
78
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).cuda(device)
79
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0).cuda(device)
80
+
81
+ # register coefficients as beffer
82
+ self.register_buffer("analysis_filter", analysis_filter)
83
+ self.register_buffer("synthesis_filter", synthesis_filter)
84
+
85
+ # filter for downsampling & upsampling
86
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float().cuda(device)
87
+ for k in range(subbands):
88
+ updown_filter[k, k, 0] = 1.0
89
+ self.register_buffer("updown_filter", updown_filter)
90
+ self.subbands = subbands
91
+
92
+ # keep padding info
93
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
94
+
95
+ def analysis(self, x):
96
+ """Analysis with PQMF.
97
+ Args:
98
+ x (Tensor): Input tensor (B, 1, T).
99
+ Returns:
100
+ Tensor: Output tensor (B, subbands, T // subbands).
101
+ """
102
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
103
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
104
+
105
+ def synthesis(self, x):
106
+ """Synthesis with PQMF.
107
+ Args:
108
+ x (Tensor): Input tensor (B, subbands, T // subbands).
109
+ Returns:
110
+ Tensor: Output tensor (B, 1, T).
111
+ """
112
+ # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
113
+ # Not sure this is the correct way, it is better to check again.
114
+ # TODO(kan-bayashi): Understand the reconstruction procedure
115
+ x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
116
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
stft.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+ Copyright (c) 2017, Prem Seetharaman
4
+ All rights reserved.
5
+ * Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+ * Redistributions of source code must retain the above copyright notice,
8
+ this list of conditions and the following disclaimer.
9
+ * Redistributions in binary form must reproduce the above copyright notice, this
10
+ list of conditions and the following disclaimer in the
11
+ documentation and/or other materials provided with the distribution.
12
+ * Neither the name of the copyright holder nor the names of its
13
+ contributors may be used to endorse or promote products derived from this
14
+ software without specific prior written permission.
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ """
26
+
27
+ import torch
28
+ import numpy as np
29
+ import torch.nn.functional as F
30
+ from torch.autograd import Variable
31
+ from scipy.signal import get_window
32
+ from librosa.util import pad_center, tiny
33
+ import librosa.util as librosa_util
34
+
35
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
36
+ n_fft=800, dtype=np.float32, norm=None):
37
+ """
38
+ # from librosa 0.6
39
+ Compute the sum-square envelope of a window function at a given hop length.
40
+ This is used to estimate modulation effects induced by windowing
41
+ observations in short-time fourier transforms.
42
+ Parameters
43
+ ----------
44
+ window : string, tuple, number, callable, or list-like
45
+ Window specification, as in `get_window`
46
+ n_frames : int > 0
47
+ The number of analysis frames
48
+ hop_length : int > 0
49
+ The number of samples to advance between frames
50
+ win_length : [optional]
51
+ The length of the window function. By default, this matches `n_fft`.
52
+ n_fft : int > 0
53
+ The length of each analysis frame.
54
+ dtype : np.dtype
55
+ The data type of the output
56
+ Returns
57
+ -------
58
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
59
+ The sum-squared envelope of the window function
60
+ """
61
+ if win_length is None:
62
+ win_length = n_fft
63
+
64
+ n = n_fft + hop_length * (n_frames - 1)
65
+ x = np.zeros(n, dtype=dtype)
66
+
67
+ # Compute the squared window at the desired length
68
+ win_sq = get_window(window, win_length, fftbins=True)
69
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
70
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
71
+
72
+ # Fill the envelope
73
+ for i in range(n_frames):
74
+ sample = i * hop_length
75
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
76
+ return x
77
+
78
+
79
+ class STFT(torch.nn.Module):
80
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
81
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
82
+ window='hann'):
83
+ super(STFT, self).__init__()
84
+ self.filter_length = filter_length
85
+ self.hop_length = hop_length
86
+ self.win_length = win_length
87
+ self.window = window
88
+ self.forward_transform = None
89
+ scale = self.filter_length / self.hop_length
90
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
91
+
92
+ cutoff = int((self.filter_length / 2 + 1))
93
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
94
+ np.imag(fourier_basis[:cutoff, :])])
95
+
96
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
97
+ inverse_basis = torch.FloatTensor(
98
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
99
+
100
+ if window is not None:
101
+ assert(filter_length >= win_length)
102
+ # get window and zero center pad it to filter_length
103
+ fft_window = get_window(window, win_length, fftbins=True)
104
+ fft_window = pad_center(fft_window, filter_length)
105
+ fft_window = torch.from_numpy(fft_window).float()
106
+
107
+ # window the bases
108
+ forward_basis *= fft_window
109
+ inverse_basis *= fft_window
110
+
111
+ self.register_buffer('forward_basis', forward_basis.float())
112
+ self.register_buffer('inverse_basis', inverse_basis.float())
113
+
114
+ def transform(self, input_data):
115
+ num_batches = input_data.size(0)
116
+ num_samples = input_data.size(1)
117
+
118
+ self.num_samples = num_samples
119
+
120
+ # similar to librosa, reflect-pad the input
121
+ input_data = input_data.view(num_batches, 1, num_samples)
122
+ input_data = F.pad(
123
+ input_data.unsqueeze(1),
124
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
125
+ mode='reflect')
126
+ input_data = input_data.squeeze(1)
127
+
128
+ forward_transform = F.conv1d(
129
+ input_data,
130
+ Variable(self.forward_basis, requires_grad=False),
131
+ stride=self.hop_length,
132
+ padding=0)
133
+
134
+ cutoff = int((self.filter_length / 2) + 1)
135
+ real_part = forward_transform[:, :cutoff, :]
136
+ imag_part = forward_transform[:, cutoff:, :]
137
+
138
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
139
+ phase = torch.autograd.Variable(
140
+ torch.atan2(imag_part.data, real_part.data))
141
+
142
+ return magnitude, phase
143
+
144
+ def inverse(self, magnitude, phase):
145
+ recombine_magnitude_phase = torch.cat(
146
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
147
+
148
+ inverse_transform = F.conv_transpose1d(
149
+ recombine_magnitude_phase,
150
+ Variable(self.inverse_basis, requires_grad=False),
151
+ stride=self.hop_length,
152
+ padding=0)
153
+
154
+ if self.window is not None:
155
+ window_sum = window_sumsquare(
156
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
157
+ win_length=self.win_length, n_fft=self.filter_length,
158
+ dtype=np.float32)
159
+ # remove modulation effects
160
+ approx_nonzero_indices = torch.from_numpy(
161
+ np.where(window_sum > tiny(window_sum))[0])
162
+ window_sum = torch.autograd.Variable(
163
+ torch.from_numpy(window_sum), requires_grad=False)
164
+ window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum
165
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
166
+
167
+ # scale by hop ratio
168
+ inverse_transform *= float(self.filter_length) / self.hop_length
169
+
170
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
171
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
172
+
173
+ return inverse_transform
174
+
175
+ def forward(self, input_data):
176
+ self.magnitude, self.phase = self.transform(input_data)
177
+ reconstruction = self.inverse(self.magnitude, self.phase)
178
+ return reconstruction
179
+
180
+
181
+ class OnnxSTFT(torch.nn.Module):
182
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
183
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
184
+ window='hann'):
185
+ super(OnnxSTFT, self).__init__()
186
+ self.filter_length = filter_length
187
+ self.hop_length = hop_length
188
+ self.win_length = win_length
189
+ self.window = window
190
+ self.forward_transform = None
191
+ scale = self.filter_length / self.hop_length
192
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
193
+
194
+ cutoff = int((self.filter_length / 2 + 1))
195
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
196
+ np.imag(fourier_basis[:cutoff, :])])
197
+
198
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
199
+ inverse_basis = torch.FloatTensor(
200
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
201
+
202
+ if window is not None:
203
+ assert(filter_length >= win_length)
204
+ # get window and zero center pad it to filter_length
205
+ fft_window = get_window(window, win_length, fftbins=True)
206
+ fft_window = pad_center(fft_window, filter_length)
207
+ fft_window = torch.from_numpy(fft_window).float()
208
+
209
+ # window the bases
210
+ forward_basis *= fft_window
211
+ inverse_basis *= fft_window
212
+
213
+ self.register_buffer('forward_basis', forward_basis.float())
214
+ self.register_buffer('inverse_basis', inverse_basis.float())
215
+
216
+ def transform(self, input_data):
217
+ num_batches = input_data.size(0)
218
+ num_samples = input_data.size(1)
219
+
220
+ self.num_samples = num_samples
221
+
222
+ # similar to librosa, reflect-pad the input
223
+ input_data = input_data.view(num_batches, 1, num_samples)
224
+ input_data = F.pad(
225
+ input_data.unsqueeze(1),
226
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
227
+ mode='reflect')
228
+ input_data = input_data.squeeze(1)
229
+
230
+ forward_transform = F.conv1d(
231
+ input_data,
232
+ Variable(self.forward_basis, requires_grad=False),
233
+ stride=self.hop_length,
234
+ padding=0)
235
+
236
+ cutoff = int((self.filter_length / 2) + 1)
237
+ real_part = forward_transform[:, :cutoff, :]
238
+ imag_part = forward_transform[:, cutoff:, :]
239
+
240
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
241
+ phase = torch.autograd.Variable(
242
+ torch.atan2(imag_part.data, real_part.data))
243
+
244
+ return magnitude, phase
245
+
246
+ def inverse(self, magnitude, phase):
247
+ recombine_magnitude_phase = torch.cat(
248
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
249
+
250
+ inverse_transform = F.conv_transpose1d(
251
+ recombine_magnitude_phase,
252
+ Variable(self.inverse_basis, requires_grad=False),
253
+ stride=self.hop_length,
254
+ padding=0)
255
+
256
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
257
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
258
+
259
+ return inverse_transform
260
+
261
+ def forward(self, input_data):
262
+ self.magnitude, self.phase = self.transform(input_data)
263
+ reconstruction = self.inverse(self.magnitude, self.phase)
264
+ return reconstruction
265
+
266
+
267
+ class TorchSTFT(torch.nn.Module):
268
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
269
+ super().__init__()
270
+ self.filter_length = filter_length
271
+ self.hop_length = hop_length
272
+ self.win_length = win_length
273
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
274
+
275
+ def transform(self, input_data):
276
+ forward_transform = torch.stft(
277
+ input_data,
278
+ self.filter_length, self.hop_length, self.win_length, window=self.window,
279
+ return_complex=True)
280
+
281
+ return torch.abs(forward_transform), torch.angle(forward_transform)
282
+
283
+ def inverse(self, magnitude, phase):
284
+ inverse_transform = torch.istft(
285
+ magnitude * torch.exp(phase * 1j),
286
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
287
+
288
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
289
+
290
+ def forward(self, input_data):
291
+ self.magnitude, self.phase = self.transform(input_data)
292
+ reconstruction = self.inverse(self.magnitude, self.phase)
293
+ return reconstruction
294
+
295
+
stft_loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """STFT-based Loss modules."""
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+ Args:
15
+ x (Tensor): Input signal tensor (B, T).
16
+ fft_size (int): FFT size.
17
+ hop_size (int): Hop size.
18
+ win_length (int): Window length.
19
+ window (str): Window function type.
20
+ Returns:
21
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
+ """
23
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window.to(x.device))
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
28
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
29
+
30
+
31
+ class SpectralConvergengeLoss(torch.nn.Module):
32
+ """Spectral convergence loss module."""
33
+
34
+ def __init__(self):
35
+ """Initilize spectral convergence loss module."""
36
+ super(SpectralConvergengeLoss, self).__init__()
37
+
38
+ def forward(self, x_mag, y_mag):
39
+ """Calculate forward propagation.
40
+ Args:
41
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
42
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
43
+ Returns:
44
+ Tensor: Spectral convergence loss value.
45
+ """
46
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
47
+
48
+
49
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
50
+ """Log STFT magnitude loss module."""
51
+
52
+ def __init__(self):
53
+ """Initilize los STFT magnitude loss module."""
54
+ super(LogSTFTMagnitudeLoss, self).__init__()
55
+
56
+ def forward(self, x_mag, y_mag):
57
+ """Calculate forward propagation.
58
+ Args:
59
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
60
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
61
+ Returns:
62
+ Tensor: Log STFT magnitude loss value.
63
+ """
64
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
65
+
66
+
67
+ class STFTLoss(torch.nn.Module):
68
+ """STFT loss module."""
69
+
70
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
71
+ """Initialize STFT loss module."""
72
+ super(STFTLoss, self).__init__()
73
+ self.fft_size = fft_size
74
+ self.shift_size = shift_size
75
+ self.win_length = win_length
76
+ self.window = getattr(torch, window)(win_length)
77
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
78
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
79
+
80
+ def forward(self, x, y):
81
+ """Calculate forward propagation.
82
+ Args:
83
+ x (Tensor): Predicted signal (B, T).
84
+ y (Tensor): Groundtruth signal (B, T).
85
+ Returns:
86
+ Tensor: Spectral convergence loss value.
87
+ Tensor: Log STFT magnitude loss value.
88
+ """
89
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
90
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
91
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
92
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
93
+
94
+ return sc_loss, mag_loss
95
+
96
+
97
+ class MultiResolutionSTFTLoss(torch.nn.Module):
98
+ """Multi resolution STFT loss module."""
99
+
100
+ def __init__(self,
101
+ fft_sizes=[1024, 2048, 512],
102
+ hop_sizes=[120, 240, 50],
103
+ win_lengths=[600, 1200, 240],
104
+ window="hann_window"):
105
+ """Initialize Multi resolution STFT loss module.
106
+ Args:
107
+ fft_sizes (list): List of FFT sizes.
108
+ hop_sizes (list): List of hop sizes.
109
+ win_lengths (list): List of window lengths.
110
+ window (str): Window function type.
111
+ """
112
+ super(MultiResolutionSTFTLoss, self).__init__()
113
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
114
+ self.stft_losses = torch.nn.ModuleList()
115
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
116
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
117
+
118
+ def forward(self, x, y):
119
+ """Calculate forward propagation.
120
+ Args:
121
+ x (Tensor): Predicted signal (B, T).
122
+ y (Tensor): Groundtruth signal (B, T).
123
+ Returns:
124
+ Tensor: Multi resolution spectral convergence loss value.
125
+ Tensor: Multi resolution log STFT magnitude loss value.
126
+ """
127
+ sc_loss = 0.0
128
+ mag_loss = 0.0
129
+ for f in self.stft_losses:
130
+ sc_l, mag_l = f(x, y)
131
+ sc_loss += sc_l
132
+ mag_loss += mag_l
133
+ sc_loss /= len(self.stft_losses)
134
+ mag_loss /= len(self.stft_losses)
135
+
136
+ return sc_loss, mag_loss