deep fake with cnn vite and ensembling
Browse filesThis research presents a novel ensemble-based approach for detecting deepfake images using a combination of Convolutional Neural Networks (CNNs) and Vision Transformers (ViT). The system achieves 94.87% accuracy by leveraging three complementary architectures: a 12-layer CNN, a lightweight 6-layer CNN, and a hybrid CNN-ViT model. Our approach demonstrates robust performance in distinguishing between real and manipulated facial images.
- README.md +127 -3
- app.py +51 -0
- cnn-vit-transformer-deepfake.ipynb +0 -0
- ensemble_config.json +1 -0
- modelA.pth +3 -0
- modelB.pth +3 -0
- modelC.pth +3 -0
- models.py +153 -0
README.md
CHANGED
@@ -1,3 +1,127 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ensemble-Based Deep Learning Architecture for Deepfake Detection
|
2 |
+
|
3 |
+
## Abstract
|
4 |
+
This research presents a novel ensemble-based approach for detecting deepfake images using a combination of Convolutional Neural Networks (CNNs) and Vision Transformers (ViT). The system achieves 94.87% accuracy by leveraging three complementary architectures: a 12-layer CNN, a lightweight 6-layer CNN, and a hybrid CNN-ViT model. Our approach demonstrates robust performance in distinguishing between real and manipulated facial images.
|
5 |
+
|
6 |
+
## 1. Introduction
|
7 |
+
With the increasing sophistication of deepfake technology, detecting manipulated images has become crucial for maintaining digital media integrity. This work introduces an ensemble method that combines traditional CNN architectures with modern Vision Transformers to create a robust detection system.
|
8 |
+
|
9 |
+
## 2. Architecture
|
10 |
+
|
11 |
+
### 2.1 Model Components
|
12 |
+
The system consists of three distinct models:
|
13 |
+
|
14 |
+
1. **Model A (12-layer CNN)**
|
15 |
+
- Three convolutional blocks
|
16 |
+
- Each block: 2 conv layers + BatchNorm + ReLU + pooling
|
17 |
+
- Input size: 50x50 pixels
|
18 |
+
- Dropout rate: 0.3
|
19 |
+
|
20 |
+
2. **Model B (6-layer CNN)**
|
21 |
+
- Lightweight architecture
|
22 |
+
- Three simple conv layers with pooling
|
23 |
+
- Input size: 50x50 pixels
|
24 |
+
- Dropout rate: 0.3
|
25 |
+
|
26 |
+
3. **Model C (CNN-ViT Hybrid)**
|
27 |
+
- CNN feature extractor
|
28 |
+
- Vision Transformer (base-16 architecture)
|
29 |
+
- Input size: 224x224 pixels
|
30 |
+
- Pretrained ViT backbone
|
31 |
+
|
32 |
+
### 2.2 Ensemble Strategy
|
33 |
+
The final prediction is determined through majority voting among the three models, enhancing robustness and reducing individual model biases.
|
34 |
+
|
35 |
+
## 3. Implementation Details
|
36 |
+
|
37 |
+
### 3.1 Dataset
|
38 |
+
- Dataset: Hemg/deepfake-and-real-images
|
39 |
+
- Split: 80% training, 20% testing
|
40 |
+
- Data augmentation: Resize, normalization
|
41 |
+
|
42 |
+
### 3.2 Training Parameters
|
43 |
+
- Optimizer: Adam
|
44 |
+
- Learning rate: 1e-4
|
45 |
+
- Batch size: 32
|
46 |
+
- Epochs: 10
|
47 |
+
- Loss function: Cross-Entropy
|
48 |
+
|
49 |
+
## 4. Results
|
50 |
+
|
51 |
+
### 4.1 Performance Metrics
|
52 |
+
Based on the test set evaluation:
|
53 |
+
|
54 |
+
- **Overall Accuracy**: 94.87%
|
55 |
+
- **Classification Report**:
|
56 |
+
- Real Images:
|
57 |
+
- Precision: 0.95
|
58 |
+
- Recall: 0.94
|
59 |
+
- F1-score: 0.94
|
60 |
+
- Fake Images:
|
61 |
+
- Precision: 0.94
|
62 |
+
- Recall: 0.95
|
63 |
+
- F1-score: 0.95
|
64 |
+
|
65 |
+
### 4.2 Deployment
|
66 |
+
The system is deployed as a FastAPI service, providing real-time inference with confidence scores.
|
67 |
+
|
68 |
+
### 4.3 Visuals
|
69 |
+
|
70 |
+

