|
import torch.nn as nn |
|
|
|
class CNN(nn.Module): |
|
def __init__(self, K): |
|
super(CNN, self).__init__() |
|
self.conv_layers = nn.Sequential( |
|
|
|
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(32), |
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(32), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(128), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(128), |
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(256), |
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(256), |
|
nn.MaxPool2d(2), |
|
) |
|
|
|
self.dense_layers = nn.Sequential( |
|
nn.Dropout(0.4), |
|
nn.Linear(256 * 14 * 14, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(0.4), |
|
nn.Linear(1024, K), |
|
) |
|
|
|
def forward(self, X): |
|
out = self.conv_layers(X) |
|
out = out.view(out.size(0), -1) |
|
out = self.dense_layers(out) |
|
return out |
|
|
|
|