Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from . import normalizations, activations | |
| class _Chop1d(nn.Module): | |
| """To ensure the output length is the same as the input.""" | |
| def __init__(self, chop_size): | |
| super().__init__() | |
| self.chop_size = chop_size | |
| def forward(self, x): | |
| return x[..., : -self.chop_size].contiguous() | |
| class Conv1DBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| hid_chan, | |
| skip_out_chan, | |
| kernel_size, | |
| padding, | |
| dilation, | |
| norm_type="gLN", | |
| causal=False, | |
| ): | |
| super(Conv1DBlock, self).__init__() | |
| self.skip_out_chan = skip_out_chan | |
| conv_norm = normalizations.get(norm_type) | |
| in_conv1d = nn.Conv1d(in_chan, hid_chan, 1) | |
| depth_conv1d = nn.Conv1d( | |
| hid_chan, | |
| hid_chan, | |
| kernel_size, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=hid_chan, | |
| ) | |
| if causal: | |
| depth_conv1d = nn.Sequential(depth_conv1d, _Chop1d(padding)) | |
| self.shared_block = nn.Sequential( | |
| in_conv1d, | |
| nn.PReLU(), | |
| conv_norm(hid_chan), | |
| depth_conv1d, | |
| nn.PReLU(), | |
| conv_norm(hid_chan), | |
| ) | |
| self.res_conv = nn.Conv1d(hid_chan, in_chan, 1) | |
| if skip_out_chan: | |
| self.skip_conv = nn.Conv1d(hid_chan, skip_out_chan, 1) | |
| def forward(self, x): | |
| r"""Input shape $(batch, feats, seq)$.""" | |
| shared_out = self.shared_block(x) | |
| res_out = self.res_conv(shared_out) | |
| if not self.skip_out_chan: | |
| return res_out | |
| skip_out = self.skip_conv(shared_out) | |
| return res_out, skip_out | |
| class ConvNormAct(nn.Module): | |
| """ | |
| This class defines the convolution layer with normalization and a PReLU | |
| activation | |
| """ | |
| def __init__( | |
| self, | |
| in_chan, | |
| out_chan, | |
| kernel_size, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super(ConvNormAct, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_chan, | |
| out_chan, | |
| kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| padding=padding, | |
| bias=True, | |
| groups=groups, | |
| ) | |
| self.norm = normalizations.get(norm_type)(out_chan) | |
| self.act = activations.get(act_type)() | |
| def forward(self, x): | |
| output = self.conv(x) | |
| output = self.norm(output) | |
| return self.act(output) | |
| class ConvNorm(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| out_chan, | |
| kernel_size, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type="gLN", | |
| ): | |
| super(ConvNorm, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_chan, | |
| out_chan, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| bias=True, | |
| groups=groups, | |
| ) | |
| self.norm = normalizations.get(norm_type)(out_chan) | |
| def forward(self, x): | |
| output = self.conv(x) | |
| return self.norm(output) | |
| class NormAct(nn.Module): | |
| """ | |
| This class defines a normalization and PReLU activation | |
| """ | |
| def __init__( | |
| self, out_chan, norm_type="gLN", act_type="prelu", | |
| ): | |
| """ | |
| :param nOut: number of output channels | |
| """ | |
| super(NormAct, self).__init__() | |
| # self.norm = nn.GroupNorm(1, nOut, eps=1e-08) | |
| self.norm = normalizations.get(norm_type)(out_chan) | |
| self.act = activations.get(act_type)() | |
| def forward(self, input): | |
| output = self.norm(input) | |
| return self.act(output) | |
| class Video1DConv(nn.Module): | |
| """ | |
| video part 1-D Conv Block | |
| in_chan: video Encoder output channels | |
| out_chan: dconv channels | |
| kernel_size: the depthwise conv kernel size | |
| dilation: the depthwise conv dilation | |
| residual: Whether to use residual connection | |
| skip_con: Whether to use skip connection | |
| first_block: first block, not residual | |
| """ | |
| def __init__( | |
| self, | |
| in_chan, | |
| out_chan, | |
| kernel_size, | |
| dilation=1, | |
| residual=True, | |
| skip_con=True, | |
| first_block=True, | |
| ): | |
| super(Video1DConv, self).__init__() | |
| self.first_block = first_block | |
| # first block, not residual | |
| self.residual = residual and not first_block | |
| self.bn = nn.BatchNorm1d(in_chan) if not first_block else None | |
| self.relu = nn.ReLU() if not first_block else None | |
| self.dconv = nn.Conv1d( | |
| in_chan, | |
| in_chan, | |
| kernel_size, | |
| groups=in_chan, | |
| dilation=dilation, | |
| padding=(dilation * (kernel_size - 1)) // 2, | |
| bias=True, | |
| ) | |
| self.bconv = nn.Conv1d(in_chan, out_chan, 1) | |
| self.sconv = nn.Conv1d(in_chan, out_chan, 1) | |
| self.skip_con = skip_con | |
| def forward(self, x): | |
| """ | |
| x: [B, N, T] | |
| out: [B, N, T] | |
| """ | |
| if not self.first_block: | |
| y = self.bn(self.relu(x)) | |
| y = self.dconv(y) | |
| else: | |
| y = self.dconv(x) | |
| # skip connection | |
| if self.skip_con: | |
| skip = self.sconv(y) | |
| if self.residual: | |
| y = y + x | |
| return skip, y | |
| else: | |
| return skip, y | |
| else: | |
| y = self.bconv(y) | |
| if self.residual: | |
| y = y + x | |
| return y | |
| else: | |
| return y | |
| class Concat(nn.Module): | |
| def __init__(self, ain_chan, vin_chan, out_chan): | |
| super(Concat, self).__init__() | |
| self.ain_chan = ain_chan | |
| self.vin_chan = vin_chan | |
| # project | |
| self.conv1d = nn.Sequential( | |
| nn.Conv1d(ain_chan + vin_chan, out_chan, 1), nn.PReLU() | |
| ) | |
| def forward(self, a, v): | |
| # up-sample video features | |
| v = torch.nn.functional.interpolate(v, size=a.size(-1)) | |
| # concat: n x (A+V) x Ta | |
| y = torch.cat([a, v], dim=1) | |
| # conv1d | |
| return self.conv1d(y) | |
| class FRCNNBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| self.proj_1x1 = ConvNormAct( | |
| in_chan, | |
| out_chan, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| self.depth = upsampling_depth | |
| self.spp_dw = nn.ModuleList([]) | |
| self.spp_dw.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=1, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| # ----------Down Sample Layer---------- | |
| for i in range(1, upsampling_depth): | |
| self.spp_dw.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=2, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| # ----------Fusion Layer---------- | |
| self.fuse_layers = nn.ModuleList([]) | |
| for i in range(upsampling_depth): | |
| fuse_layer = nn.ModuleList([]) | |
| for j in range(upsampling_depth): | |
| if i == j: | |
| fuse_layer.append(None) | |
| elif j - i == 1: | |
| fuse_layer.append(None) | |
| elif i - j == 1: | |
| fuse_layer.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=2, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| self.fuse_layers.append(fuse_layer) | |
| self.concat_layer = nn.ModuleList([]) | |
| # ----------Concat Layer---------- | |
| for i in range(upsampling_depth): | |
| if i == 0 or i == upsampling_depth - 1: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 2, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| else: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 3, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.last_layer = nn.Sequential( | |
| ConvNormAct( | |
| out_chan * upsampling_depth, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.res_conv = nn.Conv1d(out_chan, in_chan, 1) | |
| # ----------parameters------------- | |
| self.depth = upsampling_depth | |
| def forward(self, x): | |
| """ | |
| :param x: input feature map | |
| :return: transformed feature map | |
| """ | |
| residual = x.clone() | |
| # Reduce --> project high-dimensional feature maps to low-dimensional space | |
| output1 = self.proj_1x1(x) | |
| output = [self.spp_dw[0](output1)] | |
| for k in range(1, self.depth): | |
| out_k = self.spp_dw[k](output[-1]) | |
| output.append(out_k) | |
| x_fuse = [] | |
| for i in range(len(self.fuse_layers)): | |
| wav_length = output[i].shape[-1] | |
| y = torch.cat( | |
| ( | |
| self.fuse_layers[i][0](output[i - 1]) | |
| if i - 1 >= 0 | |
| else torch.Tensor().to(output1.device), | |
| output[i], | |
| F.interpolate(output[i + 1], size=wav_length, mode="nearest") | |
| if i + 1 < self.depth | |
| else torch.Tensor().to(output1.device), | |
| ), | |
| dim=1, | |
| ) | |
| x_fuse.append(self.concat_layer[i](y)) | |
| wav_length = output[0].shape[-1] | |
| for i in range(1, len(x_fuse)): | |
| x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") | |
| concat = self.last_layer(torch.cat(x_fuse, dim=1)) | |
| expanded = self.res_conv(concat) | |
| return expanded + residual | |
| class Bottomup(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| self.proj_1x1 = ConvNormAct( | |
| in_chan, | |
| out_chan, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| self.depth = upsampling_depth | |
| self.spp_dw = nn.ModuleList([]) | |
| self.spp_dw.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=1, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| # ----------Down Sample Layer---------- | |
| for i in range(1, upsampling_depth): | |
| self.spp_dw.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=2, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| def forward(self, x): | |
| residual = x.clone() | |
| # Reduce --> project high-dimensional feature maps to low-dimensional space | |
| output1 = self.proj_1x1(x) | |
| output = [self.spp_dw[0](output1)] | |
| for k in range(1, self.depth): | |
| out_k = self.spp_dw[k](output[-1]) | |
| output.append(out_k) | |
| return residual, output[-1], output | |
| class BottomupTCN(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| self.proj_1x1 = ConvNormAct( | |
| in_chan, | |
| out_chan, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| self.depth = upsampling_depth | |
| self.spp_dw = nn.ModuleList([]) | |
| self.spp_dw.append( | |
| Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) | |
| ) | |
| # ----------Down Sample Layer---------- | |
| for i in range(1, upsampling_depth): | |
| self.spp_dw.append( | |
| Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) | |
| ) | |
| def forward(self, x): | |
| residual = x.clone() | |
| # Reduce --> project high-dimensional feature maps to low-dimensional space | |
| output1 = self.proj_1x1(x) | |
| output = [self.spp_dw[0](output1)] | |
| for k in range(1, self.depth): | |
| out_k = self.spp_dw[k](output[-1]) | |
| output.append(out_k) | |
| return residual, output[-1], output | |
| class Bottomup_Concat_Topdown(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| # ----------Fusion Layer---------- | |
| self.fuse_layers = nn.ModuleList([]) | |
| for i in range(upsampling_depth): | |
| fuse_layer = nn.ModuleList([]) | |
| for j in range(upsampling_depth): | |
| if i == j: | |
| fuse_layer.append(None) | |
| elif j - i == 1: | |
| fuse_layer.append(None) | |
| elif i - j == 1: | |
| fuse_layer.append( | |
| ConvNorm( | |
| out_chan, | |
| out_chan, | |
| kernel_size=5, | |
| stride=2, | |
| groups=out_chan, | |
| dilation=1, | |
| padding=((5 - 1) // 2) * 1, | |
| norm_type=norm_type, | |
| ) | |
| ) | |
| self.fuse_layers.append(fuse_layer) | |
| self.concat_layer = nn.ModuleList([]) | |
| # ----------Concat Layer---------- | |
| for i in range(upsampling_depth): | |
| if i == 0 or i == upsampling_depth - 1: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 3, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| else: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 4, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.last_layer = nn.Sequential( | |
| ConvNormAct( | |
| out_chan * upsampling_depth, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.res_conv = nn.Conv1d(out_chan, in_chan, 1) | |
| # ----------parameters------------- | |
| self.depth = upsampling_depth | |
| def forward(self, residual, bottomup, topdown): | |
| x_fuse = [] | |
| for i in range(len(self.fuse_layers)): | |
| wav_length = bottomup[i].shape[-1] | |
| y = torch.cat( | |
| ( | |
| self.fuse_layers[i][0](bottomup[i - 1]) | |
| if i - 1 >= 0 | |
| else torch.Tensor().to(bottomup[i].device), | |
| bottomup[i], | |
| F.interpolate(bottomup[i + 1], size=wav_length, mode="nearest") | |
| if i + 1 < self.depth | |
| else torch.Tensor().to(bottomup[i].device), | |
| F.interpolate(topdown, size=wav_length, mode="nearest"), | |
| ), | |
| dim=1, | |
| ) | |
| x_fuse.append(self.concat_layer[i](y)) | |
| wav_length = bottomup[0].shape[-1] | |
| for i in range(1, len(x_fuse)): | |
| x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") | |
| concat = self.last_layer(torch.cat(x_fuse, dim=1)) | |
| expanded = self.res_conv(concat) | |
| return expanded + residual | |
| class Bottomup_Concat_Topdown_TCN(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| # ----------Fusion Layer---------- | |
| self.fuse_layers = nn.ModuleList([]) | |
| for i in range(upsampling_depth): | |
| fuse_layer = nn.ModuleList([]) | |
| for j in range(upsampling_depth): | |
| if i == j: | |
| fuse_layer.append(None) | |
| elif j - i == 1: | |
| fuse_layer.append(None) | |
| elif i - j == 1: | |
| fuse_layer.append(None) | |
| self.fuse_layers.append(fuse_layer) | |
| self.concat_layer = nn.ModuleList([]) | |
| # ----------Concat Layer---------- | |
| for i in range(upsampling_depth): | |
| if i == 0 or i == upsampling_depth - 1: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 3, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| else: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 4, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.last_layer = nn.Sequential( | |
| ConvNormAct( | |
| out_chan * upsampling_depth, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.res_conv = nn.Conv1d(out_chan, in_chan, 1) | |
| # ----------parameters------------- | |
| self.depth = upsampling_depth | |
| def forward(self, residual, bottomup, topdown): | |
| x_fuse = [] | |
| for i in range(len(self.fuse_layers)): | |
| wav_length = bottomup[i].shape[-1] | |
| y = torch.cat( | |
| ( | |
| bottomup[i - 1] | |
| if i - 1 >= 0 | |
| else torch.Tensor().to(bottomup[i].device), | |
| bottomup[i], | |
| bottomup[i + 1] | |
| if i + 1 < self.depth | |
| else torch.Tensor().to(bottomup[i].device), | |
| F.interpolate(topdown, size=wav_length, mode="nearest"), | |
| ), | |
| dim=1, | |
| ) | |
| x_fuse.append(self.concat_layer[i](y)) | |
| concat = self.last_layer(torch.cat(x_fuse, dim=1)) | |
| expanded = self.res_conv(concat) | |
| return expanded + residual | |
| class FRCNNBlockTCN(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan=128, | |
| out_chan=512, | |
| upsampling_depth=4, | |
| norm_type="gLN", | |
| act_type="prelu", | |
| ): | |
| super().__init__() | |
| self.proj_1x1 = ConvNormAct( | |
| in_chan, | |
| out_chan, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| dilation=1, | |
| padding=0, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| self.depth = upsampling_depth | |
| self.spp_dw = nn.ModuleList([]) | |
| self.spp_dw.append( | |
| Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) | |
| ) | |
| # ----------Down Sample Layer---------- | |
| for i in range(1, upsampling_depth): | |
| self.spp_dw.append( | |
| Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) | |
| ) | |
| # ----------Fusion Layer---------- | |
| self.fuse_layers = nn.ModuleList([]) | |
| for i in range(upsampling_depth): | |
| fuse_layer = nn.ModuleList([]) | |
| for j in range(upsampling_depth): | |
| if i == j: | |
| fuse_layer.append(None) | |
| elif j - i == 1: | |
| fuse_layer.append(None) | |
| elif i - j == 1: | |
| fuse_layer.append(None) | |
| self.fuse_layers.append(fuse_layer) | |
| self.concat_layer = nn.ModuleList([]) | |
| # ----------Concat Layer---------- | |
| for i in range(upsampling_depth): | |
| if i == 0 or i == upsampling_depth - 1: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 2, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| else: | |
| self.concat_layer.append( | |
| ConvNormAct( | |
| out_chan * 3, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.last_layer = nn.Sequential( | |
| ConvNormAct( | |
| out_chan * upsampling_depth, | |
| out_chan, | |
| 1, | |
| 1, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| ) | |
| ) | |
| self.res_conv = nn.Conv1d(out_chan, in_chan, 1) | |
| # ----------parameters------------- | |
| self.depth = upsampling_depth | |
| def forward(self, x): | |
| """ | |
| :param x: input feature map | |
| :return: transformed feature map | |
| """ | |
| residual = x.clone() | |
| # Reduce --> project high-dimensional feature maps to low-dimensional space | |
| output1 = self.proj_1x1(x) | |
| output = [self.spp_dw[0](output1)] | |
| for k in range(1, self.depth): | |
| out_k = self.spp_dw[k](output[-1]) | |
| output.append(out_k) | |
| x_fuse = [] | |
| for i in range(len(self.fuse_layers)): | |
| wav_length = output[i].shape[-1] | |
| y = torch.cat( | |
| ( | |
| output[i - 1] if i - 1 >= 0 else torch.Tensor().to(output1.device), | |
| output[i], | |
| output[i + 1] | |
| if i + 1 < self.depth | |
| else torch.Tensor().to(output1.device), | |
| ), | |
| dim=1, | |
| ) | |
| x_fuse.append(self.concat_layer[i](y)) | |
| concat = self.last_layer(torch.cat(x_fuse, dim=1)) | |
| expanded = self.res_conv(concat) | |
| return expanded + residual | |
| class TAC(nn.Module): | |
| """Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1]. | |
| Args: | |
| input_dim (int): Number of features of input representation. | |
| hidden_dim (int, optional): size of hidden layers in TAC operations. | |
| activation (str, optional): type of activation used. See asteroid.masknn.activations. | |
| norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms. | |
| .. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)` | |
| as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``. | |
| Output is of same shape as input. | |
| References | |
| [1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel | |
| speech separation." ICASSP 2020. | |
| """ | |
| def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.input_tf = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), activations.get(activation)() | |
| ) | |
| self.avg_tf = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), activations.get(activation)() | |
| ) | |
| self.concat_tf = nn.Sequential( | |
| nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)() | |
| ) | |
| self.norm = normalizations.get(norm_type)(input_dim) | |
| def forward(self, x, valid_mics=None): | |
| """ | |
| Args: | |
| x: (:class:`torch.Tensor`): Input multi-channel DPRNN features. | |
| Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. | |
| valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. | |
| Batches can be composed of examples coming from arrays with a different | |
| number of microphones and thus the ``mic_channels`` dimension is padded. | |
| E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. | |
| Shape: :math`(batch)`. | |
| Returns: | |
| output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing. | |
| Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. | |
| """ | |
| # Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D. | |
| batch_size, nmics, channels, chunk_size, n_chunks = x.size() | |
| if valid_mics is None: | |
| valid_mics = torch.LongTensor([nmics] * batch_size) | |
| # First operation: transform the input for each frame and independently on each mic channel. | |
| output = self.input_tf( | |
| x.permute(0, 3, 4, 1, 2).reshape( | |
| batch_size * nmics * chunk_size * n_chunks, channels | |
| ) | |
| ).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim) | |
| # Mean pooling across channels | |
| if valid_mics.max() == 0: | |
| # Fixed geometry array | |
| mics_mean = output.mean(1) | |
| else: | |
| # Only consider valid channels in each batch element: each example can have different number of microphones. | |
| mics_mean = [ | |
| output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) | |
| for b in range(batch_size) | |
| ] # 1, dim1*dim2, H | |
| mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H | |
| # The average is processed by a non-linear transform | |
| mics_mean = self.avg_tf( | |
| mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim) | |
| ) | |
| mics_mean = ( | |
| mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim) | |
| .unsqueeze(3) | |
| .expand_as(output) | |
| ) | |
| # Concatenate the transformed average in each channel with the original feats and | |
| # project back to same number of features | |
| output = torch.cat([output, mics_mean], -1) | |
| output = self.concat_tf( | |
| output.reshape(batch_size * chunk_size * n_chunks * nmics, -1) | |
| ).reshape(batch_size, chunk_size, n_chunks, nmics, -1) | |
| output = self.norm( | |
| output.permute(0, 3, 4, 1, 2).reshape( | |
| batch_size * nmics, -1, chunk_size, n_chunks | |
| ) | |
| ).reshape(batch_size, nmics, -1, chunk_size, n_chunks) | |
| output += x | |
| return output | |