sanidhya2810's picture
Create CNN.py
0733d6e verified
import torch.nn as nn
class CNN(nn.Module):
def __init__(self, K):
super(CNN, self).__init__()
self.conv_layers = nn.Sequential(
# conv1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2),
# conv2
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2),
# conv3
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2),
# conv4
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(256),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(256),
nn.MaxPool2d(2),
)
self.dense_layers = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(256 * 14 * 14, 1024), # 224x224 input after 4 maxpools (224/2/2/2/2 = 14)
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(1024, K),
)
def forward(self, X):
out = self.conv_layers(X)
out = out.view(out.size(0), -1) # Flatten dynamically
out = self.dense_layers(out)
return out