Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchaudio import transforms as taT, functional as taF | |
DEVICE = "mps" | |
class ShortChunkCNN(nn.Module): | |
def __init__(self, | |
n_channels=128, | |
sample_rate=16000, | |
n_class=50): | |
super().__init__() | |
# Spectrogram | |
self.spec_bn = nn.BatchNorm2d(1) | |
# CNN | |
self.res_layers = nn.Sequential( | |
Res_2d(1, n_channels, stride=2), | |
Res_2d(n_channels, n_channels, stride=2), | |
Res_2d(n_channels, n_channels*2, stride=2), | |
Res_2d(n_channels*2, n_channels*2, stride=2), | |
Res_2d(n_channels*2, n_channels*2, stride=2), | |
Res_2d(n_channels*2, n_channels*2, stride=2), | |
Res_2d(n_channels*2, n_channels*4, stride=2) | |
) | |
# Dense | |
self.dense1 = nn.Linear(n_channels*4, n_channels*4) | |
self.bn = nn.BatchNorm1d(n_channels*4) | |
self.dense2 = nn.Linear(n_channels*4, n_class) | |
self.dropout = nn.Dropout(0.3) | |
def forward(self, x): | |
x = self.spec_bn(x) | |
# CNN | |
x = self.res_layers(x) | |
x = x.squeeze(2) | |
# Global Max Pooling | |
if x.size(-1) != 1: | |
x = nn.MaxPool1d(x.size(-1))(x) | |
x = x.squeeze(2) | |
# Dense | |
x = self.dense1(x) | |
x = self.bn(x) | |
x = F.relu(x) | |
x = self.dropout(x) | |
x = self.dense2(x) | |
x = nn.Sigmoid()(x) | |
return x | |
class Res_2d(nn.Module): | |
def __init__(self, input_channels, output_channels, shape=3, stride=2): | |
super().__init__() | |
# convolution | |
self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_1 = nn.BatchNorm2d(output_channels) | |
self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) | |
self.bn_2 = nn.BatchNorm2d(output_channels) | |
# residual | |
self.diff = False | |
if (stride != 1) or (input_channels != output_channels): | |
self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_3 = nn.BatchNorm2d(output_channels) | |
self.diff = True | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
# convolution | |
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) | |
# residual | |
if self.diff: | |
x = self.bn_3(self.conv_3(x)) | |
out = x + out | |
out = self.relu(out) | |
return out |