Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import numpy as np | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| from models.mossformer_gan_se.conformer import ConformerBlock | |
| class LearnableSigmoid(nn.Module): | |
| """A learnable sigmoid activation function that scales the output | |
| based on the input features. | |
| Args: | |
| in_features (int): The number of input features for the sigmoid function. | |
| beta (float, optional): A scaling factor for the sigmoid output. Default is 1. | |
| Attributes: | |
| beta (float): The scaling factor for the sigmoid function. | |
| slope (Parameter): Learnable parameter that adjusts the slope of the sigmoid. | |
| """ | |
| def __init__(self, in_features, beta=1): | |
| """Initializes the LearnableSigmoid module. | |
| Args: | |
| in_features (int): Number of input features. | |
| beta (float, optional): Scaling factor for the sigmoid output. | |
| """ | |
| super().__init__() | |
| self.beta = beta # Scaling factor for the sigmoid | |
| self.slope = nn.Parameter(torch.ones(in_features)) # Learnable slope parameter | |
| self.slope.requiresGrad = True # Ensure gradient updates | |
| def forward(self, x): | |
| """Forward pass of the learnable sigmoid function. | |
| Args: | |
| x (torch.Tensor): Input tensor with shape [batch_size, in_features]. | |
| Returns: | |
| torch.Tensor: The scaled sigmoid output tensor. | |
| """ | |
| return self.beta * torch.sigmoid(self.slope * x) | |
| #%% Spectrograms | |
| def segment_specs(y, seg_length=15, seg_hop=4, max_length=None): | |
| """Segments a spectrogram into smaller segments for input to a CNN. | |
| Each segment includes neighboring frequency bins to preserve | |
| contextual information. | |
| Args: | |
| y (torch.Tensor): Input spectrogram tensor of shape [B, H, W], | |
| where B is batch size, H is number of mel bands, | |
| and W is the length of the spectrogram. | |
| seg_length (int): Length of each segment (must be odd). Default is 15. | |
| seg_hop (int): Hop length for segmenting the spectrogram. Default is 4. | |
| max_length (int, optional): Maximum number of windows allowed. If the number of | |
| windows exceeds this, a ValueError is raised. | |
| Returns: | |
| torch.Tensor: Segmented tensor with shape [B*n, C, H, seg_length], where n is the | |
| number of segments, C is the number of channels (always 1). | |
| Raises: | |
| ValueError: If seg_length is even or if the number of windows exceeds max_length. | |
| """ | |
| # Ensure segment length is odd | |
| if seg_length % 2 == 0: | |
| raise ValueError('seg_length must be odd! (seg_length={})'.format(seg_length)) | |
| # Convert input to tensor if it's not already | |
| if not torch.is_tensor(y): | |
| y = torch.tensor(y) | |
| B, _, _ = y.size() # Extract batch size and dimensions | |
| for b in range(B): | |
| x = y[b, :, :] # Extract the current batch's spectrogram | |
| n_wins = x.shape[1] - (seg_length - 1) # Calculate number of windows | |
| # Segment the mel-spectrogram | |
| idx1 = torch.arange(seg_length) # Indices for segment length | |
| idx2 = torch.arange(n_wins) # Indices for number of windows | |
| idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1) # Create indices for segments | |
| x = x.transpose(1, 0)[idx3, :].unsqueeze(1).transpose(3, 2) # Rearrange dimensions for CNN input | |
| # Adjust segments based on hop length | |
| if seg_hop > 1: | |
| x = x[::seg_hop, :] # Downsample segments | |
| n_wins = int(np.ceil(n_wins / seg_hop)) # Update number of windows | |
| # Pad the segments if max_length is specified | |
| if max_length is not None: | |
| if max_length < n_wins: | |
| raise ValueError('n_wins {} > max_length {}. Increase max window length max_segments!'.format(n_wins, max_length)) | |
| x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3])) # Create a padded tensor | |
| x_padded[:n_wins, :] = x # Fill the padded tensor with the segments | |
| x = x_padded # Update x to the padded tensor | |
| # Concatenate segments from each batch | |
| if b == 0: | |
| z = x.unsqueeze(0) # Initialize z for the first batch | |
| else: | |
| z = torch.cat((z, x.unsqueeze(0)), axis=0) # Concatenate to z | |
| # Reshape the final tensor for output | |
| B, n, c, f, t = z.size() | |
| z = z.view(B * n, c, f, t) # Combine batch and segment dimensions | |
| return z # Return the segmented spectrogram tensor | |
| class AdaptCNN(nn.Module): | |
| """ | |
| AdaptCNN: A convolutional neural network (CNN) with adaptive max pooling that | |
| can be used as a framewise model. This architecture is more flexible than a | |
| standard CNN, which requires a fixed input dimension. The model consists of six | |
| convolutional layers, with adaptive pooling at each layer to handle varying input sizes. | |
| Args: | |
| input_channels (int): Number of input channels (default is 2). | |
| c_out_1 (int): Number of output channels for the first convolutional layer (default is 16). | |
| c_out_2 (int): Number of output channels for the second convolutional layer (default is 32). | |
| c_out_3 (int): Number of output channels for the third and subsequent convolutional layers (default is 64). | |
| kernel_size (list or int): Size of the convolutional kernels (default is [3, 3]). | |
| dropout (float): Dropout rate for regularization (default is 0.2). | |
| pool_1 (list): Pooling parameters for the first adaptive pooling layer (default is [101, 7]). | |
| pool_2 (list): Pooling parameters for the second adaptive pooling layer (default is [50, 7]). | |
| pool_3 (list): Pooling parameters for the third adaptive pooling layer (default is [25, 5]). | |
| pool_4 (list): Pooling parameters for the fourth adaptive pooling layer (default is [12, 5]). | |
| pool_5 (list): Pooling parameters for the fifth adaptive pooling layer (default is [6, 3]). | |
| fc_out_h (int, optional): Number of output units for the final fully connected layer. If None, the output size is determined from previous layers. | |
| Attributes: | |
| name (str): Name of the model. | |
| dropout (Dropout): Dropout layer for regularization. | |
| conv1, conv2, conv3, conv4, conv5, conv6 (Conv2d): Convolutional layers. | |
| bn1, bn2, bn3, bn4, bn5, bn6 (BatchNorm2d): Batch normalization layers. | |
| fc (Linear, optional): Fully connected layer. | |
| fan_out (int): Output dimension of the final layer. | |
| """ | |
| def __init__(self, | |
| input_channels=2, | |
| c_out_1=16, | |
| c_out_2=32, | |
| c_out_3=64, | |
| kernel_size=[3, 3], | |
| dropout=0.2, | |
| pool_1=[101, 7], | |
| pool_2=[50, 7], | |
| pool_3=[25, 5], | |
| pool_4=[12, 5], | |
| pool_5=[6, 3], | |
| fc_out_h=None): | |
| """Initializes the AdaptCNN model with the specified parameters.""" | |
| super().__init__() | |
| self.name = 'CNN_adapt' | |
| # Model parameters | |
| self.input_channels = input_channels | |
| self.c_out_1 = c_out_1 | |
| self.c_out_2 = c_out_2 | |
| self.c_out_3 = c_out_3 | |
| self.kernel_size = kernel_size | |
| self.pool_1 = pool_1 | |
| self.pool_2 = pool_2 | |
| self.pool_3 = pool_3 | |
| self.pool_4 = pool_4 | |
| self.pool_5 = pool_5 | |
| self.dropout_rate = dropout | |
| self.fc_out_h = fc_out_h | |
| # Dropout layer for regularization | |
| self.dropout = nn.Dropout2d(p=self.dropout_rate) | |
| # Ensure kernel_size is a tuple | |
| if isinstance(self.kernel_size, int): | |
| self.kernel_size = (self.kernel_size, self.kernel_size) | |
| # Set kernel size for the last convolutional layer | |
| self.kernel_size_last = (self.kernel_size[0], self.pool_5[1]) | |
| # Determine padding for convolutional layers based on kernel size | |
| if self.kernel_size[1] == 1: | |
| self.cnn_pad = (1, 0) # No padding needed for 1D convolution | |
| else: | |
| self.cnn_pad = (1, 1) # Padding for 2D convolution | |
| # Define convolutional layers with batch normalization | |
| self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad) | |
| self.bn1 = nn.BatchNorm2d(self.conv1.out_channels) | |
| self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad) | |
| self.bn2 = nn.BatchNorm2d(self.conv2.out_channels) | |
| self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
| self.bn3 = nn.BatchNorm2d(self.conv3.out_channels) | |
| self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
| self.bn4 = nn.BatchNorm2d(self.conv4.out_channels) | |
| self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
| self.bn5 = nn.BatchNorm2d(self.conv5.out_channels) | |
| self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0)) | |
| self.bn6 = nn.BatchNorm2d(self.conv6.out_channels) | |
| # Define fully connected layer if output size is specified | |
| if self.fc_out_h: | |
| self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h) | |
| self.fan_out = self.fc_out_h | |
| else: | |
| self.fan_out = (self.conv6.out_channels * self.pool_3[0]) | |
| def forward(self, x): | |
| """Defines the forward pass of the AdaptCNN model. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, input_channels, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor after passing through the CNN layers. | |
| """ | |
| # Forward pass through each layer with ReLU activation and adaptive pooling | |
| x = F.relu(self.bn1(self.conv1(x))) # First convolutional layer | |
| x = F.adaptive_max_pool2d(x, output_size=(self.pool_1)) # Adaptive pooling after conv1 | |
| x = F.relu(self.bn2(self.conv2(x))) # Second convolutional layer | |
| x = F.adaptive_max_pool2d(x, output_size=(self.pool_2)) # Adaptive pooling after conv2 | |
| x = self.dropout(x) # Apply dropout | |
| x = F.relu(self.bn3(self.conv3(x))) # Third convolutional layer | |
| x = F.adaptive_max_pool2d(x, output_size=(self.pool_3)) # Adaptive pooling after conv3 | |
| x = self.dropout(x) # Apply dropout | |
| x = F.relu(self.bn4(self.conv4(x))) # Fourth convolutional layer | |
| x = F.adaptive_max_pool2d(x, output_size=(self.pool_4)) # Adaptive pooling after conv4 | |
| x = self.dropout(x) # Apply dropout | |
| x = F.relu(self.bn5(self.conv5(x))) # Fifth convolutional layer | |
| x = F.adaptive_max_pool2d(x, output_size=(self.pool_5)) # Adaptive pooling after conv5 | |
| x = self.dropout(x) # Apply dropout | |
| x = F.relu(self.bn6(self.conv6(x))) # Last convolutional layer | |
| # Flatten the output for the fully connected layer | |
| x = x.view(-1, self.conv6.out_channels * self.pool_5[0]) | |
| # Apply fully connected layer if defined | |
| if self.fc_out_h: | |
| x = self.fc(x) # Fully connected output | |
| return x # Return the output tensor | |
| class PoolAttFF(nn.Module): | |
| """ | |
| PoolAttFF: An attention pooling module with an additional feed-forward network. | |
| This module performs attention-based pooling on input features followed by a | |
| feed-forward neural network. The attention mechanism helps in focusing on the | |
| important parts of the input while pooling. | |
| Args: | |
| d_input (int): The dimensionality of the input features (default is 384). | |
| output_size (int): The size of the output after the feed-forward network (default is 1). | |
| h (int): The size of the hidden layer in the feed-forward network (default is 128). | |
| dropout (float): The dropout rate for regularization (default is 0.1). | |
| Attributes: | |
| linear1 (Linear): First linear layer transforming input features to hidden size. | |
| linear2 (Linear): Second linear layer producing attention scores. | |
| linear3 (Linear): Final linear layer producing the output. | |
| activation (function): Activation function used in the network (ReLU). | |
| dropout (Dropout): Dropout layer for regularization. | |
| """ | |
| def __init__(self, d_input=384, output_size=1, h=128, dropout=0.1): | |
| """Initializes the PoolAttFF module with the specified parameters.""" | |
| super().__init__() | |
| # Define the feed-forward layers | |
| self.linear1 = nn.Linear(d_input, h) # First linear layer | |
| self.linear2 = nn.Linear(h, 1) # Second linear layer for attention scores | |
| self.linear3 = nn.Linear(d_input, output_size) # Final output layer | |
| self.activation = F.relu # Activation function | |
| self.dropout = nn.Dropout(dropout) # Dropout layer for regularization | |
| def forward(self, x): | |
| """Defines the forward pass of the PoolAttFF module. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch_size, seq_len, d_input]. | |
| Returns: | |
| torch.Tensor: Output tensor after attention pooling and feed-forward network. | |
| """ | |
| # Compute attention scores | |
| att = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| att = att.transpose(2, 1) # Transpose for softmax | |
| # Apply softmax to get attention weights | |
| att = F.softmax(att, dim=2) # Softmax along the sequence length | |
| # Perform attention pooling | |
| x = torch.bmm(att, x) # Batch matrix multiplication | |
| x = x.squeeze(1) # Remove unnecessary dimension | |
| x = self.linear3(x) # Final output layer | |
| return x # Return the output tensor | |
| class Discriminator(nn.Module): | |
| """ | |
| Discriminator: A neural network that predicts a normalized PESQ value | |
| between a predicted waveform (x) and a ground truth waveform (y). | |
| The model concatenates the two input waveforms, processes them through | |
| a convolutional network (CNN), applies self-attention, and outputs a | |
| value between 0 and 1 using a sigmoid activation function. | |
| Args: | |
| ndf (int): Number of filters in the convolutional layers (not directly used in this implementation). | |
| in_channel (int): Number of input channels (default is 2). | |
| Attributes: | |
| dim (int): Dimensionality of the feature representation (default is 384). | |
| cnn (AdaptCNN): CNN model for feature extraction. | |
| att (Sequential): Sequential stack of Conformer blocks for attention processing. | |
| pool (PoolAttFF): Attention pooling module. | |
| sigmoid (LearnableSigmoid): Sigmoid layer for final output. | |
| """ | |
| def __init__(self, ndf, in_channel=2): | |
| """Initializes the Discriminator with specified parameters.""" | |
| super().__init__() | |
| self.dim = 384 # Dimensionality of the feature representation | |
| self.cnn = AdaptCNN() # CNN model for feature extraction | |
| # Define attention layers using Conformer blocks | |
| self.att = nn.Sequential( | |
| ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4, | |
| conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2), | |
| ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4, | |
| conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2) | |
| ) | |
| # Define attention pooling module | |
| self.pool = PoolAttFF() | |
| self.sigmoid = LearnableSigmoid(1) # Sigmoid layer for output normalization | |
| def forward(self, x, y): | |
| """Defines the forward pass of the Discriminator. | |
| Args: | |
| x (torch.Tensor): Predicted waveform tensor of shape [batch_size, 1, height, width]. | |
| y (torch.Tensor): Ground truth waveform tensor of shape [batch_size, 1, height, width]. | |
| Returns: | |
| torch.Tensor: Output tensor representing the predicted PESQ value. | |
| """ | |
| B, _, _, _ = x.size() # Get the batch size from input x | |
| x = segment_specs(x.squeeze(1)) # Segment and process predicted waveform | |
| y = segment_specs(y.squeeze(1)) # Segment and process ground truth waveform | |
| # Concatenate the processed waveforms | |
| xy = torch.cat([x, y], dim=1) # Concatenate along the channel dimension | |
| cnn_out = self.cnn(xy) # Extract features using CNN | |
| _, d = cnn_out.size() # Get dimensions of CNN output | |
| cnn_out = cnn_out.view(B, -1, d) # Reshape for attention processing | |
| att_out = self.att(cnn_out) # Apply self-attention layers | |
| pool_out = self.pool(att_out) # Apply attention pooling module | |
| out = self.sigmoid(pool_out) # Normalize output using sigmoid function | |
| return out # Return the predicted PESQ value | |