CNN_VIT_DeepFake / models.py
Saqib772's picture
deep fake with cnn vite and ensembling
dd98f48 verified
import os
import numpy as np
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt # at top of file
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm
# -------------------------------
# Transformations for different model inputs
# -------------------------------
# For Model A and Model B, we use small images (50x50)
transform_small = transforms.Compose([
transforms.Resize((50, 50)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
# For Model C, we use larger images (224x224)
transform_large = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
# --- Model A: CNN-based network for eye and nose regions (12 layers) ---
class ModelA(nn.Module):
def __init__(self, num_classes=2):
super(ModelA, self).__init__()
# Three convolutional blocks, each with 2 conv layers + BN, ReLU, pooling and dropout
self.block1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
self.block2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
self.block3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.3)
)
# After three blocks, feature map size for 50x50 input: 50 -> 25 -> ~12 -> ~6
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.classifier(x)
return x
# --- Model B: Simpler CNN-based network (6 layers) ---
class ModelB(nn.Module):
def __init__(self, num_classes=2):
super(ModelB, self).__init__()
# A lighter CNN architecture: three conv layers with pooling and dropout
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# --- Model C: CNN + ViT based network for the entire face ---
class ModelC(nn.Module):
def __init__(self, num_classes=2):
super(ModelC, self).__init__()
# Feature learning (FL) module: a deep CNN.
# For demonstration, we use a simpler CNN here.
self.cnn_feature_extractor = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# Assume feature map size is reduced appropriately (for 224x224, it becomes roughly 28x28)
# Now use a vision transformer module from the timm library.
# Note: You may need to install timm (pip install timm).
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
# Replace the head of ViT to match our number of classes.
in_features = self.vit.head.in_features
self.vit.head = nn.Linear(in_features, num_classes)
def forward(self, x):
# Extract lower-level features (optional fusion)
features = self.cnn_feature_extractor(x)
# For this demonstration, we are feeding the original image to vit.
# In a more advanced implementation, you can fuse the CNN features with ViT.
out = self.vit(x)
return out