groadabike commited on
Commit
7c6b998
·
verified ·
1 Parent(s): 9478e21

Upload 2 files

Browse files
Files changed (2) hide show
  1. dynamic_source_separator.py +23 -0
  2. tasnet.py +533 -0
dynamic_source_separator.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from huggingface_hub import ModelCard
5
+
6
+ from tasnet import ConvTasNetStereo
7
+
8
+
9
+ class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin):
10
+ def __init__(self, pre_trained_models):
11
+ super(DynamicSourceSeparator, self).__init__()
12
+ self.models = nn.ModuleDict(pre_trained_models)
13
+
14
+ def forward(self, mixture, indicator):
15
+ separated_sources = {}
16
+ for instrument, active in indicator.items():
17
+ if active:
18
+ model = self.models[instrument]
19
+ est_source = model(mixture)
20
+ separated_sources[instrument] = est_source[:, 0, :, :]
21
+ else:
22
+ separated_sources[instrument] = torch.zeros_like(mixture)
23
+ return separated_sources
tasnet.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from huggingface_hub import PyTorchModelHubMixin
39
+
40
+ EPS = 1e-8
41
+
42
+
43
+ def overlap_and_add(signal, frame_step):
44
+ outer_dimensions = signal.size()[:-2]
45
+ frames, frame_length = signal.size()[-2:]
46
+
47
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
48
+ subframe_step = frame_step // subframe_length
49
+ subframes_per_frame = frame_length // subframe_length
50
+ output_size = frame_step * (frames - 1) + frame_length
51
+ output_subframes = output_size // subframe_length
52
+
53
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
54
+
55
+ frame = torch.arange(0, output_subframes, device=signal.device).unfold(
56
+ 0, subframes_per_frame, subframe_step
57
+ )
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNetStereo(nn.Module, PyTorchModelHubMixin):
68
+ def __init__(
69
+ self,
70
+ N=256,
71
+ L=20,
72
+ B=256,
73
+ H=512,
74
+ P=3,
75
+ X=8,
76
+ R=4,
77
+ C=4,
78
+ audio_channels=1,
79
+ samplerate=44100,
80
+ norm_type="gLN",
81
+ causal=False,
82
+ mask_nonlinear="relu",
83
+ ):
84
+ """
85
+ Args:
86
+ N: Number of filters in autoencoder
87
+ L: Length of the filters (in samples)
88
+ B: Number of channels in bottleneck 1 × 1-conv block
89
+ H: Number of channels in convolutional blocks
90
+ P: Kernel size in convolutional blocks
91
+ X: Number of convolutional blocks in each repeat
92
+ R: Number of repeats
93
+ C: Number of speakers
94
+ norm_type: BN, gLN, cLN
95
+ causal: causal or non-causal
96
+ mask_nonlinear: use which non-linear function to generate mask
97
+ """
98
+ super().__init__()
99
+ # Hyper-parameter
100
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = (
101
+ N,
102
+ L,
103
+ B,
104
+ H,
105
+ P,
106
+ X,
107
+ R,
108
+ C,
109
+ )
110
+ self.norm_type = norm_type
111
+ self.causal = causal
112
+ self.mask_nonlinear = mask_nonlinear
113
+ self.audio_channels = audio_channels
114
+ self.samplerate = samplerate
115
+ # Components
116
+ self.encoder = Encoder(L, N, audio_channels)
117
+ self.separator = TemporalConvNet(
118
+ N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear
119
+ )
120
+ self.decoder = Decoder(N, L, audio_channels)
121
+ # init
122
+ for p in self.parameters():
123
+ if p.dim() > 1:
124
+ nn.init.xavier_normal_(p)
125
+
126
+ def valid_length(self, length):
127
+ return length
128
+
129
+ def forward(self, mixture):
130
+ """
131
+ Args:
132
+ mixture: [M, T], M is batch size, T is #samples
133
+ Returns:
134
+ est_source: [M, C, T]
135
+ """
136
+ mixture_w = self.encoder(mixture)
137
+ est_mask = self.separator(mixture_w)
138
+ est_source = self.decoder(mixture_w, est_mask)
139
+
140
+ # T changed after conv1d in encoder, fix it here
141
+ T_origin = mixture.size(-1)
142
+ T_conv = est_source.size(-1)
143
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
144
+ return est_source
145
+
146
+ def serialize(self):
147
+ """Serialize model and output dictionary.
148
+
149
+ Returns:
150
+ dict, serialized model with keys `model_args` and `state_dict`.
151
+ """
152
+ import pytorch_lightning as pl # Not used in torch.hub
153
+
154
+ model_conf = dict(
155
+ model_name=self.__class__.__name__,
156
+ state_dict=self.get_state_dict(),
157
+ # model_args=self.get_model_args(),
158
+ )
159
+ # Additional infos
160
+ infos = dict()
161
+ infos["software_versions"] = dict(
162
+ torch_version=torch.__version__,
163
+ pytorch_lightning_version=pl.__version__,
164
+ asteroid_version="0.7.0",
165
+ )
166
+ model_conf["infos"] = infos
167
+ return model_conf
168
+
169
+ def get_state_dict(self):
170
+ """In case the state dict needs to be modified before sharing the model."""
171
+ return self.state_dict()
172
+
173
+ def get_model_args(self):
174
+ """Arguments needed to re-instantiate the model."""
175
+ fb_config = self.encoder.filterbank.get_config()
176
+ masknet_config = self.masker.get_config()
177
+ # Assert both dict are disjoint
178
+ if not all(k not in fb_config for k in masknet_config):
179
+ raise AssertionError(
180
+ "Filterbank and Mask network config share common keys. Merging them is"
181
+ " not safe."
182
+ )
183
+ # Merge all args under model_args.
184
+ model_args = {
185
+ **fb_config,
186
+ **masknet_config,
187
+ "encoder_activation": self.encoder_activation,
188
+ }
189
+ return model_args
190
+
191
+
192
+ class Encoder(nn.Module):
193
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer."""
194
+
195
+ def __init__(self, L, N, audio_channels):
196
+ super().__init__()
197
+ # Hyper-parameter
198
+ self.L, self.N = L, N
199
+ # Components
200
+ # 50% overlap
201
+ self.conv1d_U = nn.Conv1d(
202
+ audio_channels, N, kernel_size=L, stride=L // 2, bias=False
203
+ )
204
+
205
+ def forward(self, mixture):
206
+ """
207
+ Args:
208
+ mixture: [M, T], M is batch size, T is #samples
209
+ Returns:
210
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
211
+ """
212
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
213
+ return mixture_w
214
+
215
+
216
+ class Decoder(nn.Module):
217
+ def __init__(self, N, L, audio_channels):
218
+ super().__init__()
219
+ # Hyper-parameter
220
+ self.N, self.L = N, L
221
+ self.audio_channels = audio_channels
222
+ # Components
223
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
224
+
225
+ def forward(self, mixture_w, est_mask):
226
+ """
227
+ Args:
228
+ mixture_w: [M, N, K]
229
+ est_mask: [M, C, N, K]
230
+ Returns:
231
+ est_source: [M, C, T]
232
+ """
233
+ # D = W * M
234
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
235
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
236
+ # S = DV
237
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
238
+ m, c, k, _ = est_source.size()
239
+ est_source = (
240
+ est_source.view(m, c, k, self.audio_channels, -1)
241
+ .transpose(2, 3)
242
+ .contiguous()
243
+ )
244
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
245
+ return est_source
246
+
247
+
248
+ class TemporalConvNet(nn.Module):
249
+ def __init__(
250
+ self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"
251
+ ):
252
+ """
253
+ Args:
254
+ N: Number of filters in autoencoder
255
+ B: Number of channels in bottleneck 1 × 1-conv block
256
+ H: Number of channels in convolutional blocks
257
+ P: Kernel size in convolutional blocks
258
+ X: Number of convolutional blocks in each repeat
259
+ R: Number of repeats
260
+ C: Number of speakers
261
+ norm_type: BN, gLN, cLN
262
+ causal: causal or non-causal
263
+ mask_nonlinear: use which non-linear function to generate mask
264
+ """
265
+ super().__init__()
266
+ # Hyper-parameter
267
+ self.C = C
268
+ self.mask_nonlinear = mask_nonlinear
269
+ # Components
270
+ # [M, N, K] -> [M, N, K]
271
+ layer_norm = ChannelwiseLayerNorm(N)
272
+ # [M, N, K] -> [M, B, K]
273
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
274
+ # [M, B, K] -> [M, B, K]
275
+ repeats = []
276
+ for _r in range(R):
277
+ blocks = []
278
+ for x in range(X):
279
+ dilation = 2**x
280
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
281
+ blocks += [
282
+ TemporalBlock(
283
+ B,
284
+ H,
285
+ P,
286
+ stride=1,
287
+ padding=padding,
288
+ dilation=dilation,
289
+ norm_type=norm_type,
290
+ causal=causal,
291
+ )
292
+ ]
293
+ repeats += [nn.Sequential(*blocks)]
294
+ temporal_conv_net = nn.Sequential(*repeats)
295
+ # [M, B, K] -> [M, C*N, K]
296
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
297
+ # Put together
298
+ self.network = nn.Sequential(
299
+ layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1
300
+ )
301
+
302
+ def forward(self, mixture_w):
303
+ """
304
+ Keep this API same with TasNet
305
+ Args:
306
+ mixture_w: [M, N, K], M is batch size
307
+ returns:
308
+ est_mask: [M, C, N, K]
309
+ """
310
+ M, N, K = mixture_w.size()
311
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
312
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
313
+ if self.mask_nonlinear == "softmax":
314
+ est_mask = F.softmax(score, dim=1)
315
+ elif self.mask_nonlinear == "relu":
316
+ est_mask = F.relu(score)
317
+ else:
318
+ raise ValueError("Unsupported mask non-linear function")
319
+ return est_mask
320
+
321
+
322
+ class TemporalBlock(nn.Module):
323
+ def __init__(
324
+ self,
325
+ in_channels,
326
+ out_channels,
327
+ kernel_size,
328
+ stride,
329
+ padding,
330
+ dilation,
331
+ norm_type="gLN",
332
+ causal=False,
333
+ ):
334
+ super().__init__()
335
+ # [M, B, K] -> [M, H, K]
336
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
337
+ prelu = nn.PReLU()
338
+ norm = chose_norm(norm_type, out_channels)
339
+ # [M, H, K] -> [M, B, K]
340
+ dsconv = DepthwiseSeparableConv(
341
+ out_channels,
342
+ in_channels,
343
+ kernel_size,
344
+ stride,
345
+ padding,
346
+ dilation,
347
+ norm_type,
348
+ causal,
349
+ )
350
+ # Put together
351
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
352
+
353
+ def forward(self, x):
354
+ """
355
+ Args:
356
+ x: [M, B, K]
357
+ Returns:
358
+ [M, B, K]
359
+ """
360
+ residual = x
361
+ out = self.net(x)
362
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
363
+ return out + residual # look like w/o F.relu is better than w/ F.relu
364
+ # return F.relu(out + residual)
365
+
366
+
367
+ class DepthwiseSeparableConv(nn.Module):
368
+ def __init__(
369
+ self,
370
+ in_channels,
371
+ out_channels,
372
+ kernel_size,
373
+ stride,
374
+ padding,
375
+ dilation,
376
+ norm_type="gLN",
377
+ causal=False,
378
+ ):
379
+ super().__init__()
380
+ # Use `groups` option to implement depthwise convolution
381
+ # [M, H, K] -> [M, H, K]
382
+ depthwise_conv = nn.Conv1d(
383
+ in_channels,
384
+ in_channels,
385
+ kernel_size,
386
+ stride=stride,
387
+ padding=padding,
388
+ dilation=dilation,
389
+ groups=in_channels,
390
+ bias=False,
391
+ )
392
+ if causal:
393
+ chomp = Chomp1d(padding)
394
+ prelu = nn.PReLU()
395
+ norm = chose_norm(norm_type, in_channels)
396
+ # [M, H, K] -> [M, B, K]
397
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
398
+ # Put together
399
+ if causal:
400
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
401
+ else:
402
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
403
+
404
+ def forward(self, x):
405
+ """
406
+ Args:
407
+ x: [M, H, K]
408
+ Returns:
409
+ result: [M, B, K]
410
+ """
411
+ return self.net(x)
412
+
413
+
414
+ class Chomp1d(nn.Module):
415
+ """To ensure the output length is the same as the input."""
416
+
417
+ def __init__(self, chomp_size):
418
+ super().__init__()
419
+ self.chomp_size = chomp_size
420
+
421
+ def forward(self, x):
422
+ """
423
+ Args:
424
+ x: [M, H, Kpad]
425
+ Returns:
426
+ [M, H, K]
427
+ """
428
+ return x[:, :, : -self.chomp_size].contiguous()
429
+
430
+
431
+ def chose_norm(norm_type, channel_size):
432
+ """The input of normlization will be (M, C, K), where M is batch size,
433
+ C is channel size and K is sequence length.
434
+ """
435
+ if norm_type == "gLN":
436
+ return GlobalLayerNorm(channel_size)
437
+ elif norm_type == "cLN":
438
+ return ChannelwiseLayerNorm(channel_size)
439
+ elif norm_type == "id":
440
+ return nn.Identity()
441
+ else: # norm_type == "BN":
442
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
443
+ # along M and K, so this BN usage is right.
444
+ return nn.BatchNorm1d(channel_size)
445
+
446
+
447
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
448
+ class ChannelwiseLayerNorm(nn.Module):
449
+ """Channel-wise Layer Normalization (cLN)"""
450
+
451
+ def __init__(self, channel_size):
452
+ super().__init__()
453
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
454
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
455
+ self.reset_parameters()
456
+
457
+ def reset_parameters(self):
458
+ self.gamma.data.fill_(1)
459
+ self.beta.data.zero_()
460
+
461
+ def forward(self, y):
462
+ """
463
+ Args:
464
+ y: [M, N, K], M is batch size, N is channel size, K is length
465
+ Returns:
466
+ cLN_y: [M, N, K]
467
+ """
468
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
469
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
470
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
471
+ return cLN_y
472
+
473
+
474
+ class GlobalLayerNorm(nn.Module):
475
+ """Global Layer Normalization (gLN)"""
476
+
477
+ def __init__(self, channel_size):
478
+ super().__init__()
479
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
480
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
481
+ self.reset_parameters()
482
+
483
+ def reset_parameters(self):
484
+ self.gamma.data.fill_(1)
485
+ self.beta.data.zero_()
486
+
487
+ def forward(self, y):
488
+ """
489
+ Args:
490
+ y: [M, N, K], M is batch size, N is channel size, K is length
491
+ Returns:
492
+ gLN_y: [M, N, K]
493
+ """
494
+ # TODO: in torch 1.0, torch.mean() support dim list
495
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
496
+ var = (
497
+ (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
498
+ )
499
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
500
+ return gLN_y
501
+
502
+
503
+ if __name__ == "__main__":
504
+ torch.manual_seed(123)
505
+ M, N, L, T = 2, 3, 4, 12
506
+ K = 2 * T // L - 1
507
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
508
+ mixture = torch.randint(3, (M, T))
509
+ # test Encoder
510
+ encoder = Encoder(L, N, 1)
511
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
512
+ mixture_w = encoder(mixture)
513
+ print("mixture", mixture)
514
+ print("U", encoder.conv1d_U.weight)
515
+ print("mixture_w", mixture_w)
516
+ print("mixture_w size", mixture_w.size())
517
+
518
+ # test TemporalConvNet
519
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
520
+ est_mask = separator(mixture_w)
521
+ print("est_mask", est_mask)
522
+
523
+ # test Decoder
524
+ decoder = Decoder(N, L, audio_channels=1)
525
+ est_mask = torch.randint(2, (B, K, C, N))
526
+ est_source = decoder(mixture_w, est_mask)
527
+ print("est_source", est_source)
528
+
529
+ # test Conv-TasNet
530
+ conv_tasnet = ConvTasNetStereo(N, L, B, H, P, X, R, C, norm_type=norm_type)
531
+ est_source = conv_tasnet(mixture)
532
+ print("est_source", est_source)
533
+ print("est_source size", est_source.size())