|
71 |
+
*Figure 1: Findings of the proposed ensemble-based deepfake detection system*
|
72 |
+
|
73 |
+
#### Performance Visualization
|
74 |
+

|
75 |
+
*Figure 2: Confusion matrix showing model performance on test set*
|
76 |
+
|
77 |
+
|
78 |
+
#### Loss Vs Epochs
|
79 |
+

|
80 |
+
*Figure 3: Loss vs Epochs for individual models*
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
#### Accuracy Vs Epochs
|
85 |
+

|
86 |
+
*Figure 4: Accuracy vs Epochs for individual models*
|
87 |
+
|
88 |
+
|
89 |
+
## 5. Technical Requirements
|
90 |
+
- Python 3.x
|
91 |
+
- PyTorch
|
92 |
+
- timm
|
93 |
+
- FastAPI
|
94 |
+
- PIL
|
95 |
+
- scikit-learn
|
96 |
+
|
97 |
+
## 6. Usage
|
98 |
+
|
99 |
+
### 6.1 API Endpoint
|
100 |
+
```python
|
101 |
+
POST /predict/
|
102 |
+
Input: Image file
|
103 |
+
Output: {
|
104 |
+
"prediction": "Real/Fake",
|
105 |
+
"confidence": "percentage"
|
106 |
+
}
|
107 |
+
```
|
108 |
+
|
109 |
+
### 6.2 Model Training
|
110 |
+
Run cnn-vit file to train models on your custom dataset!
|
111 |
+
|
112 |
+
## 7. Conclusions
|
113 |
+
The ensemble approach demonstrates superior performance in deepfake detection, with the combination of traditional CNNs and modern Vision Transformers providing robust and reliable results. The system's high accuracy and balanced precision-recall metrics make it suitable for real-world applications. Although for model C it doesn't perform well on Epoch 10 but still overall Result is Good.
|
114 |
+
|
115 |
+
## 8. Future Work
|
116 |
+
- Integration of attention mechanisms in CNN models
|
117 |
+
- Exploration of different ensemble strategies
|
118 |
+
- Extension to video deepfake detection
|
119 |
+
- Investigation of model compression techniques
|
120 |
+
|
121 |
+
## References
|
122 |
+
1. Vision Transformer (ViT) - Dosovitskiy et al., 2020
|
123 |
+
2. timm library - Ross Wightman
|
124 |
+
3. FastAPI - Sebastián Ramírez
|
125 |
+
|
126 |
+
## License
|
127 |
+
MIT License
|
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
from fastapi import FastAPI, File, UploadFile
|
3 |
+
from PIL import Image
|
4 |
+
import io, torch
|
5 |
+
from collections import Counter
|
6 |
+
|
7 |
+
from models import ModelA, ModelB, ModelC, transform_small, transform_large
|
8 |
+
|
9 |
+
# 1. spin up FastAPI
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
# 2. load your saved weights
|
13 |
+
device = torch.device('cpu')
|
14 |
+
modelA = ModelA();
|
15 |
+
modelA.load_state_dict(torch.load('modelA.pth', map_location=device,weights_only=True))
|
16 |
+
modelA.eval()
|
17 |
+
|
18 |
+
modelB = ModelB()
|
19 |
+
modelB.load_state_dict(torch.load('modelB.pth', map_location=device,weights_only=True))
|
20 |
+
modelB.eval()
|
21 |
+
modelC = ModelC()
|
22 |
+
modelC.load_state_dict(torch.load('modelC.pth', map_location=device,weights_only=True))
|
23 |
+
modelC.eval()
|
24 |
+
|
25 |
+
@app.post("/predict/")
|
26 |
+
async def predict(file: UploadFile = File(...)):
|
27 |
+
# read image bytes → PIL
|
28 |
+
data = await file.read()
|
29 |
+
img = Image.open(io.BytesIO(data)).convert('RGB')
|
30 |
+
|
31 |
+
# preprocess
|
32 |
+
t_small = transform_small(img).unsqueeze(0) # for A & B
|
33 |
+
t_large = transform_large(img).unsqueeze(0) # for C
|
34 |
+
|
35 |
+
# run inference
|
36 |
+
votes = []
|
37 |
+
with torch.no_grad():
|
38 |
+
for model, inp in [(modelA, t_small), (modelB, t_small), (modelC, t_large)]:
|
39 |
+
out = model(inp)
|
40 |
+
_, pred = out.max(1)
|
41 |
+
votes.append(int(pred.item()))
|
42 |
+
|
43 |
+
# majority vote + confidence
|
44 |
+
vote_count = Counter(votes)
|
45 |
+
final_label = vote_count.most_common(1)[0][0]
|
46 |
+
confidence = vote_count[final_label] / len(votes)
|
47 |
+
|
48 |
+
return {
|
49 |
+
"prediction": "Real" if final_label == 1 else "Fake",
|
50 |
+
"confidence": f"{confidence*100:.1f}%"
|
51 |
+
}
|
cnn-vit-transformer-deepfake.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ensemble_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"models": ["modelA.pth", "modelB.pth", "modelC.pth"], "accuracy": 0.9487482596474637}
|
modelA.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:570d72aef91239c7ac636ddd14fdc00ddb75eee0a8f994f902df271cf2943b37
|
3 |
+
size 10603798
|
modelB.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86b1ff337ca94507916e756f602a8917ad04c719e87827a89e7cdebb65df7cda
|
3 |
+
size 9820010
|
modelC.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff3dcdc52e47b3b3fd8ba4b67366f540a7bf6e0c6d2b046a299655ac8e77c1ad
|
3 |
+
size 344739732
|
models.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from collections import Counter
|
5 |
+
import matplotlib.pyplot as plt # at top of file
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
|
15 |
+
|
16 |
+
import timm
|
17 |
+
|
18 |
+
|
19 |
+
# -------------------------------
|
20 |
+
# Transformations for different model inputs
|
21 |
+
# -------------------------------
|
22 |
+
# For Model A and Model B, we use small images (50x50)
|
23 |
+
transform_small = transforms.Compose([
|
24 |
+
transforms.Resize((50, 50)),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
27 |
+
])
|
28 |
+
|
29 |
+
# For Model C, we use larger images (224x224)
|
30 |
+
transform_large = transforms.Compose([
|
31 |
+
transforms.Resize((224, 224)),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
34 |
+
])
|
35 |
+
|
36 |
+
|
37 |
+
# --- Model A: CNN-based network for eye and nose regions (12 layers) ---
|
38 |
+
class ModelA(nn.Module):
|
39 |
+
def __init__(self, num_classes=2):
|
40 |
+
super(ModelA, self).__init__()
|
41 |
+
# Three convolutional blocks, each with 2 conv layers + BN, ReLU, pooling and dropout
|
42 |
+
self.block1 = nn.Sequential(
|
43 |
+
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
44 |
+
nn.ReLU(),
|
45 |
+
nn.BatchNorm2d(32),
|
46 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
47 |
+
nn.ReLU(),
|
48 |
+
nn.MaxPool2d(kernel_size=2),
|
49 |
+
nn.Dropout(0.3)
|
50 |
+
)
|
51 |
+
self.block2 = nn.Sequential(
|
52 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.BatchNorm2d(64),
|
55 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
56 |
+
nn.ReLU(),
|
57 |
+
nn.MaxPool2d(kernel_size=2),
|
58 |
+
nn.Dropout(0.3)
|
59 |
+
)
|
60 |
+
self.block3 = nn.Sequential(
|
61 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
62 |
+
nn.ReLU(),
|
63 |
+
nn.BatchNorm2d(128),
|
64 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
65 |
+
nn.ReLU(),
|
66 |
+
nn.MaxPool2d(kernel_size=2),
|
67 |
+
nn.Dropout(0.3)
|
68 |
+
)
|
69 |
+
# After three blocks, feature map size for 50x50 input: 50 -> 25 -> ~12 -> ~6
|
70 |
+
self.classifier = nn.Sequential(
|
71 |
+
nn.Flatten(),
|
72 |
+
nn.Linear(128 * 6 * 6, 512),
|
73 |
+
nn.ReLU(),
|
74 |
+
nn.Dropout(0.3),
|
75 |
+
nn.Linear(512, num_classes)
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.block1(x)
|
80 |
+
x = self.block2(x)
|
81 |
+
x = self.block3(x)
|
82 |
+
x = self.classifier(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
# --- Model B: Simpler CNN-based network (6 layers) ---
|
88 |
+
class ModelB(nn.Module):
|
89 |
+
def __init__(self, num_classes=2):
|
90 |
+
super(ModelB, self).__init__()
|
91 |
+
# A lighter CNN architecture: three conv layers with pooling and dropout
|
92 |
+
self.features = nn.Sequential(
|
93 |
+
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
94 |
+
nn.ReLU(),
|
95 |
+
nn.MaxPool2d(2),
|
96 |
+
nn.Dropout(0.3),
|
97 |
+
|
98 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
99 |
+
nn.ReLU(),
|
100 |
+
nn.MaxPool2d(2),
|
101 |
+
nn.Dropout(0.3),
|
102 |
+
|
103 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
104 |
+
nn.ReLU(),
|
105 |
+
nn.MaxPool2d(2),
|
106 |
+
nn.Dropout(0.3)
|
107 |
+
)
|
108 |
+
self.classifier = nn.Sequential(
|
109 |
+
nn.Flatten(),
|
110 |
+
nn.Linear(128 * 6 * 6, 512),
|
111 |
+
nn.ReLU(),
|
112 |
+
nn.Dropout(0.3),
|
113 |
+
nn.Linear(512, num_classes)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
x = self.features(x)
|
118 |
+
x = self.classifier(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
# --- Model C: CNN + ViT based network for the entire face ---
|
123 |
+
class ModelC(nn.Module):
|
124 |
+
def __init__(self, num_classes=2):
|
125 |
+
super(ModelC, self).__init__()
|
126 |
+
# Feature learning (FL) module: a deep CNN.
|
127 |
+
# For demonstration, we use a simpler CNN here.
|
128 |
+
self.cnn_feature_extractor = nn.Sequential(
|
129 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
130 |
+
nn.ReLU(),
|
131 |
+
nn.MaxPool2d(2),
|
132 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.MaxPool2d(2),
|
135 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
136 |
+
nn.ReLU(),
|
137 |
+
nn.MaxPool2d(2)
|
138 |
+
)
|
139 |
+
# Assume feature map size is reduced appropriately (for 224x224, it becomes roughly 28x28)
|
140 |
+
# Now use a vision transformer module from the timm library.
|
141 |
+
# Note: You may need to install timm (pip install timm).
|
142 |
+
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
|
143 |
+
# Replace the head of ViT to match our number of classes.
|
144 |
+
in_features = self.vit.head.in_features
|
145 |
+
self.vit.head = nn.Linear(in_features, num_classes)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
# Extract lower-level features (optional fusion)
|
149 |
+
features = self.cnn_feature_extractor(x)
|
150 |
+
# For this demonstration, we are feeding the original image to vit.
|
151 |
+
# In a more advanced implementation, you can fuse the CNN features with ViT.
|
152 |
+
out = self.vit(x)
|
153 |
+
return out
|