CNN_VIT_DeepFake / models.py
Saqib772's picture
deep fake with cnn vite and ensembling
dd98f48 verified
raw
history blame
5.15 kB
import os
import numpy as np
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt # at top of file
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm
# -------------------------------
# Transformations for different model inputs
# -------------------------------
# For Model A and Model B, we use small images (50x50)
transform_small = transforms.Compose([
transforms.Resize((50, 50)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
# For Model C, we use larger images (224x224)
transform_large = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
# --- Model A: CNN-based network for eye and nose regions (12 layers) ---
class ModelA(nn.Module):
def __init__(self, num_classes=2):
super(ModelA, self).__init__()
# Three convolutional blocks, each with 2 conv layers + BN, ReLU, pooling and dropout
self.block1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
self.block2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
self.block3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
# After three blocks, feature map size for 50x50 input: 50 -> 25 -> ~12 -> ~6
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.classifier(x)
return x
# --- Model B: Simpler CNN-based network (6 layers) ---
class ModelB(nn.Module):
def __init__(self, num_classes=2):
super(ModelB, self).__init__()
# A lighter CNN architecture: three conv layers with pooling and dropout
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# --- Model C: CNN + ViT based network for the entire face ---
class ModelC(nn.Module):
def __init__(self, num_classes=2):
super(ModelC, self).__init__()
# Feature learning (FL) module: a deep CNN.
# For demonstration, we use a simpler CNN here.
self.cnn_feature_extractor = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# Assume feature map size is reduced appropriately (for 224x224, it becomes roughly 28x28)
# Now use a vision transformer module from the timm library.
# Note: You may need to install timm (pip install timm).
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
# Replace the head of ViT to match our number of classes.
in_features = self.vit.head.in_features
self.vit.head = nn.Linear(in_features, num_classes)
def forward(self, x):
# Extract lower-level features (optional fusion)
features = self.cnn_feature_extractor(x)
# For this demonstration, we are feeding the original image to vit.
# In a more advanced implementation, you can fuse the CNN features with ViT.
out = self.vit(x)
return out