Spaces:
Running
Running
Commit
·
5e014de
1
Parent(s):
c329af9
first commit
Browse files- .gitignore +46 -0
- Examples/DeepFakes_10.png +0 -0
- Examples/DeepFakes_2.png +0 -0
- Examples/DeepFakes_4.png +0 -0
- Examples/DeepFakes_8.png +0 -0
- Examples/DeepFakes_9.png +0 -0
- Examples/SimSwap_8.png +0 -0
- Examples/StyleGAN_7.png +0 -0
- Examples/o_11.jpg +0 -0
- Examples/o_3.jpg +0 -0
- Examples/o_5.jpg +0 -0
- Examples/o_6.jpg +0 -0
- Examples/o_7.jpg +0 -0
- app.py +206 -0
- dataset/real_n_fake_dataloader.py +119 -0
- face_cropper.py +99 -0
- net/Multimodalmodel.py +41 -0
- test_image_fusion.py +182 -0
- utils/__init__.py +1 -0
- utils/basicblocks.py +32 -0
- utils/classifier.py +32 -0
- utils/config.py +38 -0
- utils/data_transforms.py +33 -0
- utils/feature_fusion_block.py +46 -0
- weights/faceswap-fft-best_model.pth +3 -0
- weights/faceswap-hh-best_model.pth +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
working.ipynb
|
| 2 |
+
training.py
|
| 3 |
+
|
| 4 |
+
# Compiled source #
|
| 5 |
+
###################
|
| 6 |
+
*.com
|
| 7 |
+
*.class
|
| 8 |
+
*.dll
|
| 9 |
+
*.exe
|
| 10 |
+
*.o
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Packages #
|
| 14 |
+
############
|
| 15 |
+
# it's better to unpack these files and commit the raw source because
|
| 16 |
+
# git has its own built in compression methods
|
| 17 |
+
*.7z
|
| 18 |
+
*.dmg
|
| 19 |
+
*.gz
|
| 20 |
+
*.iso
|
| 21 |
+
*.jar
|
| 22 |
+
*.rar
|
| 23 |
+
*.tar
|
| 24 |
+
*.zip
|
| 25 |
+
|
| 26 |
+
# Logs and databases #
|
| 27 |
+
######################
|
| 28 |
+
*.log
|
| 29 |
+
*.sql
|
| 30 |
+
*.sqlite
|
| 31 |
+
|
| 32 |
+
# OS generated files #
|
| 33 |
+
######################
|
| 34 |
+
.DS_Store
|
| 35 |
+
.DS_Store?
|
| 36 |
+
._*
|
| 37 |
+
.Spotlight-V100
|
| 38 |
+
.Trashes
|
| 39 |
+
ehthumbs.db
|
| 40 |
+
Thumbs.db
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
___pycache__/
|
| 45 |
+
test_image.py
|
| 46 |
+
*.pyc
|
Examples/DeepFakes_10.png
ADDED
|
Examples/DeepFakes_2.png
ADDED
|
Examples/DeepFakes_4.png
ADDED
|
Examples/DeepFakes_8.png
ADDED
|
Examples/DeepFakes_9.png
ADDED
|
Examples/SimSwap_8.png
ADDED
|
Examples/StyleGAN_7.png
ADDED
|
Examples/o_11.jpg
ADDED
|
Examples/o_3.jpg
ADDED
|
Examples/o_5.jpg
ADDED
|
Examples/o_6.jpg
ADDED
|
Examples/o_7.jpg
ADDED
|
app.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from face_cropper import detect_and_label_faces
|
| 6 |
+
# Define a custom function to convert an image to grayscale
|
| 7 |
+
def to_grayscale(input_image):
|
| 8 |
+
grayscale_image = Image.fromarray(np.array(input_image).mean(axis=-1).astype(np.uint8))
|
| 9 |
+
return grayscale_image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
description_markdown = """
|
| 13 |
+
# Fake Face Detection tool from TrustWorthy BiometraVision Lab IISER Bhopal
|
| 14 |
+
|
| 15 |
+
## Usage
|
| 16 |
+
This tool expects a face image as input. Upon submission, it will process the image and provide an output with bounding boxes drawn on the face. Alongside the visual markers, the tool will give a detection result indicating whether the face is fake or real.
|
| 17 |
+
|
| 18 |
+
## Disclaimer
|
| 19 |
+
Please note that this tool is for research purposes only and may not always be 100% accurate. Users are advised to exercise discretion and supervise the tool's usage accordingly.
|
| 20 |
+
|
| 21 |
+
## Licensing and Permissions
|
| 22 |
+
This tool has been developed solely for research and demonstrative purposes. Any commercial utilization of this tool is strictly prohibited unless explicit permission has been obtained from the developers.
|
| 23 |
+
|
| 24 |
+
## Developer Contact
|
| 25 |
+
For further inquiries or permissions, you can reach out to the developer through the following social media accounts:
|
| 26 |
+
- [LAB Webpage](https://sites.google.com/iiitd.ac.in/agarwalakshay/labiiserb?authuser=0)
|
| 27 |
+
- [LinkedIn](https://www.linkedin.com/in/shivam-shukla-0a50ab1a2/)
|
| 28 |
+
- [GitHub](https://github.com/SaShukla090)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Create the Gradio app
|
| 35 |
+
app = gr.Interface(
|
| 36 |
+
fn=detect_and_label_faces,
|
| 37 |
+
inputs=gr.Image(type="pil"),
|
| 38 |
+
outputs="image",
|
| 39 |
+
# examples=[
|
| 40 |
+
# "path_to_example_image_1.jpg",
|
| 41 |
+
# "path_to_example_image_2.jpg"
|
| 42 |
+
# ]
|
| 43 |
+
examples=[
|
| 44 |
+
os.path.join("Examples", image_name) for image_name in os.listdir("Examples")
|
| 45 |
+
],
|
| 46 |
+
title="Fake Face Detection",
|
| 47 |
+
description=description_markdown,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Run the app
|
| 51 |
+
app.launch()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# import torch.nn.functional as F
|
| 86 |
+
# import torch
|
| 87 |
+
# import torch.nn as nn
|
| 88 |
+
# import torch.optim as optim
|
| 89 |
+
# from torch.utils.data import DataLoader
|
| 90 |
+
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 91 |
+
# from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 92 |
+
# from tqdm import tqdm
|
| 93 |
+
# import warnings
|
| 94 |
+
# warnings.filterwarnings("ignore")
|
| 95 |
+
|
| 96 |
+
# from utils.config import cfg
|
| 97 |
+
# from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
|
| 98 |
+
# from utils.data_transforms import get_transforms_train, get_transforms_val
|
| 99 |
+
# from net.Multimodalmodel import Image_n_DCT
|
| 100 |
+
# import gradio as gr
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# import os
|
| 106 |
+
# import json
|
| 107 |
+
# import torch
|
| 108 |
+
# from torchvision import transforms
|
| 109 |
+
# from torch.utils.data import DataLoader, Dataset
|
| 110 |
+
# from PIL import Image
|
| 111 |
+
# import numpy as np
|
| 112 |
+
# import pandas as pd
|
| 113 |
+
# import cv2
|
| 114 |
+
# import argparse
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# from sklearn.metrics import classification_report, confusion_matrix
|
| 122 |
+
# import matplotlib.pyplot as plt
|
| 123 |
+
# import seaborn as sns
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# class Test_Dataset(Dataset):
|
| 130 |
+
# def __init__(self, test_data_path = None, transform = None, image = None):
|
| 131 |
+
# """
|
| 132 |
+
# Args:
|
| 133 |
+
# returns:
|
| 134 |
+
# """
|
| 135 |
+
|
| 136 |
+
# if test_data_path is None and image is not None:
|
| 137 |
+
# self.dataset = [(image, 2)]
|
| 138 |
+
# self.transform = transform
|
| 139 |
+
|
| 140 |
+
# def __len__(self):
|
| 141 |
+
# return len(self.dataset)
|
| 142 |
+
|
| 143 |
+
# def __getitem__(self, idx):
|
| 144 |
+
# sample_input = self.get_sample_input(idx)
|
| 145 |
+
# return sample_input
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# def get_sample_input(self, idx):
|
| 149 |
+
# rgb_image = self.get_rgb_image(self.dataset[idx][0])
|
| 150 |
+
# dct_image = self.compute_dct_color(self.dataset[idx][0])
|
| 151 |
+
# # label = self.get_label(idx)
|
| 152 |
+
# sample_input = {"rgb_image": rgb_image, "dct_image": dct_image}
|
| 153 |
+
|
| 154 |
+
# return sample_input
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# def get_rgb_image(self, rgb_image):
|
| 158 |
+
# # rgb_image_path = self.dataset[idx][0]
|
| 159 |
+
# # rgb_image = Image.open(rgb_image_path)
|
| 160 |
+
# if self.transform:
|
| 161 |
+
# rgb_image = self.transform(rgb_image)
|
| 162 |
+
# return rgb_image
|
| 163 |
+
|
| 164 |
+
# def get_dct_image(self, idx):
|
| 165 |
+
# rgb_image_path = self.dataset[idx][0]
|
| 166 |
+
# rgb_image = cv2.imread(rgb_image_path)
|
| 167 |
+
# dct_image = self.compute_dct_color(rgb_image)
|
| 168 |
+
# if self.transform:
|
| 169 |
+
# dct_image = self.transform(dct_image)
|
| 170 |
+
|
| 171 |
+
# return dct_image
|
| 172 |
+
|
| 173 |
+
# def get_label(self, idx):
|
| 174 |
+
# return self.dataset[idx][1]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# def compute_dct_color(self, image):
|
| 178 |
+
# image_float = np.float32(image)
|
| 179 |
+
# dct_image = np.zeros_like(image_float)
|
| 180 |
+
# for i in range(3):
|
| 181 |
+
# dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
| 182 |
+
# if self.transform:
|
| 183 |
+
# dct_image = self.transform(dct_image)
|
| 184 |
+
# return dct_image
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# device = torch.device("cpu")
|
| 188 |
+
# # print(device)
|
| 189 |
+
# model = Image_n_DCT()
|
| 190 |
+
# model.load_state_dict(torch.load('weights/best_model.pth', map_location = device))
|
| 191 |
+
# model.to(device)
|
| 192 |
+
# model.eval()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# def classify(image):
|
| 196 |
+
# test_dataset = Test_Dataset(transform = get_transforms_val(), image = image)
|
| 197 |
+
# inputs = test_dataset[0]
|
| 198 |
+
# rgb_image, dct_image = inputs['rgb_image'].to(device), inputs['dct_image'].to(device)
|
| 199 |
+
# output = model(rgb_image.unsqueeze(0), dct_image.unsqueeze(0))
|
| 200 |
+
# # _, predicted = torch.max(output.data, 1)
|
| 201 |
+
# # print(f"the face is {'real' if predicted==1 else 'fake'}")
|
| 202 |
+
# return {'Fake': output[0][0], 'Real': output[0][1]}
|
| 203 |
+
|
| 204 |
+
# iface = gr.Interface(fn=classify, inputs="image", outputs="label")
|
| 205 |
+
# if __name__ == "__main__":
|
| 206 |
+
# iface.launch()
|
dataset/real_n_fake_dataloader.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We will use this file to create a dataloader for the real and fake dataset
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import pywt
|
| 16 |
+
|
| 17 |
+
class Extracted_Frames_Dataset(Dataset):
|
| 18 |
+
def __init__(self, root_dir, split = "train", transform = None, extend = 'None', multi_modal = "dct"):
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
returns:
|
| 22 |
+
"""
|
| 23 |
+
AssertionError(split in ["train", "val", "test"]), "Split must be one of (train, val, test)"
|
| 24 |
+
self.multi_modal = multi_modal
|
| 25 |
+
self.root_dir = root_dir
|
| 26 |
+
self.split = split
|
| 27 |
+
self.transform = transform
|
| 28 |
+
if extend == 'faceswap':
|
| 29 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"faceswap_extended_{self.split}.csv"))
|
| 30 |
+
elif extend == 'fsgan':
|
| 31 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"fsgan_extended_{self.split}.csv"))
|
| 32 |
+
else:
|
| 33 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"{self.split}.csv"))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.dataset)
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, idx):
|
| 40 |
+
sample_input = self.get_sample_input(idx)
|
| 41 |
+
return sample_input
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_sample_input(self, idx):
|
| 45 |
+
rgb_image = self.get_rgb_image(idx)
|
| 46 |
+
label = self.get_label(idx)
|
| 47 |
+
if self.multi_modal == "dct":
|
| 48 |
+
dct_image = self.get_dct_image(idx)
|
| 49 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
|
| 50 |
+
|
| 51 |
+
# dct_image = self.get_dct_image(idx)
|
| 52 |
+
elif self.multi_modal == "fft":
|
| 53 |
+
fft_image = self.get_fft_image(idx)
|
| 54 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
|
| 55 |
+
elif self.multi_modal == "hh":
|
| 56 |
+
hh_image = self.get_hh_image(idx)
|
| 57 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
|
| 58 |
+
else:
|
| 59 |
+
AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
|
| 60 |
+
|
| 61 |
+
return sample_input
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_fft_image(self, idx):
|
| 65 |
+
gray_image_path = self.dataset.iloc[idx, 0]
|
| 66 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
| 67 |
+
fft_image = self.compute_fft(gray_image)
|
| 68 |
+
if self.transform:
|
| 69 |
+
fft_image = self.transform(fft_image)
|
| 70 |
+
|
| 71 |
+
return fft_image
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def compute_fft(self, image):
|
| 75 |
+
f = np.fft.fft2(image)
|
| 76 |
+
fshift = np.fft.fftshift(f)
|
| 77 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
| 78 |
+
return magnitude_spectrum
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_hh_image(self, idx):
|
| 82 |
+
gray_image_path = self.dataset.iloc[idx, 0]
|
| 83 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
| 84 |
+
hh_image = self.compute_hh(gray_image)
|
| 85 |
+
if self.transform:
|
| 86 |
+
hh_image = self.transform(hh_image)
|
| 87 |
+
return hh_image
|
| 88 |
+
|
| 89 |
+
def compute_hh(self, image):
|
| 90 |
+
coeffs2 = pywt.dwt2(image, 'haar')
|
| 91 |
+
LL, (LH, HL, HH) = coeffs2
|
| 92 |
+
return HH
|
| 93 |
+
|
| 94 |
+
def get_rgb_image(self, idx):
|
| 95 |
+
rgb_image_path = self.dataset.iloc[idx, 0]
|
| 96 |
+
rgb_image = Image.open(rgb_image_path)
|
| 97 |
+
if self.transform:
|
| 98 |
+
rgb_image = self.transform(rgb_image)
|
| 99 |
+
return rgb_image
|
| 100 |
+
|
| 101 |
+
def get_dct_image(self, idx):
|
| 102 |
+
rgb_image_path = self.dataset.iloc[idx, 0]
|
| 103 |
+
rgb_image = cv2.imread(rgb_image_path)
|
| 104 |
+
dct_image = self.compute_dct_color(rgb_image)
|
| 105 |
+
if self.transform:
|
| 106 |
+
dct_image = self.transform(dct_image)
|
| 107 |
+
|
| 108 |
+
return dct_image
|
| 109 |
+
|
| 110 |
+
def get_label(self, idx):
|
| 111 |
+
return self.dataset.iloc[idx, 1]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compute_dct_color(self, image):
|
| 115 |
+
image_float = np.float32(image)
|
| 116 |
+
dct_image = np.zeros_like(image_float)
|
| 117 |
+
for i in range(3):
|
| 118 |
+
dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
| 119 |
+
return dct_image
|
face_cropper.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import mediapipe as mp
|
| 3 |
+
import os
|
| 4 |
+
from gradio_client import Client
|
| 5 |
+
from test_image_fusion import Test
|
| 6 |
+
from test_image_fusion import Test
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
# client = Client("https://tbvl-real-and-fake-face-detection.hf.space/--replicas/40d41jxhhx/")
|
| 16 |
+
|
| 17 |
+
data = 'faceswap'
|
| 18 |
+
dct = 'fft'
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
testet = Test(model_paths = [f"weights/{data}-hh-best_model.pth",
|
| 22 |
+
f"weights/{data}-fft-best_model.pth"],
|
| 23 |
+
multi_modal = ['hh', 'fft'])
|
| 24 |
+
|
| 25 |
+
# Initialize MediaPipe Face Detection
|
| 26 |
+
mp_face_detection = mp.solutions.face_detection
|
| 27 |
+
mp_drawing = mp.solutions.drawing_utils
|
| 28 |
+
face_detection = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.35)
|
| 29 |
+
|
| 30 |
+
# Create a directory to save the cropped face images if it does not exist
|
| 31 |
+
save_dir = "cropped_faces"
|
| 32 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# def detect_and_label_faces(image_path):
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Function to crop faces from a video and save them as images
|
| 38 |
+
# def crop_faces_from_video(video_path):
|
| 39 |
+
# # Read the video
|
| 40 |
+
# cap = cv2.VideoCapture(video_path)
|
| 41 |
+
# frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 42 |
+
# frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 43 |
+
# fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 44 |
+
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 45 |
+
|
| 46 |
+
# # Define the codec and create VideoWriter object
|
| 47 |
+
# out = cv2.VideoWriter(f'output_{real}_{data}_fusion.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, (frame_width, frame_height))
|
| 48 |
+
|
| 49 |
+
# if not cap.isOpened():
|
| 50 |
+
# print("Error: Could not open video.")
|
| 51 |
+
# return
|
| 52 |
+
# Convert PIL Image to NumPy array for OpenCV
|
| 53 |
+
def pil_to_opencv(pil_image):
|
| 54 |
+
open_cv_image = np.array(pil_image)
|
| 55 |
+
# Convert RGB to BGR for OpenCV
|
| 56 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
| 57 |
+
return open_cv_image
|
| 58 |
+
|
| 59 |
+
# Convert OpenCV NumPy array to PIL Image
|
| 60 |
+
def opencv_to_pil(opencv_image):
|
| 61 |
+
# Convert BGR to RGB
|
| 62 |
+
pil_image = Image.fromarray(opencv_image[:, :, ::-1])
|
| 63 |
+
return pil_image
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def detect_and_label_faces(frame):
|
| 69 |
+
frame = pil_to_opencv(frame)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
print(type(frame))
|
| 73 |
+
# Convert the frame to RGB
|
| 74 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 75 |
+
# Perform face detection
|
| 76 |
+
results = face_detection.process(frame_rgb)
|
| 77 |
+
|
| 78 |
+
# If faces are detected, crop and save each face as an image
|
| 79 |
+
if results.detections:
|
| 80 |
+
for face_count,detection in enumerate(results.detections):
|
| 81 |
+
bboxC = detection.location_data.relative_bounding_box
|
| 82 |
+
ih, iw, _ = frame.shape
|
| 83 |
+
x, y, w, h = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
|
| 84 |
+
# Crop the face region and make sure the bounding box is within the frame dimensions
|
| 85 |
+
crop_img = frame[max(0, y):min(ih, y+h), max(0, x):min(iw, x+w)]
|
| 86 |
+
if crop_img.size > 0:
|
| 87 |
+
face_filename = os.path.join(save_dir, f'face_{face_count}.jpg')
|
| 88 |
+
cv2.imwrite(face_filename, crop_img)
|
| 89 |
+
|
| 90 |
+
label = testet.testimage(face_filename)
|
| 91 |
+
|
| 92 |
+
if os.path.exists(face_filename):
|
| 93 |
+
os.remove(face_filename)
|
| 94 |
+
|
| 95 |
+
color = (0, 0, 255) if label == 'fake' else (0, 255, 0)
|
| 96 |
+
cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
|
| 97 |
+
cv2.putText(frame, label, (x, y + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
| 98 |
+
return opencv_to_pil(frame)
|
| 99 |
+
|
net/Multimodalmodel.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils.config import cfg
|
| 5 |
+
from utils.basicblocks import BasicBlock
|
| 6 |
+
from utils.feature_fusion_block import DCT_Attention_Fusion_Conv
|
| 7 |
+
from utils.classifier import ClassifierModel
|
| 8 |
+
|
| 9 |
+
class Image_n_DCT(nn.Module):
|
| 10 |
+
def __init__(self,):
|
| 11 |
+
super(Image_n_DCT, self).__init__()
|
| 12 |
+
self.Img_Block = nn.ModuleList()
|
| 13 |
+
self.DCT_Block = nn.ModuleList()
|
| 14 |
+
self.RGB_n_DCT_Fusion = nn.ModuleList()
|
| 15 |
+
self.num_classes = len(cfg.CLASSES)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
for i in range(len(cfg.MULTIMODAL_FUSION.IMG_CHANNELS) - 1):
|
| 20 |
+
self.Img_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
|
| 21 |
+
self.DCT_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.DCT_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
|
| 22 |
+
self.RGB_n_DCT_Fusion.append(DCT_Attention_Fusion_Conv(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1]))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
self.classifier = ClassifierModel(self.num_classes)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def forward(self, rgb_image, dct_image):
|
| 30 |
+
image = [rgb_image]
|
| 31 |
+
dct_image = [dct_image]
|
| 32 |
+
|
| 33 |
+
for i in range(len(self.Img_Block)):
|
| 34 |
+
image.append(self.Img_Block[i](image[-1]))
|
| 35 |
+
dct_image.append(self.DCT_Block[i](dct_image[-1]))
|
| 36 |
+
image[-1] = self.RGB_n_DCT_Fusion[i](image[-1], dct_image[-1])
|
| 37 |
+
dct_image[-1] = image[-1]
|
| 38 |
+
out = self.classifier(image[-1])
|
| 39 |
+
|
| 40 |
+
return out
|
| 41 |
+
|
test_image_fusion.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import warnings
|
| 10 |
+
warnings.filterwarnings("ignore")
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import pywt
|
| 15 |
+
|
| 16 |
+
from utils.config import cfg
|
| 17 |
+
from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
|
| 18 |
+
from utils.data_transforms import get_transforms_train, get_transforms_val
|
| 19 |
+
from net.Multimodalmodel import Image_n_DCT
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import json
|
| 25 |
+
import torch
|
| 26 |
+
from torchvision import transforms
|
| 27 |
+
from torch.utils.data import DataLoader, Dataset
|
| 28 |
+
from PIL import Image
|
| 29 |
+
import numpy as np
|
| 30 |
+
import pandas as pd
|
| 31 |
+
import cv2
|
| 32 |
+
import argparse
|
| 33 |
+
|
| 34 |
+
class Test_Dataset(Dataset):
|
| 35 |
+
def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
returns:
|
| 39 |
+
"""
|
| 40 |
+
self.multi_modal = multi_modal
|
| 41 |
+
if test_data_path is None and image_path is not None:
|
| 42 |
+
self.dataset = [[image_path, 2]]
|
| 43 |
+
self.transform = transform
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
self.transform = transform
|
| 47 |
+
|
| 48 |
+
self.real_data = os.listdir(test_data_path + "/real")
|
| 49 |
+
self.fake_data = os.listdir(test_data_path + "/fake")
|
| 50 |
+
self.dataset = []
|
| 51 |
+
for image in self.real_data:
|
| 52 |
+
self.dataset.append([test_data_path + "/real/" + image, 1])
|
| 53 |
+
|
| 54 |
+
for image in self.fake_data:
|
| 55 |
+
self.dataset.append([test_data_path + "/fake/" + image, 0])
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.dataset)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
sample_input = self.get_sample_input(idx)
|
| 62 |
+
return sample_input
|
| 63 |
+
|
| 64 |
+
def get_sample_input(self, idx):
|
| 65 |
+
rgb_image = self.get_rgb_image(idx)
|
| 66 |
+
label = self.get_label(idx)
|
| 67 |
+
if self.multi_modal == "dct":
|
| 68 |
+
dct_image = self.get_dct_image(idx)
|
| 69 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
|
| 70 |
+
|
| 71 |
+
# dct_image = self.get_dct_image(idx)
|
| 72 |
+
elif self.multi_modal == "fft":
|
| 73 |
+
fft_image = self.get_fft_image(idx)
|
| 74 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
|
| 75 |
+
elif self.multi_modal == "hh":
|
| 76 |
+
hh_image = self.get_hh_image(idx)
|
| 77 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
|
| 78 |
+
else:
|
| 79 |
+
AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
|
| 80 |
+
|
| 81 |
+
return sample_input
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_fft_image(self, idx):
|
| 85 |
+
gray_image_path = self.dataset[idx][0]
|
| 86 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
| 87 |
+
fft_image = self.compute_fft(gray_image)
|
| 88 |
+
if self.transform:
|
| 89 |
+
fft_image = self.transform(fft_image)
|
| 90 |
+
|
| 91 |
+
return fft_image
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def compute_fft(self, image):
|
| 95 |
+
f = np.fft.fft2(image)
|
| 96 |
+
fshift = np.fft.fftshift(f)
|
| 97 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
| 98 |
+
return magnitude_spectrum
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_hh_image(self, idx):
|
| 102 |
+
gray_image_path = self.dataset[idx][0]
|
| 103 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
| 104 |
+
hh_image = self.compute_hh(gray_image)
|
| 105 |
+
if self.transform:
|
| 106 |
+
hh_image = self.transform(hh_image)
|
| 107 |
+
return hh_image
|
| 108 |
+
|
| 109 |
+
def compute_hh(self, image):
|
| 110 |
+
coeffs2 = pywt.dwt2(image, 'haar')
|
| 111 |
+
LL, (LH, HL, HH) = coeffs2
|
| 112 |
+
return HH
|
| 113 |
+
|
| 114 |
+
def get_rgb_image(self, idx):
|
| 115 |
+
rgb_image_path = self.dataset[idx][0]
|
| 116 |
+
rgb_image = Image.open(rgb_image_path)
|
| 117 |
+
if self.transform:
|
| 118 |
+
rgb_image = self.transform(rgb_image)
|
| 119 |
+
return rgb_image
|
| 120 |
+
|
| 121 |
+
def get_dct_image(self, idx):
|
| 122 |
+
rgb_image_path = self.dataset[idx][0]
|
| 123 |
+
rgb_image = cv2.imread(rgb_image_path)
|
| 124 |
+
dct_image = self.compute_dct_color(rgb_image)
|
| 125 |
+
if self.transform:
|
| 126 |
+
dct_image = self.transform(dct_image)
|
| 127 |
+
|
| 128 |
+
return dct_image
|
| 129 |
+
|
| 130 |
+
def get_label(self, idx):
|
| 131 |
+
return self.dataset[idx][1]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def compute_dct_color(self, image):
|
| 135 |
+
image_float = np.float32(image)
|
| 136 |
+
dct_image = np.zeros_like(image_float)
|
| 137 |
+
for i in range(3):
|
| 138 |
+
dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
| 139 |
+
return dct_image
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Test:
|
| 143 |
+
def __init__(self, model_paths = [ 'weights/faceswap-hh-best_model.pth',
|
| 144 |
+
'weights/faceswap-fft-best_model.pth',
|
| 145 |
+
],
|
| 146 |
+
multi_modal = ["hh","fct"]):
|
| 147 |
+
self.model_path = model_paths
|
| 148 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 149 |
+
print(self.device)
|
| 150 |
+
# Load the model
|
| 151 |
+
self.model1 = Image_n_DCT()
|
| 152 |
+
self.model1.load_state_dict(torch.load(self.model_path[0], map_location = self.device))
|
| 153 |
+
self.model1.to(self.device)
|
| 154 |
+
self.model1.eval()
|
| 155 |
+
|
| 156 |
+
self.model2 = Image_n_DCT()
|
| 157 |
+
self.model2.load_state_dict(torch.load(self.model_path[1], map_location = self.device))
|
| 158 |
+
self.model2.to(self.device)
|
| 159 |
+
self.model2.eval()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
self.multi_modal = multi_modal
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def testimage(self, image_path):
|
| 166 |
+
test_dataset1 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[0])
|
| 167 |
+
test_dataset2 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[1])
|
| 168 |
+
|
| 169 |
+
inputs1 = test_dataset1[0]
|
| 170 |
+
rgb_image1, dct_image1 = inputs1['rgb_image'].to(self.device), inputs1['dct_image'].to(self.device)
|
| 171 |
+
|
| 172 |
+
inputs2 = test_dataset2[0]
|
| 173 |
+
rgb_image2, dct_image2 = inputs2['rgb_image'].to(self.device), inputs2['dct_image'].to(self.device)
|
| 174 |
+
|
| 175 |
+
output1 = self.model1(rgb_image1.unsqueeze(0), dct_image1.unsqueeze(0))
|
| 176 |
+
|
| 177 |
+
output2 = self.model2(rgb_image2.unsqueeze(0), dct_image2.unsqueeze(0))
|
| 178 |
+
|
| 179 |
+
output = (output1 + output2)/2
|
| 180 |
+
# print(output.shape)
|
| 181 |
+
_, predicted = torch.max(output.data, 1)
|
| 182 |
+
return 'real' if predicted==1 else 'fake'
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import os
|
utils/basicblocks.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
BatchNorm2d = nn.BatchNorm2d
|
| 7 |
+
|
| 8 |
+
def conv3x3(in_planes, out_planes, stride = 1):
|
| 9 |
+
"""3x3 convolution with padding"""
|
| 10 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size = 3, stride = stride,
|
| 11 |
+
padding = 1, bias = False)
|
| 12 |
+
|
| 13 |
+
def conv1x1(in_planes, out_planes, stride = 1):
|
| 14 |
+
"""3x3 convolution with padding"""
|
| 15 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size = 1, stride = stride,
|
| 16 |
+
padding = 0, bias = False)
|
| 17 |
+
|
| 18 |
+
class BasicBlock(nn.Module):
|
| 19 |
+
def __init__(self, inplanes, outplanes, stride = 1):
|
| 20 |
+
super(BasicBlock, self).__init__()
|
| 21 |
+
self.conv1 = conv3x3(inplanes, outplanes, stride)
|
| 22 |
+
self.bn1 = BatchNorm2d(outplanes)
|
| 23 |
+
self.relu = nn.ReLU(inplace = True)
|
| 24 |
+
self.conv2 = conv3x3(outplanes, outplanes, 2*stride)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
out = self.conv1(x)
|
| 28 |
+
out = self.bn1(out)
|
| 29 |
+
out = self.relu(out)
|
| 30 |
+
out = self.conv2(out)
|
| 31 |
+
|
| 32 |
+
return out
|
utils/classifier.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ClassifierModel(nn.Module):
|
| 6 |
+
def __init__(self, num_classes):
|
| 7 |
+
super(ClassifierModel, self).__init__()
|
| 8 |
+
# Apply adaptive average pooling to convert (512, 14, 14) to (512)
|
| 9 |
+
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 10 |
+
|
| 11 |
+
# Define multiple fully connected layers
|
| 12 |
+
self.fc1 = nn.Linear(512, 256) # First FC layer, reducing to 256 features
|
| 13 |
+
self.fc2 = nn.Linear(256, 128) # Second FC layer, reducing to 128 features
|
| 14 |
+
self.fc3 = nn.Linear(128, num_classes) # Final FC layer, outputting num_classes for classification
|
| 15 |
+
|
| 16 |
+
#dropout for regularization
|
| 17 |
+
self.dropout = nn.Dropout(0.2)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
# Flatten the output from the adaptive pooling
|
| 21 |
+
x = self.adaptive_pool(x)
|
| 22 |
+
x = torch.flatten(x, 1)
|
| 23 |
+
|
| 24 |
+
# Pass through the fully connected layers with ReLU activations and dropout
|
| 25 |
+
x = F.relu(self.fc1(x))
|
| 26 |
+
x = self.dropout(x)
|
| 27 |
+
x = F.relu(self.fc2(x))
|
| 28 |
+
x = self.dropout(x)
|
| 29 |
+
x = self.fc3(x) # No activation, raw scores
|
| 30 |
+
x = F.softmax(x, dim=1)
|
| 31 |
+
|
| 32 |
+
return x
|
utils/config.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from easydict import EasyDict as edict
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
__C = edict()
|
| 5 |
+
cfg = __C
|
| 6 |
+
|
| 7 |
+
# 0. basic config
|
| 8 |
+
__C.TAG = 'default'
|
| 9 |
+
__C.CLASSES = ['Real', 'Fake']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# config of network input
|
| 13 |
+
__C.MULTIMODAL_FUSION = edict()
|
| 14 |
+
__C.MULTIMODAL_FUSION.IMG_CHANNELS = [3, 64, 128, 256, 512]
|
| 15 |
+
__C.MULTIMODAL_FUSION.DCT_CHANNELS = [1, 64, 128, 256, 512]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__C.NUM_EPOCHS = 100
|
| 19 |
+
|
| 20 |
+
__C.BATCH_SIZE = 64
|
| 21 |
+
|
| 22 |
+
__C.NUM_WORKERS = 4
|
| 23 |
+
|
| 24 |
+
__C.LEARNING_RATE = 0.0001
|
| 25 |
+
|
| 26 |
+
__C.PRETRAINED = False
|
| 27 |
+
|
| 28 |
+
__C.PRETRAINED_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__C.TEST_BATCH_SIZE = 512
|
| 34 |
+
|
| 35 |
+
__C.TEST_CSV = "/home/user/Documents/Real_and_DeepFake/src/dataset/extended_val.csv"
|
| 36 |
+
|
| 37 |
+
__C.MODEL_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
|
| 38 |
+
|
utils/data_transforms.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_transforms_train():
|
| 6 |
+
# Define the dataset object
|
| 7 |
+
transform = transform = transforms.Compose([
|
| 8 |
+
transforms.ToTensor(),
|
| 9 |
+
transforms.Lambda(lambda x: x.float()) ,
|
| 10 |
+
transforms.Resize((224, 224)),
|
| 11 |
+
transforms.RandomHorizontalFlip(),
|
| 12 |
+
transforms.RandomRotation(10),
|
| 13 |
+
transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
|
| 14 |
+
])
|
| 15 |
+
|
| 16 |
+
return transform
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_transforms_val():
|
| 22 |
+
transform = transform = transforms.Compose([
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
transforms.Lambda(lambda x: x.float()) ,
|
| 25 |
+
transforms.Resize((224, 224)),
|
| 26 |
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 27 |
+
transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
return transform
|
utils/feature_fusion_block.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
class SpatialAttention(nn.Module):
|
| 6 |
+
def __init__(self, in_channels):
|
| 7 |
+
super(SpatialAttention, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, padding=0)
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
# Calculate attention scores
|
| 12 |
+
attention_scores = self.conv1(x)
|
| 13 |
+
attention_scores = F.softmax(attention_scores, dim=2)
|
| 14 |
+
|
| 15 |
+
# Apply attention to input features
|
| 16 |
+
attended_features = x * attention_scores
|
| 17 |
+
|
| 18 |
+
return attended_features
|
| 19 |
+
|
| 20 |
+
class DCT_Attention_Fusion_Conv(nn.Module):
|
| 21 |
+
def __init__(self, channels):
|
| 22 |
+
super(DCT_Attention_Fusion_Conv, self).__init__()
|
| 23 |
+
self.rgb_attention = SpatialAttention(channels)
|
| 24 |
+
self.depth_attention = SpatialAttention(channels)
|
| 25 |
+
self.rgb_pooling = nn.AdaptiveAvgPool2d(1)
|
| 26 |
+
self.depth_pooling = nn.AdaptiveAvgPool2d(1)
|
| 27 |
+
|
| 28 |
+
def forward(self, rgb_features, DCT_features):
|
| 29 |
+
# Spatial attention for both modalities
|
| 30 |
+
rgb_attended_features = self.rgb_attention(rgb_features)
|
| 31 |
+
depth_attended_features = self.depth_attention(DCT_features)
|
| 32 |
+
|
| 33 |
+
# Adaptive pooling for both modalities
|
| 34 |
+
rgb_pooled = self.rgb_pooling(rgb_attended_features)
|
| 35 |
+
depth_pooled = self.depth_pooling(depth_attended_features)
|
| 36 |
+
|
| 37 |
+
# Upsample attended and pooled features to the original size
|
| 38 |
+
rgb_upsampled = F.interpolate(rgb_pooled, size=rgb_features.size()[2:], mode='bilinear', align_corners=False)
|
| 39 |
+
depth_upsampled = F.interpolate(depth_pooled, size=DCT_features.size()[2:], mode='bilinear', align_corners=False)
|
| 40 |
+
|
| 41 |
+
# Concatenate the upsampled features
|
| 42 |
+
fused_features = F.relu(rgb_upsampled+depth_upsampled)
|
| 43 |
+
# fused_features = fused_features.sum(dim=1)
|
| 44 |
+
|
| 45 |
+
return fused_features
|
| 46 |
+
|
weights/faceswap-fft-best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c42f82049bed6db4edb5e933ffe4ce6e3612e7fbf351c29327d9cfe81f8c5ff
|
| 3 |
+
size 38189260
|
weights/faceswap-hh-best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15272d1439ef629566cf43b3d4d1bc4f2091f3db1c0d0430038b56880c7ef385
|
| 3 |
+
size 38189178
|