resnet34-voxceleb1 / conv_asr.py
yangwang825's picture
Upload ResNetForSequenceClassification
ea74807 verified
from typing import Optional, Union, Type, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .module import NeuralModule
from .tdnn_attention import (
StatsPoolLayer,
AttentivePoolLayer,
ChannelDependentAttentiveStatisticsPoolLayer,
TdnnModule,
TdnnSeModule,
TdnnSeRes2NetModule,
init_weights
)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""2D convolution with kernel_size = 3"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
groups=groups,
padding=dilation,
bias=False,
)
def conv1x1(in_planes, out_planes, stride=1):
"""2D convolution with kernel_size = 1"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
in_channels,
out_channels,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 16,
dilation: int = 1,
activation: Optional[nn.Module] = nn.ReLU,
):
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 16:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.activation = activation()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.activation(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.activation(out)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class SEBlock(nn.Module):
def __init__(self, channels, reduction=1, activation=nn.ReLU):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction),
activation(),
nn.Linear(channels // reduction, channels),
nn.Sigmoid(),
)
def forward(self, x):
"""Intermediate step. Processes the input tensor x
and returns an output tensor.
"""
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class SEBasicBlock(nn.Module):
expansion = 1
def __init__(
self,
in_channels,
out_channels,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 16,
dilation: int = 1,
activation: Optional[nn.Module] = nn.ReLU,
reduction: int = 8
):
super(SEBasicBlock, self).__init__()
if groups != 1 or base_width != 16:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.activation = activation()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
self.se = SEBlock(out_channels, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.activation(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.activation(out)
out = self.conv2(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class SEBottleneck(nn.Module):
expansion = 4
def __init__(
self,
in_channels,
out_channels,
stride=1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 16,
dilation: int = 1,
activation=nn.ReLU,
reduction: int = 8,
):
super(SEBottleneck, self).__init__()
width = int(out_channels * (base_width / 16.)) * groups
self.activation = activation()
# 1x1 convolution to reduce channels
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = conv1x1(in_channels, width, stride=1)
# 3x3 convolution
self.bn2 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
# 1x1 convolution to restore channels
self.bn3 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, out_channels * self.expansion)
# Squeeze-and-Excitation block
self.se = SEBlock(out_channels * self.expansion, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
# First 1x1 convolution
out = self.bn1(x)
out = self.activation(out)
out = self.conv1(out)
# 3x3 convolution
out = self.bn2(out)
out = self.activation(out)
out = self.conv2(out)
# Second 1x1 convolution
out = self.bn3(out)
out = self.activation(out)
out = self.conv3(out)
# Apply SE block
out = self.se(out)
# Downsample residual if needed
if self.downsample is not None:
residual = self.downsample(x)
# Add residual
out += residual
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(
self,
in_channels,
out_channels,
stride=1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 16,
dilation: int = 1,
activation=nn.ReLU,
):
super(Bottleneck, self).__init__()
width = int(out_channels * (base_width / 16.)) * groups
self.activation = activation()
# 1x1 convolution to reduce channels
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = conv1x1(in_channels, width, stride=1)
# 3x3 convolution
self.bn2 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
# 1x1 convolution to restore channels
self.bn3 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, out_channels * self.expansion)
self.downsample = downsample
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# First 1x1 convolution
out = self.bn1(x)
out = self.activation(out)
out = self.conv1(out)
# 3x3 convolution
out = self.bn2(out)
out = self.activation(out)
out = self.conv2(out)
# Second 1x1 convolution
out = self.bn3(out)
out = self.activation(out)
out = self.conv3(out)
# Downsample residual if needed
if self.downsample is not None:
residual = self.downsample(x)
# Add residual
out += residual
return out
class ResNetEncoder(NeuralModule):
def __init__(
self,
feat_in: int,
filters: list = [16, 32, 64, 128],
block_sizes: list = [3, 4, 6, 3],
strides: list = [1, 2, 2, 1],
groups: int = 1,
width_per_group: int = 16,
replace_stride_with_dilation: Optional[List[bool]] = None,
block_type: str = 'basic', # basic, bottleneck
reduction: int = 8, # reduction for SE layer
init_mode: str = 'xavier_uniform',
):
super().__init__()
if block_type == 'basic':
self.block_class = BasicBlock
self.se_block_class = SEBasicBlock
elif block_type == 'bottleneck':
self.block_class = Bottleneck
self.se_block_class = SEBottleneck
self.in_channels = filters[0]
self.dilation = 1
self.reduction = reduction
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.pre_conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=filters[0],
kernel_size=3,
stride=1,
padding=1,
bias=False
),
nn.BatchNorm2d(filters[0]),
nn.ReLU(inplace=True)
)
self.layer1 = self._make_layer_se(self.se_block_class, filters[0], block_sizes[0], strides[0])
self.layer2 = self._make_layer_se(self.se_block_class, filters[1], block_sizes[1], strides[1], dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(self.block_class, filters[2], block_sizes[2], strides[2], dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(self.block_class, filters[3], block_sizes[3], strides[3], dilate=replace_stride_with_dilation[2])
self.final_dim = filters[-1] * self.block_class.expansion
self.apply(lambda x: init_weights(x, mode=init_mode))
def _make_layer_se(
self,
block: Type[Union[SEBasicBlock, SEBottleneck]],
in_channels: int,
block_num: int,
stride: int = 1,
dilate: bool = False
) -> nn.Sequential:
norm_layer = nn.BatchNorm2d
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.in_channels != in_channels * block.expansion:
downsample = nn.Sequential(
conv1x1(self.in_channels, in_channels * block.expansion, stride),
norm_layer(in_channels * block.expansion),
)
layers = []
layers.append(
block(
in_channels=self.in_channels,
out_channels=in_channels,
stride=stride,
downsample=downsample,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
)
)
self.in_channels = in_channels * block.expansion
for _ in range(1, block_num):
layers.append(
block(
in_channels=self.in_channels,
out_channels=in_channels,
stride=1,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
)
)
return nn.Sequential(*layers)
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
in_channels: int,
block_num: int,
stride: int = 1,
dilate: bool = False
) -> nn.Sequential:
norm_layer = nn.BatchNorm2d
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.in_channels != in_channels * block.expansion:
downsample = nn.Sequential(
conv1x1(self.in_channels, in_channels * block.expansion, stride),
norm_layer(in_channels * block.expansion),
)
layers = []
layers.append(
block(self.in_channels, in_channels, stride, downsample, self.groups, self.base_width, previous_dilation)
)
self.in_channels = in_channels * block.expansion
for _ in range(1, block_num):
layers.append(
block(self.in_channels, in_channels, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
)
return nn.Sequential(*layers)
def forward(self, audio_signal: torch.Tensor, length: torch.Tensor = None):
x = audio_signal
x = x.unsqueeze(dim=1) # (B, 1, C, T)
x = self.pre_conv(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.flatten(1, 2)
return x, length
class SpeakerDecoder(NeuralModule):
"""
Speaker Decoder creates the final neural layers that maps from the outputs
of Jasper Encoder to the embedding layer followed by speaker based softmax loss.
Args:
feat_in (int): Number of channels being input to this module
num_classes (int): Number of unique speakers in dataset
emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings
from 1st of this layers). Defaults to [1024,1024]
pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention'
Defaults to 'xvector (mean and variance)'
tap (temporal average pooling: just mean)
attention (attention based pooling)
init_mode (str): Describes how neural network parameters are
initialized. Options are ['xavier_uniform', 'xavier_normal',
'kaiming_uniform','kaiming_normal'].
Defaults to "xavier_uniform".
"""
def __init__(
self,
feat_in: int,
num_classes: int,
emb_sizes: Optional[Union[int, list]] = 256,
pool_mode: str = 'xvector',
angular: bool = False,
attention_channels: int = 128,
init_mode: str = "xavier_uniform",
):
super().__init__()
self.angular = angular
self.emb_id = 2
bias = False if self.angular else True
emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes
self._num_classes = num_classes
self.pool_mode = pool_mode.lower()
if self.pool_mode == 'xvector' or self.pool_mode == 'tap':
self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode)
affine_type = 'linear'
elif self.pool_mode == 'attention':
self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels)
affine_type = 'conv'
elif self.pool_mode == 'ecapa2':
self._pooling = ChannelDependentAttentiveStatisticsPoolLayer(
inp_filters=feat_in, attention_channels=attention_channels
)
affine_type = 'conv'
shapes = [self._pooling.feat_in]
for size in emb_sizes:
shapes.append(int(size))
emb_layers = []
for shape_in, shape_out in zip(shapes[:-1], shapes[1:]):
layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type)
emb_layers.append(layer)
self.emb_layers = nn.ModuleList(emb_layers)
self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias)
self.apply(lambda x: init_weights(x, mode=init_mode))
def affine_layer(
self,
inp_shape,
out_shape,
learn_mean=True,
affine_type='conv',
):
if affine_type == 'conv':
layer = nn.Sequential(
nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True),
nn.Conv1d(inp_shape, out_shape, kernel_size=1),
)
else:
layer = nn.Sequential(
nn.Linear(inp_shape, out_shape),
nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True),
nn.ReLU(),
)
return layer
def forward(self, encoder_output, length=None):
pool = self._pooling(encoder_output, length)
embs = []
for layer in self.emb_layers:
pool, emb = layer(pool), layer[: self.emb_id](pool)
embs.append(emb)
pool = pool.squeeze(-1)
if self.angular:
for W in self.final.parameters():
W = F.normalize(W, p=2, dim=1)
pool = F.normalize(pool, p=2, dim=1)
out = self.final(pool)
return out, embs[-1].squeeze(-1)