Spaces:
Running
Running
| import math | |
| from collections import OrderedDict | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from packaging.version import parse as V | |
| from torch.nn import init | |
| from torch.nn.parameter import Parameter | |
| from models.mossformer_gan_se.fsmn import UniDeepFsmn | |
| from models.mossformer_gan_se.conv_module import ConvModule | |
| from models.mossformer_gan_se.mossformer import MossFormer | |
| from models.mossformer_gan_se.se_layer import SELayer | |
| from models.mossformer_gan_se.get_layer_from_string import get_layer | |
| from models.mossformer_gan_se.discriminator import Discriminator | |
| # Check if the installed version of PyTorch is 1.9.0 or higher | |
| is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") | |
| class MossFormerGAN_SE_16K(nn.Module): | |
| """ | |
| MossFormerGAN_SE_16K: A GAN-based speech enhancement model for 16kHz input audio. | |
| This model integrates a synchronous attention network (SyncANet) for | |
| feature extraction. Depending on the mode (train or inference), it may | |
| also include a discriminator for adversarial training. | |
| Args: | |
| args (Namespace): Arguments containing configuration parameters, | |
| including 'fft_len' and 'mode'. | |
| """ | |
| def __init__(self, args): | |
| """Initializes the MossFormerGAN_SE_16K model.""" | |
| super(MossFormerGAN_SE_16K, self).__init__() | |
| # Initialize SyncANet with specified number of channels and features | |
| self.model = SyncANet(num_channel=64, num_features=args.fft_len // 2 + 1) | |
| # Initialize discriminator if in training mode | |
| if args.mode == 'train': | |
| self.discriminator = Discriminator(ndf=16) | |
| else: | |
| self.discriminator = None | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass of the MossFormerGAN_SE_16K model. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, num_channels, height, width]. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Output tensors representing the real and imaginary parts. | |
| """ | |
| output_real, output_imag = self.model(x) # Get real and imaginary outputs from the model | |
| return output_real, output_imag # Return the outputs | |
| class FSMN_Wrap(nn.Module): | |
| """ | |
| FSMN_Wrap: A wrapper around the UniDeepFsmn module to facilitate | |
| integration into the larger model architecture. | |
| Args: | |
| nIn (int): Number of input features. | |
| nHidden (int): Number of hidden features in the FSMN (default is 128). | |
| lorder (int): Order of the FSMN (default is 20). | |
| nOut (int): Number of output features (default is 128). | |
| """ | |
| def __init__(self, nIn, nHidden=128, lorder=20, nOut=128): | |
| """Initializes the FSMN_Wrap module with specified parameters.""" | |
| super(FSMN_Wrap, self).__init__() | |
| # Initialize the UniDeepFsmn module | |
| self.fsmn = UniDeepFsmn(nIn, nHidden, lorder, nHidden) | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass of the FSMN_Wrap module. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, channels, height, time, 2]. | |
| Returns: | |
| torch.Tensor: Output tensor reshaped to [batch_size, channels, height, time]. | |
| """ | |
| # Shape of input x: [b, c, h, T, 2] | |
| b, c, T, h = x.size() | |
| # Permute x to reshape it for FSMN processing: [b, T, h, c] | |
| x = x.permute(0, 2, 3, 1) # Change dimensions to [b, T, h, c] | |
| x = torch.reshape(x, (b * T, h, c)) # Reshape to [b*T, h, c] | |
| # Pass through the FSMN | |
| output = self.fsmn(x) # output: [b*T, h, c] | |
| # Reshape output back to original dimensions | |
| output = torch.reshape(output, (b, T, h, c)) # output: [b, T, h, c] | |
| return output.permute(0, 3, 1, 2) # Final output shape: [b, c, h, T] | |
| class DilatedDenseNet(nn.Module): | |
| """ | |
| DilatedDenseNet: A dilated dense network for feature extraction. | |
| This network consists of a series of dilated convolutions organized in a dense block structure, | |
| allowing for efficient feature reuse and capturing multi-scale information. | |
| Args: | |
| depth (int): The number of layers in the dense block (default is 4). | |
| in_channels (int): The number of input channels for the first layer (default is 64). | |
| """ | |
| def __init__(self, depth=4, in_channels=64): | |
| """Initializes the DilatedDenseNet with specified depth and input channels.""" | |
| super(DilatedDenseNet, self).__init__() | |
| self.depth = depth | |
| self.in_channels = in_channels | |
| self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) # Padding for the first layer | |
| self.twidth = 2 # Temporal width for convolutions | |
| self.kernel_size = (self.twidth, 3) # Kernel size for convolutions | |
| # Initialize dilated convolutions, padding, normalization, and FSMN for each layer | |
| for i in range(self.depth): | |
| dil = 2 ** i # Dilation factor for the current layer | |
| pad_length = self.twidth + (dil - 1) * (self.twidth - 1) - 1 # Calculate padding length | |
| setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((1, 1, pad_length, 0), value=0.)) | |
| setattr(self, 'conv{}'.format(i + 1), | |
| nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size, | |
| dilation=(dil, 1))) # Convolution layer | |
| setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True)) # Normalization | |
| setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) # Activation function | |
| setattr(self, 'fsmn{}'.format(i + 1), FSMN_Wrap(nIn=self.in_channels, nHidden=self.in_channels, lorder=5, nOut=self.in_channels)) | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the DilatedDenseNet. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor after processing through the dense network. | |
| """ | |
| skip = x # Initialize skip connection with input | |
| for i in range(self.depth): | |
| # Apply padding, convolution, normalization, activation, and FSMN in sequence | |
| out = getattr(self, 'pad{}'.format(i + 1))(skip) | |
| out = getattr(self, 'conv{}'.format(i + 1))(out) | |
| out = getattr(self, 'norm{}'.format(i + 1))(out) | |
| out = getattr(self, 'prelu{}'.format(i + 1))(out) | |
| out = getattr(self, 'fsmn{}'.format(i + 1))(out) | |
| skip = torch.cat([out, skip], dim=1) # Concatenate outputs for dense connectivity | |
| return out # Return the final output | |
| class DenseEncoder(nn.Module): | |
| """ | |
| DenseEncoder: A dense encoding module for feature extraction from input data. | |
| This module consists of a series of convolutional layers followed by a | |
| dilated dense network for robust feature learning. | |
| Args: | |
| in_channel (int): Number of input channels for the encoder. | |
| channels (int): Number of output channels for each convolutional layer (default is 64). | |
| """ | |
| def __init__(self, in_channel, channels=64): | |
| """Initializes the DenseEncoder with specified input channels and feature size.""" | |
| super(DenseEncoder, self).__init__() | |
| self.conv_1 = nn.Sequential( | |
| nn.Conv2d(in_channel, channels, (1, 1), (1, 1)), # Initial convolution layer | |
| nn.InstanceNorm2d(channels, affine=True), # Normalization layer | |
| nn.PReLU(channels) # Activation function | |
| ) | |
| self.dilated_dense = DilatedDenseNet(depth=4, in_channels=channels) # Dilated Dense Network | |
| self.conv_2 = nn.Sequential( | |
| nn.Conv2d(channels, channels, (1, 3), (1, 2), padding=(0, 1)), # Second convolution layer | |
| nn.InstanceNorm2d(channels, affine=True), # Normalization layer | |
| nn.PReLU(channels) # Activation function | |
| ) | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the DenseEncoder. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, in_channel, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor after processing through the encoder. | |
| """ | |
| x = self.conv_1(x) # Process through the first convolutional layer | |
| x = self.dilated_dense(x) # Process through the dilated dense network | |
| x = self.conv_2(x) # Process through the second convolutional layer | |
| return x # Return the final output | |
| class SPConvTranspose2d(nn.Module): | |
| """ | |
| SPConvTranspose2d: A spatially separable convolution transpose layer. | |
| This module implements a transposed convolution operation with spatial separability, | |
| allowing for efficient upsampling and feature extraction. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| kernel_size (tuple): Size of the convolution kernel. | |
| r (int): Upsampling rate (default is 1). | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, r=1): | |
| """Initializes the SPConvTranspose2d with specified parameters.""" | |
| super(SPConvTranspose2d, self).__init__() | |
| self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) # Padding for input | |
| self.out_channels = out_channels # Store number of output channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) # Convolution layer | |
| self.r = r # Store the upsampling rate | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the SPConvTranspose2d module. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, in_channels, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor after transposed convolution operation. | |
| """ | |
| x = self.pad1(x) # Apply padding to input | |
| out = self.conv(x) # Perform convolution operation | |
| batch_size, nchannels, H, W = out.shape # Get output shape | |
| out = out.view((batch_size, self.r, nchannels // self.r, H, W)) # Reshape output for separation | |
| out = out.permute(0, 2, 3, 4, 1) # Rearrange dimensions | |
| out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) # Final output shape | |
| return out # Return the final output | |
| class MaskDecoder(nn.Module): | |
| """ | |
| MaskDecoder: A decoder module for estimating masks used in audio processing. | |
| This module utilizes a dilated dense network to capture features and | |
| applies sub-pixel convolution to upscale the output. It produces | |
| a mask that can be applied to the magnitude of audio signals. | |
| Args: | |
| num_features (int): The number of features in the output mask. | |
| num_channel (int): The number of channels in intermediate layers (default is 64). | |
| out_channel (int): The number of output channels for the final output mask (default is 1). | |
| """ | |
| def __init__(self, num_features, num_channel=64, out_channel=1): | |
| """Initializes the MaskDecoder with specified parameters.""" | |
| super(MaskDecoder, self).__init__() | |
| self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel) # Dense feature extraction | |
| self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2) # Sub-pixel convolution for upsampling | |
| self.conv_1 = nn.Conv2d(num_channel, out_channel, (1, 2)) # Convolution layer to produce mask | |
| self.norm = nn.InstanceNorm2d(out_channel, affine=True) # Normalization layer | |
| self.prelu = nn.PReLU(out_channel) # Activation function | |
| self.final_conv = nn.Conv2d(out_channel, out_channel, (1, 1)) # Final convolution layer | |
| self.prelu_out = nn.PReLU(num_features, init=-0.25) # Final activation for output mask | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the MaskDecoder. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
| Returns: | |
| torch.Tensor: Output mask tensor after processing through the decoder. | |
| """ | |
| x = self.dense_block(x) # Feature extraction using dilated dense block | |
| x = self.sub_pixel(x) # Upsample the features | |
| x = self.conv_1(x) # Convolution to estimate the mask | |
| x = self.prelu(self.norm(x)) # Apply normalization and activation | |
| x = self.final_conv(x).permute(0, 3, 2, 1).squeeze(-1) # Final convolution and rearrangement | |
| return self.prelu_out(x).permute(0, 2, 1).unsqueeze(1) # Final output shape | |
| class ComplexDecoder(nn.Module): | |
| """ | |
| ComplexDecoder: A decoder module for estimating complex-valued outputs. | |
| This module processes features through a dilated dense network and a | |
| sub-pixel convolution layer to generate two output channels representing | |
| the real and imaginary parts of the complex output. | |
| Args: | |
| num_channel (int): The number of channels in intermediate layers (default is 64). | |
| """ | |
| def __init__(self, num_channel=64): | |
| """Initializes the ComplexDecoder with specified parameters.""" | |
| super(ComplexDecoder, self).__init__() | |
| self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel) # Dense feature extraction | |
| self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2) # Sub-pixel convolution for upsampling | |
| self.prelu = nn.PReLU(num_channel) # Activation function | |
| self.norm = nn.InstanceNorm2d(num_channel, affine=True) # Normalization layer | |
| self.conv = nn.Conv2d(num_channel, 2, (1, 2)) # Convolution layer to produce complex outputs | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the ComplexDecoder. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor containing real and imaginary parts. | |
| """ | |
| x = self.dense_block(x) # Feature extraction using dilated dense block | |
| x = self.sub_pixel(x) # Upsample the features | |
| x = self.prelu(self.norm(x)) # Apply normalization and activation | |
| x = self.conv(x) # Generate complex output | |
| return x # Return the output tensor | |
| class SyncANet(nn.Module): | |
| """ | |
| SyncANet: A synchronous audio processing network for separating audio signals. | |
| This network integrates dense encoding, synchronous attention blocks, | |
| and separate decoders for estimating masks and complex-valued outputs. | |
| Args: | |
| num_channel (int): The number of channels in the network (default is 64). | |
| num_features (int): The number of features for the mask decoder (default is 201). | |
| """ | |
| def __init__(self, num_channel=64, num_features=201): | |
| """Initializes the SyncANet with specified parameters.""" | |
| super(SyncANet, self).__init__() | |
| self.dense_encoder = DenseEncoder(in_channel=3, channels=num_channel) # Dense encoder for input | |
| self.n_layers = 6 # Number of synchronous attention layers | |
| self.blocks = nn.ModuleList([]) # List to hold attention blocks | |
| # Initialize attention blocks | |
| for _ in range(self.n_layers): | |
| self.blocks.append( | |
| SyncANetBlock( | |
| emb_dim=num_channel, | |
| emb_ks=2, | |
| emb_hs=1, | |
| n_freqs=int(num_features//2)+1, | |
| hidden_channels=num_channel*2, | |
| n_head=4, | |
| approx_qk_dim=512, | |
| activation='prelu', | |
| eps=1.0e-5, | |
| ) | |
| ) | |
| self.mask_decoder = MaskDecoder(num_features, num_channel=num_channel, out_channel=1) # Mask decoder | |
| self.complex_decoder = ComplexDecoder(num_channel=num_channel) # Complex decoder | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the SyncANet. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, 2, height, width] representing complex signals. | |
| Returns: | |
| list: List containing the real and imaginary parts of the output tensor. | |
| """ | |
| out_list = [] # List to store outputs | |
| mag = torch.sqrt(x[:, 0, :, :]**2 + x[:, 1, :, :]**2).unsqueeze(1) # Calculate magnitude | |
| noisy_phase = torch.angle(torch.complex(x[:, 0, :, :], x[:, 1, :, :])).unsqueeze(1) # Calculate phase | |
| x_in = torch.cat([mag, x], dim=1) # Concatenate magnitude and input for processing | |
| x = self.dense_encoder(x_in) # Feature extraction using dense encoder | |
| for ii in range(self.n_layers): | |
| x = self.blocks[ii](x) # Pass through attention blocks | |
| mask = self.mask_decoder(x) # Estimate mask from features | |
| out_mag = mask * mag # Apply mask to magnitude | |
| complex_out = self.complex_decoder(x) # Generate complex output | |
| mag_real = out_mag * torch.cos(noisy_phase) # Real part of the output | |
| mag_imag = out_mag * torch.sin(noisy_phase) # Imaginary part of the output | |
| final_real = mag_real + complex_out[:, 0, :, :].unsqueeze(1) # Final real output | |
| final_imag = mag_imag + complex_out[:, 1, :, :].unsqueeze(1) # Final imaginary output | |
| out_list.append(final_real) # Append real output to list | |
| out_list.append(final_imag) # Append imaginary output to list | |
| return out_list # Return list of outputs | |
| class FFConvM(nn.Module): | |
| """ | |
| FFConvM: A feedforward convolutional module combining linear layers, normalization, | |
| non-linear activation, and convolution operations. | |
| This module processes input tensors through a sequence of transformations, including | |
| normalization, a linear layer with a SiLU activation, a convolutional operation, and | |
| dropout for regularization. | |
| Args: | |
| dim_in (int): The number of input features (dimensionality of input). | |
| dim_out (int): The number of output features (dimensionality of output). | |
| norm_klass (nn.Module): The normalization class to be applied (default is nn.LayerNorm). | |
| dropout (float): The dropout probability for regularization (default is 0.1). | |
| """ | |
| def __init__( | |
| self, | |
| dim_in, | |
| dim_out, | |
| norm_klass=nn.LayerNorm, | |
| dropout=0.1 | |
| ): | |
| """Initializes the FFConvM with specified parameters.""" | |
| super().__init__() | |
| # Define the sequential model | |
| self.mdl = nn.Sequential( | |
| norm_klass(dim_in), # Apply normalization to input | |
| nn.Linear(dim_in, dim_out), # Linear transformation to dim_out | |
| nn.SiLU(), # Non-linear activation using SiLU (Sigmoid Linear Unit) | |
| ConvModule(dim_out), # Convolution operation on the output | |
| nn.Dropout(dropout) # Dropout layer for regularization | |
| ) | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass for the FFConvM. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, dim_in]. | |
| Returns: | |
| torch.Tensor: Output tensor of shape [batch_size, dim_out] after processing. | |
| """ | |
| output = self.mdl(x) # Pass input through the sequential model | |
| return output # Return the processed output | |
| class SyncANetBlock(nn.Module): | |
| """ | |
| SyncANetBlock implements a modified version of the MossFormer (GatedFormer) module, | |
| inspired by the TF-GridNet architecture (https://arxiv.org/abs/2211.12433). | |
| It combines gated triple-attention schemes and Finite Short Memory Network (FSMN) modules | |
| to enhance computational efficiency and overall performance in audio processing tasks. | |
| Attributes: | |
| emb_dim (int): Dimensionality of the embedding. | |
| emb_ks (int): Kernel size for embeddings. | |
| emb_hs (int): Stride size for embeddings. | |
| n_freqs (int): Number of frequency bands. | |
| hidden_channels (int): Number of hidden channels. | |
| n_head (int): Number of attention heads. | |
| approx_qk_dim (int): Approximate dimension for query-key matrices. | |
| activation (str): Activation function to use. | |
| eps (float): Small value to avoid division by zero in normalization layers. | |
| """ | |
| def __getitem__(self, key): | |
| """ | |
| Allows accessing module attributes using indexing. | |
| Args: | |
| key: Attribute name to retrieve. | |
| Returns: | |
| The requested attribute. | |
| """ | |
| return getattr(self, key) | |
| def __init__( | |
| self, | |
| emb_dim, | |
| emb_ks, | |
| emb_hs, | |
| n_freqs, | |
| hidden_channels, | |
| n_head=4, | |
| approx_qk_dim=512, | |
| activation="prelu", | |
| eps=1e-5, | |
| ): | |
| """ | |
| Initializes the SyncANetBlock with the specified parameters. | |
| Args: | |
| emb_dim (int): Dimensionality of the embedding. | |
| emb_ks (int): Kernel size for embeddings. | |
| emb_hs (int): Stride size for embeddings. | |
| n_freqs (int): Number of frequency bands. | |
| hidden_channels (int): Number of hidden channels. | |
| n_head (int): Number of attention heads. Default is 4. | |
| approx_qk_dim (int): Approximate dimension for query-key matrices. Default is 512. | |
| activation (str): Activation function to use. Default is "prelu". | |
| eps (float): Small value to avoid division by zero in normalization layers. Default is 1e-5. | |
| """ | |
| super().__init__() | |
| in_channels = emb_dim * emb_ks # Calculate the number of input channels | |
| ## Intra modules: Modules for internal processing within the block | |
| self.Fconv = nn.Conv2d(emb_dim, in_channels, kernel_size=(1, emb_ks), stride=(1, 1), groups=emb_dim) | |
| self.intra_norm = LayerNormalization4D(emb_dim, eps=eps) # Layer normalization | |
| self.intra_to_u = FFConvM( | |
| dim_in=in_channels, | |
| dim_out=hidden_channels, | |
| norm_klass=nn.LayerNorm, | |
| dropout=0.1, | |
| ) | |
| self.intra_to_v = FFConvM( | |
| dim_in=in_channels, | |
| dim_out=hidden_channels, | |
| norm_klass=nn.LayerNorm, | |
| dropout=0.1, | |
| ) | |
| self.intra_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1) # FSMN layers | |
| self.intra_mossformer = MossFormer(dim=emb_dim, group_size=n_freqs) # MossFormer module | |
| # Linear transformation for intra module output | |
| self.intra_linear = nn.ConvTranspose1d( | |
| hidden_channels, emb_dim, emb_ks, stride=emb_hs | |
| ) | |
| self.intra_se = SELayer(channel=emb_dim, reduction=1) # Squeeze-and-excitation layer | |
| ## Inter modules: Modules for external processing between blocks | |
| self.inter_norm = LayerNormalization4D(emb_dim, eps=eps) # Layer normalization | |
| self.inter_to_u = FFConvM( | |
| dim_in=in_channels, | |
| dim_out=hidden_channels, | |
| norm_klass=nn.LayerNorm, | |
| dropout=0.1, | |
| ) | |
| self.inter_to_v = FFConvM( | |
| dim_in=in_channels, | |
| dim_out=hidden_channels, | |
| norm_klass=nn.LayerNorm, | |
| dropout=0.1, | |
| ) | |
| self.inter_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1) # FSMN layers | |
| self.inter_mossformer = MossFormer(dim=emb_dim, group_size=256) # MossFormer module | |
| # Linear transformation for inter module output | |
| self.inter_linear = nn.ConvTranspose1d( | |
| hidden_channels, emb_dim, emb_ks, stride=emb_hs | |
| ) | |
| self.inter_se = SELayer(channel=emb_dim, reduction=1) # Squeeze-and-excitation layer | |
| # Approximate query-key dimension calculation | |
| E = math.ceil(approx_qk_dim * 1.0 / n_freqs) | |
| assert emb_dim % n_head == 0 # Ensure emb_dim is divisible by n_head | |
| # Define attention convolution layers for each head | |
| for ii in range(n_head): | |
| self.add_module( | |
| f"attn_conv_Q_{ii}", | |
| nn.Sequential( | |
| nn.Conv2d(emb_dim, E, 1), | |
| get_layer(activation)(), | |
| LayerNormalization4DCF((E, n_freqs), eps=eps), | |
| ), | |
| ) | |
| self.add_module( | |
| f"attn_conv_K_{ii}", | |
| nn.Sequential( | |
| nn.Conv2d(emb_dim, E, 1), | |
| get_layer(activation)(), | |
| LayerNormalization4DCF((E, n_freqs), eps=eps), | |
| ), | |
| ) | |
| self.add_module( | |
| f"attn_conv_V_{ii}", | |
| nn.Sequential( | |
| nn.Conv2d(emb_dim, emb_dim // n_head, 1), | |
| get_layer(activation)(), | |
| LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps), | |
| ), | |
| ) | |
| # Final attention concatenation projection | |
| self.add_module( | |
| "attn_concat_proj", | |
| nn.Sequential( | |
| nn.Conv2d(emb_dim, emb_dim, 1), | |
| get_layer(activation)(), | |
| LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), | |
| ), | |
| ) | |
| # Store parameters for further processing | |
| self.emb_dim = emb_dim | |
| self.emb_ks = emb_ks | |
| self.emb_hs = emb_hs | |
| self.n_head = n_head | |
| def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1): | |
| """ | |
| Constructs a sequence of UniDeepFSMN modules. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| lorder (int): Order of the filter. | |
| hidden_size (int): Hidden size for the FSMN. | |
| repeats (int): Number of times to repeat the module. Default is 1. | |
| Returns: | |
| nn.Sequential: A sequence of UniDeepFSMN modules. | |
| """ | |
| repeats = [ | |
| UniDeepFsmn(in_channels, out_channels, lorder, hidden_size) | |
| for _ in range(repeats) | |
| ] | |
| return nn.Sequential(*repeats) | |
| def forward(self, x): | |
| """Performs a forward pass through the SyncANetBlock. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, C, T, Q] where | |
| B is batch size, C is number of channels, | |
| T is temporal dimension, and Q is frequency dimension. | |
| Returns: | |
| torch.Tensor: Output tensor of the same shape [B, C, T, Q]. | |
| """ | |
| B, C, old_T, old_Q = x.shape | |
| # Calculate new dimensions for padding | |
| T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks | |
| Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks | |
| # Pad the input tensor to match the new dimensions | |
| x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) | |
| # Intra-process | |
| input_ = x | |
| intra_rnn = self.intra_norm(input_) # Normalize input for intra-process | |
| intra_rnn = self.Fconv(intra_rnn) # Apply depthwise convolution | |
| intra_rnn = ( | |
| intra_rnn.transpose(1, 2).contiguous().view(B * T, C * self.emb_ks, -1) | |
| ) # Reshape for subsequent operations | |
| intra_rnn = intra_rnn.transpose(1, 2) # Reshape for processing | |
| intra_rnn_u = self.intra_to_u(intra_rnn) # Linear transformation | |
| intra_rnn_v = self.intra_to_v(intra_rnn) # Linear transformation | |
| intra_rnn_u = self.intra_rnn(intra_rnn_u) # Apply FSMN | |
| intra_rnn = intra_rnn_v * intra_rnn_u # Element-wise multiplication | |
| intra_rnn = intra_rnn.transpose(1, 2) # Reshape back | |
| intra_rnn = self.intra_linear(intra_rnn) # Linear projection | |
| intra_rnn = intra_rnn.transpose(1, 2) # Reshape for mossformer | |
| intra_rnn = intra_rnn.view([B, T, Q, C]) # Reshape for mossformer | |
| intra_rnn = self.intra_mossformer(intra_rnn) # Apply MossFormer | |
| intra_rnn = intra_rnn.transpose(1, 2) # Reshape back | |
| intra_rnn = intra_rnn.view([B, T, C, Q]) # Reshape back | |
| intra_rnn = intra_rnn.transpose(1, 2).contiguous() # Final reshape | |
| intra_rnn = self.intra_se(intra_rnn) # Squeeze-and-excitation layer | |
| intra_rnn = intra_rnn + input_ # Residual connection | |
| # Inter-process | |
| input_ = intra_rnn | |
| inter_rnn = self.inter_norm(input_) # Normalize input for inter-process | |
| inter_rnn = ( | |
| inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) | |
| ) # Reshape for processing | |
| inter_rnn = F.unfold( | |
| inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) | |
| ) # Extract sliding windows | |
| inter_rnn = inter_rnn.transpose(1, 2) # Reshape for further processing | |
| inter_rnn_u = self.inter_to_u(inter_rnn) # Linear transformation | |
| inter_rnn_v = self.inter_to_v(inter_rnn) # Linear transformation | |
| inter_rnn_u = self.inter_rnn(inter_rnn_u) # Apply FSMN | |
| inter_rnn = inter_rnn_v * inter_rnn_u # Element-wise multiplication | |
| inter_rnn = inter_rnn.transpose(1, 2) # Reshape back | |
| inter_rnn = self.inter_linear(inter_rnn) # Linear projection | |
| inter_rnn = inter_rnn.transpose(1, 2) # Reshape for mossformer | |
| inter_rnn = inter_rnn.view([B, Q, T, C]) # Reshape for mossformer | |
| inter_rnn = self.inter_mossformer(inter_rnn) # Apply MossFormer | |
| inter_rnn = inter_rnn.transpose(1, 2) # Reshape back | |
| inter_rnn = inter_rnn.view([B, Q, C, T]) # Final reshape | |
| inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous() # Permute for SE layer | |
| inter_rnn = self.inter_se(inter_rnn) # Squeeze-and-excitation layer | |
| inter_rnn = inter_rnn + input_ # Residual connection | |
| # Attention mechanism | |
| inter_rnn = inter_rnn[..., :old_T, :old_Q] # Trim to original shape | |
| batch = inter_rnn | |
| all_Q, all_K, all_V = [], [], [] | |
| # Compute query, key, and value for each attention head | |
| for ii in range(self.n_head): | |
| all_Q.append(self["attn_conv_Q_%d" % ii](batch)) # Query | |
| all_K.append(self["attn_conv_K_%d" % ii](batch)) # Key | |
| all_V.append(self["attn_conv_V_%d" % ii](batch)) # Value | |
| Q = torch.cat(all_Q, dim=0) # Concatenate all queries | |
| K = torch.cat(all_K, dim=0) # Concatenate all keys | |
| V = torch.cat(all_V, dim=0) # Concatenate all values | |
| # Reshape for attention calculation | |
| Q = Q.transpose(1, 2) | |
| Q = Q.flatten(start_dim=2) # Flatten for attention calculation | |
| K = K.transpose(1, 2) | |
| K = K.flatten(start_dim=2) # Flatten for attention calculation | |
| V = V.transpose(1, 2) # Reshape for attention calculation | |
| old_shape = V.shape | |
| V = V.flatten(start_dim=2) # Flatten for attention calculation | |
| emb_dim = Q.shape[-1] | |
| # Compute scaled dot-product attention | |
| attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # Attention matrix | |
| attn_mat = F.softmax(attn_mat, dim=2) # Softmax over attention scores | |
| V = torch.matmul(attn_mat, V) # Weighted sum of values | |
| V = V.reshape(old_shape) # Reshape back | |
| V = V.transpose(1, 2) # Final reshaping | |
| emb_dim = V.shape[1] | |
| batch = V.view([self.n_head, B, emb_dim, old_T, -1]) # Reshape for multi-head | |
| batch = batch.transpose(0, 1) # Permute for batch processing | |
| batch = batch.contiguous().view( | |
| [B, self.n_head * emb_dim, old_T, -1] | |
| ) # Final reshape for concatenation | |
| batch = self["attn_concat_proj"](batch) # Final linear projection | |
| # Combine inter-process result with attention output | |
| out = batch + inter_rnn | |
| return out # Return the output tensor | |
| class LayerNormalization4D(nn.Module): | |
| """ | |
| LayerNormalization4D applies layer normalization to 4D tensors | |
| (e.g., [B, C, T, F]), where B is the batch size, C is the number of channels, | |
| T is the temporal dimension, and F is the frequency dimension. | |
| Attributes: | |
| gamma (torch.Parameter): Learnable scaling parameter. | |
| beta (torch.Parameter): Learnable shifting parameter. | |
| eps (float): Small value for numerical stability during variance calculation. | |
| """ | |
| def __init__(self, input_dimension, eps=1e-5): | |
| """ | |
| Initializes the LayerNormalization4D layer. | |
| Args: | |
| input_dimension (int): The number of channels in the input tensor. | |
| eps (float, optional): Small constant added for numerical stability. | |
| """ | |
| super().__init__() | |
| param_size = [1, input_dimension, 1, 1] | |
| self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Scale parameter | |
| self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Shift parameter | |
| init.ones_(self.gamma) # Initialize gamma to 1 | |
| init.zeros_(self.beta) # Initialize beta to 0 | |
| self.eps = eps # Set the epsilon value | |
| def forward(self, x): | |
| """ | |
| Forward pass for the layer normalization. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, C, T, F]. | |
| Returns: | |
| torch.Tensor: Normalized output tensor of the same shape. | |
| """ | |
| if x.ndim == 4: | |
| _, C, _, _ = x.shape # Extract the number of channels | |
| stat_dim = (1,) # Dimension to compute statistics over | |
| else: | |
| raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) | |
| # Compute mean and standard deviation along the specified dimension | |
| mu_ = x.mean(dim=stat_dim, keepdim=True) # [B, 1, T, F] | |
| std_ = torch.sqrt( | |
| x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps | |
| ) # [B, 1, T, F] | |
| # Normalize the input tensor and apply learnable parameters | |
| x_hat = ((x - mu_) / std_) * self.gamma + self.beta # [B, C, T, F] | |
| return x_hat | |
| class LayerNormalization4DCF(nn.Module): | |
| """ | |
| LayerNormalization4DCF applies layer normalization to 4D tensors | |
| (e.g., [B, C, T, F]) specifically designed for DCF (Dynamic Channel Frequency) inputs. | |
| Attributes: | |
| gamma (torch.Parameter): Learnable scaling parameter. | |
| beta (torch.Parameter): Learnable shifting parameter. | |
| eps (float): Small value for numerical stability during variance calculation. | |
| """ | |
| def __init__(self, input_dimension, eps=1e-5): | |
| """ | |
| Initializes the LayerNormalization4DCF layer. | |
| Args: | |
| input_dimension (tuple): A tuple containing the dimensions of the input tensor | |
| (number of channels, frequency dimension). | |
| eps (float, optional): Small constant added for numerical stability. | |
| """ | |
| super().__init__() | |
| assert len(input_dimension) == 2, "Input dimension must be a tuple of length 2." | |
| param_size = [1, input_dimension[0], 1, input_dimension[1]] # Shape based on input dimensions | |
| self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Scale parameter | |
| self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Shift parameter | |
| init.ones_(self.gamma) # Initialize gamma to 1 | |
| init.zeros_(self.beta) # Initialize beta to 0 | |
| self.eps = eps # Set the epsilon value | |
| def forward(self, x): | |
| """ | |
| Forward pass for the layer normalization. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, C, T, F]. | |
| Returns: | |
| torch.Tensor: Normalized output tensor of the same shape. | |
| """ | |
| if x.ndim == 4: | |
| stat_dim = (1, 3) # Dimensions to compute statistics over | |
| else: | |
| raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) | |
| # Compute mean and standard deviation along the specified dimensions | |
| mu_ = x.mean(dim=stat_dim, keepdim=True) # [B, 1, T, 1] | |
| std_ = torch.sqrt( | |
| x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps | |
| ) # [B, 1, T, F] | |
| # Normalize the input tensor and apply learnable parameters | |
| x_hat = ((x - mu_) / std_) * self.gamma + self.beta # [B, C, T, F] | |
| return x_hat | |