| import torch | |
| from torch import nn | |
| from transformers.activations import ACT2FN | |
| class Conv2dFeatureExtractor(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.conv = torch.nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| nn.Conv2d( | |
| conv_in, | |
| out_channels=conv_out, | |
| kernel_size=(conv_kernel, conv_kernel), | |
| stride=(conv_stride, conv_stride), | |
| ), | |
| ACT2FN[config.feat_extract_activation], | |
| ) | |
| for conv_in, conv_out, conv_kernel, conv_stride in zip( | |
| [1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride | |
| ) | |
| ], | |
| ) | |
| linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2) | |
| self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True) | |
| def forward(self, input_values: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.conv(input_values[:, None, ...]) | |
| hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3)) | |
| return hidden_states.transpose(1, 2) | |