from typing import Optional, Union, Type, List import torch import torch.nn as nn import torch.nn.functional as F from .module import NeuralModule from .tdnn_attention import ( StatsPoolLayer, AttentivePoolLayer, ChannelDependentAttentiveStatisticsPoolLayer, TdnnModule, TdnnSeModule, TdnnSeRes2NetModule, init_weights ) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """2D convolution with kernel_size = 3""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, groups=groups, padding=dilation, bias=False, ) def conv1x1(in_planes, out_planes, stride=1): """2D convolution with kernel_size = 1""" return nn.Conv2d( in_planes, out_planes, kernel_size=1, stride=stride, bias=False ) class BasicBlock(nn.Module): expansion = 1 def __init__( self, in_channels, out_channels, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 16, dilation: int = 1, activation: Optional[nn.Module] = nn.ReLU, ): super(BasicBlock, self).__init__() if groups != 1 or base_width != 16: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.activation = activation() self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = conv3x3(in_channels, out_channels, stride) self.bn2 = nn.BatchNorm2d(out_channels) self.conv2 = conv3x3(out_channels, out_channels) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.bn1(x) out = self.activation(out) out = self.conv1(out) out = self.bn2(out) out = self.activation(out) out = self.conv2(out) if self.downsample is not None: residual = self.downsample(x) out += residual return out class SEBlock(nn.Module): def __init__(self, channels, reduction=1, activation=nn.ReLU): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channels, channels // reduction), activation(), nn.Linear(channels // reduction, channels), nn.Sigmoid(), ) def forward(self, x): """Intermediate step. Processes the input tensor x and returns an output tensor. """ b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y class SEBasicBlock(nn.Module): expansion = 1 def __init__( self, in_channels, out_channels, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 16, dilation: int = 1, activation: Optional[nn.Module] = nn.ReLU, reduction: int = 8 ): super(SEBasicBlock, self).__init__() if groups != 1 or base_width != 16: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.activation = activation() self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = conv3x3(in_channels, out_channels, stride) self.bn2 = nn.BatchNorm2d(out_channels) self.conv2 = conv3x3(out_channels, out_channels) self.se = SEBlock(out_channels, reduction) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.bn1(x) out = self.activation(out) out = self.conv1(out) out = self.bn2(out) out = self.activation(out) out = self.conv2(out) out = self.se(out) if self.downsample is not None: residual = self.downsample(x) out += residual return out class SEBottleneck(nn.Module): expansion = 4 def __init__( self, in_channels, out_channels, stride=1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 16, dilation: int = 1, activation=nn.ReLU, reduction: int = 8, ): super(SEBottleneck, self).__init__() width = int(out_channels * (base_width / 16.)) * groups self.activation = activation() # 1x1 convolution to reduce channels self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = conv1x1(in_channels, width, stride=1) # 3x3 convolution self.bn2 = nn.BatchNorm2d(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) # 1x1 convolution to restore channels self.bn3 = nn.BatchNorm2d(width) self.conv3 = conv1x1(width, out_channels * self.expansion) # Squeeze-and-Excitation block self.se = SEBlock(out_channels * self.expansion, reduction) self.downsample = downsample self.stride = stride def forward(self, x): residual = x # First 1x1 convolution out = self.bn1(x) out = self.activation(out) out = self.conv1(out) # 3x3 convolution out = self.bn2(out) out = self.activation(out) out = self.conv2(out) # Second 1x1 convolution out = self.bn3(out) out = self.activation(out) out = self.conv3(out) # Apply SE block out = self.se(out) # Downsample residual if needed if self.downsample is not None: residual = self.downsample(x) # Add residual out += residual return out class Bottleneck(nn.Module): expansion = 4 def __init__( self, in_channels, out_channels, stride=1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 16, dilation: int = 1, activation=nn.ReLU, ): super(Bottleneck, self).__init__() width = int(out_channels * (base_width / 16.)) * groups self.activation = activation() # 1x1 convolution to reduce channels self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = conv1x1(in_channels, width, stride=1) # 3x3 convolution self.bn2 = nn.BatchNorm2d(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) # 1x1 convolution to restore channels self.bn3 = nn.BatchNorm2d(width) self.conv3 = conv1x1(width, out_channels * self.expansion) self.downsample = downsample self.stride = stride def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x # First 1x1 convolution out = self.bn1(x) out = self.activation(out) out = self.conv1(out) # 3x3 convolution out = self.bn2(out) out = self.activation(out) out = self.conv2(out) # Second 1x1 convolution out = self.bn3(out) out = self.activation(out) out = self.conv3(out) # Downsample residual if needed if self.downsample is not None: residual = self.downsample(x) # Add residual out += residual return out class ResNetEncoder(NeuralModule): def __init__( self, feat_in: int, filters: list = [16, 32, 64, 128], block_sizes: list = [3, 4, 6, 3], strides: list = [1, 2, 2, 1], groups: int = 1, width_per_group: int = 16, replace_stride_with_dilation: Optional[List[bool]] = None, block_type: str = 'basic', # basic, bottleneck reduction: int = 8, # reduction for SE layer init_mode: str = 'xavier_uniform', ): super().__init__() if block_type == 'basic': self.block_class = BasicBlock self.se_block_class = SEBasicBlock elif block_type == 'bottleneck': self.block_class = Bottleneck self.se_block_class = SEBottleneck self.in_channels = filters[0] self.dilation = 1 self.reduction = reduction if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.pre_conv = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=filters[0], kernel_size=3, stride=1, padding=1, bias=False ), nn.BatchNorm2d(filters[0]), nn.ReLU(inplace=True) ) self.layer1 = self._make_layer_se(self.se_block_class, filters[0], block_sizes[0], strides[0]) self.layer2 = self._make_layer_se(self.se_block_class, filters[1], block_sizes[1], strides[1], dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(self.block_class, filters[2], block_sizes[2], strides[2], dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(self.block_class, filters[3], block_sizes[3], strides[3], dilate=replace_stride_with_dilation[2]) self.final_dim = filters[-1] * self.block_class.expansion self.apply(lambda x: init_weights(x, mode=init_mode)) def _make_layer_se( self, block: Type[Union[SEBasicBlock, SEBottleneck]], in_channels: int, block_num: int, stride: int = 1, dilate: bool = False ) -> nn.Sequential: norm_layer = nn.BatchNorm2d downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.in_channels != in_channels * block.expansion: downsample = nn.Sequential( conv1x1(self.in_channels, in_channels * block.expansion, stride), norm_layer(in_channels * block.expansion), ) layers = [] layers.append( block( in_channels=self.in_channels, out_channels=in_channels, stride=stride, downsample=downsample, groups=self.groups, base_width=self.base_width, dilation=previous_dilation, ) ) self.in_channels = in_channels * block.expansion for _ in range(1, block_num): layers.append( block( in_channels=self.in_channels, out_channels=in_channels, stride=1, groups=self.groups, base_width=self.base_width, dilation=self.dilation, ) ) return nn.Sequential(*layers) def _make_layer( self, block: Type[Union[BasicBlock, Bottleneck]], in_channels: int, block_num: int, stride: int = 1, dilate: bool = False ) -> nn.Sequential: norm_layer = nn.BatchNorm2d downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.in_channels != in_channels * block.expansion: downsample = nn.Sequential( conv1x1(self.in_channels, in_channels * block.expansion, stride), norm_layer(in_channels * block.expansion), ) layers = [] layers.append( block(self.in_channels, in_channels, stride, downsample, self.groups, self.base_width, previous_dilation) ) self.in_channels = in_channels * block.expansion for _ in range(1, block_num): layers.append( block(self.in_channels, in_channels, groups=self.groups, base_width=self.base_width, dilation=self.dilation) ) return nn.Sequential(*layers) def forward(self, audio_signal: torch.Tensor, length: torch.Tensor = None): x = audio_signal x = x.unsqueeze(dim=1) # (B, 1, C, T) x = self.pre_conv(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = x.flatten(1, 2) return x, length class SpeakerDecoder(NeuralModule): """ Speaker Decoder creates the final neural layers that maps from the outputs of Jasper Encoder to the embedding layer followed by speaker based softmax loss. Args: feat_in (int): Number of channels being input to this module num_classes (int): Number of unique speakers in dataset emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers). Defaults to [1024,1024] pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' Defaults to 'xvector (mean and variance)' tap (temporal average pooling: just mean) attention (attention based pooling) init_mode (str): Describes how neural network parameters are initialized. Options are ['xavier_uniform', 'xavier_normal', 'kaiming_uniform','kaiming_normal']. Defaults to "xavier_uniform". """ def __init__( self, feat_in: int, num_classes: int, emb_sizes: Optional[Union[int, list]] = 256, pool_mode: str = 'xvector', angular: bool = False, attention_channels: int = 128, init_mode: str = "xavier_uniform", ): super().__init__() self.angular = angular self.emb_id = 2 bias = False if self.angular else True emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes self._num_classes = num_classes self.pool_mode = pool_mode.lower() if self.pool_mode == 'xvector' or self.pool_mode == 'tap': self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode) affine_type = 'linear' elif self.pool_mode == 'attention': self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels) affine_type = 'conv' elif self.pool_mode == 'ecapa2': self._pooling = ChannelDependentAttentiveStatisticsPoolLayer( inp_filters=feat_in, attention_channels=attention_channels ) affine_type = 'conv' shapes = [self._pooling.feat_in] for size in emb_sizes: shapes.append(int(size)) emb_layers = [] for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type) emb_layers.append(layer) self.emb_layers = nn.ModuleList(emb_layers) self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) self.apply(lambda x: init_weights(x, mode=init_mode)) def affine_layer( self, inp_shape, out_shape, learn_mean=True, affine_type='conv', ): if affine_type == 'conv': layer = nn.Sequential( nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True), nn.Conv1d(inp_shape, out_shape, kernel_size=1), ) else: layer = nn.Sequential( nn.Linear(inp_shape, out_shape), nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), nn.ReLU(), ) return layer def forward(self, encoder_output, length=None): pool = self._pooling(encoder_output, length) embs = [] for layer in self.emb_layers: pool, emb = layer(pool), layer[: self.emb_id](pool) embs.append(emb) pool = pool.squeeze(-1) if self.angular: for W in self.final.parameters(): W = F.normalize(W, p=2, dim=1) pool = F.normalize(pool, p=2, dim=1) out = self.final(pool) return out, embs[-1].squeeze(-1)