Upload folder using huggingface_hub
Browse files- .gitattributes +35 -35
- README.md +68 -0
- config.json +21 -0
- inference.py +69 -0
- model_87_acc_20_frames_final_data.pt +3 -0
- modeling.py +30 -0
- modeling_deepfake.py +139 -0
- processor_deepfake.py +136 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,71 @@
|
|
1 |
---
|
|
|
2 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language: en
|
3 |
license: mit
|
4 |
+
tags:
|
5 |
+
- deepfake
|
6 |
+
- video-classification
|
7 |
+
- pytorch
|
8 |
+
- computer-vision
|
9 |
+
datasets:
|
10 |
+
- custom
|
11 |
+
pipeline_tag: video-classification
|
12 |
+
widget:
|
13 |
+
- example_title: Deepfake Detection
|
14 |
+
example_input: "Upload a video to detect if it's real or fake"
|
15 |
---
|
16 |
+
|
17 |
+
# DeepFake Detection Model
|
18 |
+
|
19 |
+
This model detects deepfake videos using a combination of ResNext50 and LSTM architecture. It analyzes video frames to determine if a video is authentic or manipulated.
|
20 |
+
|
21 |
+
## Model Details
|
22 |
+
|
23 |
+
- **Model Type:** Video Classification
|
24 |
+
- **Task:** Deepfake Detection
|
25 |
+
- **Framework:** PyTorch
|
26 |
+
- **Training Data:** Deepfake video datasets
|
27 |
+
- **Accuracy:** 87% on test datasets
|
28 |
+
- **Output:** Binary classification (real/fake) with confidence score
|
29 |
+
|
30 |
+
## Usage
|
31 |
+
|
32 |
+
```python
|
33 |
+
from transformers import pipeline
|
34 |
+
|
35 |
+
# Load the model
|
36 |
+
detector = pipeline("video-classification", model="tayyabimam/Deepfake")
|
37 |
+
|
38 |
+
# Analyze a video
|
39 |
+
result = detector("path/to/video.mp4")
|
40 |
+
print(result)
|
41 |
+
```
|
42 |
+
|
43 |
+
## API Usage
|
44 |
+
|
45 |
+
You can also use this model through the Hugging Face Inference API:
|
46 |
+
|
47 |
+
```python
|
48 |
+
import requests
|
49 |
+
|
50 |
+
API_URL = "https://api-inference.huggingface.co/models/tayyabimam/Deepfake"
|
51 |
+
headers = {"Authorization": "Bearer YOUR_API_TOKEN"}
|
52 |
+
|
53 |
+
def query(video_path):
|
54 |
+
with open(video_path, "rb") as f:
|
55 |
+
data = f.read()
|
56 |
+
response = requests.post(API_URL, headers=headers, data=data)
|
57 |
+
return response.json()
|
58 |
+
|
59 |
+
result = query("path/to/video.mp4")
|
60 |
+
print(result)
|
61 |
+
```
|
62 |
+
|
63 |
+
## Model Architecture
|
64 |
+
|
65 |
+
The model uses a ResNext50 backbone to extract features from video frames, followed by an LSTM to capture temporal relationships between frames. This architecture is particularly effective for detecting manipulation artifacts that appear across multiple frames in deepfake videos.
|
66 |
+
|
67 |
+
## Limitations
|
68 |
+
|
69 |
+
- The model works best with videos that include human faces
|
70 |
+
- Performance may vary with different video qualities and resolutions
|
71 |
+
- The model is designed for 20-frame sequences for optimal performance
|
config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DeepFakeDetector"
|
4 |
+
],
|
5 |
+
"model_type": "deepfake_detector",
|
6 |
+
"hidden_size": 512,
|
7 |
+
"num_hidden_layers": 2,
|
8 |
+
"num_attention_heads": 8,
|
9 |
+
"intermediate_size": 2048,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"attention_probs_dropout_prob": 0.1,
|
13 |
+
"max_position_embeddings": 20,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"layer_norm_eps": 1e-12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"backbone": "resnext50",
|
18 |
+
"sequence_length": 20,
|
19 |
+
"num_classes": 2,
|
20 |
+
"pipeline_tag": "video-classification"
|
21 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoProcessor, AutoModelForVideoClassification
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
+
|
11 |
+
# Import the model definition
|
12 |
+
from modeling import DeepFakeDetector, load_model
|
13 |
+
|
14 |
+
# Constants
|
15 |
+
im_size = 112
|
16 |
+
mean = [0.485, 0.456, 0.406]
|
17 |
+
std = [0.229, 0.224, 0.225]
|
18 |
+
|
19 |
+
def preprocess_frame(frame):
|
20 |
+
# Convert to PIL Image if it's a numpy array
|
21 |
+
if isinstance(frame, np.ndarray):
|
22 |
+
frame = Image.fromarray(frame)
|
23 |
+
|
24 |
+
# Resize
|
25 |
+
frame = frame.resize((im_size, im_size))
|
26 |
+
|
27 |
+
# Convert to tensor
|
28 |
+
frame = np.array(frame).astype(np.float32) / 255.0
|
29 |
+
frame = (frame - np.array(mean)) / np.array(std)
|
30 |
+
frame = frame.transpose(2, 0, 1) # HWC -> CHW
|
31 |
+
frame = torch.tensor(frame, dtype=torch.float32)
|
32 |
+
|
33 |
+
return frame
|
34 |
+
|
35 |
+
def inference(model_inputs):
|
36 |
+
# Load the model
|
37 |
+
model = load_model()
|
38 |
+
|
39 |
+
# Process inputs
|
40 |
+
if "frames" in model_inputs:
|
41 |
+
# Process frames from base64
|
42 |
+
frames = []
|
43 |
+
for frame_b64 in model_inputs["frames"]:
|
44 |
+
img_data = base64.b64decode(frame_b64)
|
45 |
+
frame = Image.open(io.BytesIO(img_data))
|
46 |
+
frame = np.array(frame)
|
47 |
+
frames.append(preprocess_frame(frame))
|
48 |
+
|
49 |
+
# Stack frames
|
50 |
+
frames = torch.stack(frames)
|
51 |
+
frames = frames.unsqueeze(0) # Add batch dimension
|
52 |
+
|
53 |
+
# Run inference
|
54 |
+
with torch.no_grad():
|
55 |
+
_, outputs = model(frames)
|
56 |
+
probs = F.softmax(outputs, dim=1).cpu().numpy()[0]
|
57 |
+
|
58 |
+
# Get prediction (0: fake, 1: real)
|
59 |
+
prediction = int(np.argmax(probs))
|
60 |
+
confidence = float(probs[prediction]) * 100
|
61 |
+
|
62 |
+
return {
|
63 |
+
"prediction": "REAL" if prediction == 1 else "FAKE",
|
64 |
+
"confidence": round(confidence, 1),
|
65 |
+
"prediction_code": prediction
|
66 |
+
}
|
67 |
+
|
68 |
+
# Default response if no valid input
|
69 |
+
return {"error": "Invalid input format"}
|
model_87_acc_20_frames_final_data.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f1b48a51886f487e7ce9880a59ddf6c44e7da0d5b5265ed772871e1e9b1c578
|
3 |
+
size 226547455
|
modeling.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.models as models
|
5 |
+
|
6 |
+
class DeepFakeDetector(nn.Module):
|
7 |
+
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
|
8 |
+
super(DeepFakeDetector, self).__init__()
|
9 |
+
model = models.resnext50_32x4d(pretrained=True)
|
10 |
+
self.model = nn.Sequential(*list(model.children())[:-2])
|
11 |
+
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
|
12 |
+
self.relu = nn.LeakyReLU()
|
13 |
+
self.dp = nn.Dropout(0.4)
|
14 |
+
self.linear1 = nn.Linear(2048, num_classes)
|
15 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
batch_size, seq_length, c, h, w = x.shape
|
19 |
+
x = x.view(batch_size * seq_length, c, h, w)
|
20 |
+
fmap = self.model(x)
|
21 |
+
x = self.avgpool(fmap)
|
22 |
+
x = x.view(batch_size, seq_length, 2048)
|
23 |
+
x_lstm, _ = self.lstm(x, None)
|
24 |
+
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
|
25 |
+
|
26 |
+
def load_model():
|
27 |
+
model = DeepFakeDetector(2)
|
28 |
+
model.load_state_dict(torch.load("model_87_acc_20_frames_final_data.pt", map_location=torch.device('cpu')))
|
29 |
+
model.eval()
|
30 |
+
return model
|
modeling_deepfake.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
from transformers import PreTrainedModel
|
5 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
6 |
+
|
7 |
+
class DeepFakeDetectorConfig:
|
8 |
+
"""Configuration class for DeepFakeDetector."""
|
9 |
+
|
10 |
+
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048,
|
11 |
+
bidirectional=False, sequence_length=20, im_size=112, **kwargs):
|
12 |
+
self.num_classes = num_classes
|
13 |
+
self.latent_dim = latent_dim
|
14 |
+
self.lstm_layers = lstm_layers
|
15 |
+
self.hidden_dim = hidden_dim
|
16 |
+
self.bidirectional = bidirectional
|
17 |
+
self.sequence_length = sequence_length
|
18 |
+
self.im_size = im_size
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def from_dict(cls, config_dict):
|
22 |
+
"""Create a configuration from a dictionary."""
|
23 |
+
return cls(**config_dict)
|
24 |
+
|
25 |
+
def to_dict(self):
|
26 |
+
"""Convert configuration to a dictionary."""
|
27 |
+
return {
|
28 |
+
"num_classes": self.num_classes,
|
29 |
+
"latent_dim": self.latent_dim,
|
30 |
+
"lstm_layers": self.lstm_layers,
|
31 |
+
"hidden_dim": self.hidden_dim,
|
32 |
+
"bidirectional": self.bidirectional,
|
33 |
+
"sequence_length": self.sequence_length,
|
34 |
+
"im_size": self.im_size
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
class DeepFakeDetectorModel(PreTrainedModel):
|
39 |
+
"""DeepFake detection model using ResNext50 and LSTM."""
|
40 |
+
|
41 |
+
config_class = DeepFakeDetectorConfig
|
42 |
+
|
43 |
+
def __init__(self, config):
|
44 |
+
super().__init__(config)
|
45 |
+
self.num_classes = config.num_classes
|
46 |
+
self.latent_dim = config.latent_dim
|
47 |
+
self.lstm_layers = config.lstm_layers
|
48 |
+
self.hidden_dim = config.hidden_dim
|
49 |
+
self.bidirectional = config.bidirectional
|
50 |
+
|
51 |
+
# Initialize ResNext50 backbone
|
52 |
+
resnext = models.resnext50_32x4d(pretrained=True)
|
53 |
+
self.backbone = nn.Sequential(*list(resnext.children())[:-2])
|
54 |
+
|
55 |
+
# Initialize LSTM
|
56 |
+
self.lstm = nn.LSTM(
|
57 |
+
self.latent_dim,
|
58 |
+
self.hidden_dim,
|
59 |
+
self.lstm_layers,
|
60 |
+
bidirectional=self.bidirectional
|
61 |
+
)
|
62 |
+
|
63 |
+
# Additional layers
|
64 |
+
self.relu = nn.LeakyReLU()
|
65 |
+
self.dropout = nn.Dropout(0.4)
|
66 |
+
self.classifier = nn.Linear(self.hidden_dim, self.num_classes)
|
67 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
68 |
+
|
69 |
+
def forward(self, x, labels=None):
|
70 |
+
"""
|
71 |
+
Forward pass of the model.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
x: Input tensor of shape (batch_size, sequence_length, channels, height, width)
|
75 |
+
labels: Optional labels for computing loss
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
SequenceClassifierOutput: Model outputs including loss and logits
|
79 |
+
"""
|
80 |
+
batch_size, seq_length, c, h, w = x.shape
|
81 |
+
|
82 |
+
# Reshape for ResNext processing
|
83 |
+
x = x.view(batch_size * seq_length, c, h, w)
|
84 |
+
|
85 |
+
# Extract features using ResNext
|
86 |
+
features = self.backbone(x)
|
87 |
+
|
88 |
+
# Apply average pooling
|
89 |
+
pooled = self.avgpool(features)
|
90 |
+
|
91 |
+
# Reshape for LSTM processing
|
92 |
+
pooled = pooled.view(batch_size, seq_length, self.latent_dim)
|
93 |
+
|
94 |
+
# Process with LSTM
|
95 |
+
lstm_out, _ = self.lstm(pooled, None)
|
96 |
+
|
97 |
+
# Get the final time step output
|
98 |
+
final = lstm_out[:, -1, :]
|
99 |
+
|
100 |
+
# Apply dropout and classification
|
101 |
+
logits = self.classifier(self.dropout(final))
|
102 |
+
|
103 |
+
# Compute loss if labels are provided
|
104 |
+
loss = None
|
105 |
+
if labels is not None:
|
106 |
+
loss_fct = nn.CrossEntropyLoss()
|
107 |
+
loss = loss_fct(logits, labels)
|
108 |
+
|
109 |
+
return SequenceClassifierOutput(
|
110 |
+
loss=loss,
|
111 |
+
logits=logits
|
112 |
+
)
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
116 |
+
"""Load a pretrained model."""
|
117 |
+
# Load config
|
118 |
+
config_dict = kwargs.pop("config", None)
|
119 |
+
if config_dict is None:
|
120 |
+
config_dict = {
|
121 |
+
"num_classes": 2,
|
122 |
+
"latent_dim": 2048,
|
123 |
+
"lstm_layers": 1,
|
124 |
+
"hidden_dim": 2048,
|
125 |
+
"bidirectional": False,
|
126 |
+
"sequence_length": 20,
|
127 |
+
"im_size": 112
|
128 |
+
}
|
129 |
+
|
130 |
+
config = DeepFakeDetectorConfig.from_dict(config_dict)
|
131 |
+
|
132 |
+
# Create model
|
133 |
+
model = cls(config, *model_args, **kwargs)
|
134 |
+
|
135 |
+
# Load weights
|
136 |
+
state_dict = torch.load(pretrained_model_name_or_path, map_location="cpu")
|
137 |
+
model.load_state_dict(state_dict)
|
138 |
+
|
139 |
+
return model
|
processor_deepfake.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import face_recognition
|
6 |
+
from transformers import ProcessorMixin
|
7 |
+
|
8 |
+
class DeepFakeProcessor(ProcessorMixin):
|
9 |
+
"""Processor for DeepFake detection model."""
|
10 |
+
|
11 |
+
def __init__(self, im_size=112, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
12 |
+
self.im_size = im_size
|
13 |
+
self.mean = mean
|
14 |
+
self.std = std
|
15 |
+
|
16 |
+
def preprocess_frame(self, frame):
|
17 |
+
"""
|
18 |
+
Preprocess a single frame.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
frame: PIL Image or numpy array
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
torch.Tensor: Processed frame tensor
|
25 |
+
"""
|
26 |
+
# Convert to PIL Image if it's a numpy array
|
27 |
+
if isinstance(frame, np.ndarray):
|
28 |
+
frame = Image.fromarray(frame)
|
29 |
+
|
30 |
+
# Resize
|
31 |
+
frame = frame.resize((self.im_size, self.im_size))
|
32 |
+
|
33 |
+
# Convert to tensor
|
34 |
+
frame = np.array(frame).astype(np.float32) / 255.0
|
35 |
+
frame = (frame - np.array(self.mean)) / np.array(self.std)
|
36 |
+
frame = frame.transpose(2, 0, 1) # HWC -> CHW
|
37 |
+
frame = torch.tensor(frame, dtype=torch.float32)
|
38 |
+
|
39 |
+
return frame
|
40 |
+
|
41 |
+
def extract_frames(self, video_path, sequence_length=20, extract_faces=True):
|
42 |
+
"""
|
43 |
+
Extract frames from a video file.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
video_path: Path to the video file
|
47 |
+
sequence_length: Number of frames to extract
|
48 |
+
extract_faces: Whether to extract faces from frames
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
torch.Tensor: Tensor of shape (1, sequence_length, 3, im_size, im_size)
|
52 |
+
"""
|
53 |
+
frames = []
|
54 |
+
|
55 |
+
# Open video file
|
56 |
+
vidObj = cv2.VideoCapture(video_path)
|
57 |
+
|
58 |
+
# Calculate frame interval
|
59 |
+
total_frames = int(vidObj.get(cv2.CAP_PROP_FRAME_COUNT))
|
60 |
+
interval = max(1, total_frames // sequence_length)
|
61 |
+
|
62 |
+
# Extract frames
|
63 |
+
count = 0
|
64 |
+
success = True
|
65 |
+
|
66 |
+
while success and len(frames) < sequence_length:
|
67 |
+
success, image = vidObj.read()
|
68 |
+
|
69 |
+
if success and count % interval == 0:
|
70 |
+
# Convert BGR to RGB
|
71 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
72 |
+
|
73 |
+
# Extract face if requested
|
74 |
+
if extract_faces:
|
75 |
+
face_locations = face_recognition.face_locations(image)
|
76 |
+
if face_locations:
|
77 |
+
top, right, bottom, left = face_locations[0]
|
78 |
+
# Add padding
|
79 |
+
padding = 40
|
80 |
+
h, w = image.shape[:2]
|
81 |
+
top = max(0, top - padding)
|
82 |
+
bottom = min(h, bottom + padding)
|
83 |
+
left = max(0, left - padding)
|
84 |
+
right = min(w, right + padding)
|
85 |
+
image = image[top:bottom, left:right]
|
86 |
+
|
87 |
+
# Preprocess frame
|
88 |
+
processed_frame = self.preprocess_frame(image)
|
89 |
+
frames.append(processed_frame)
|
90 |
+
|
91 |
+
count += 1
|
92 |
+
|
93 |
+
# If we couldn't extract enough frames, duplicate the last one
|
94 |
+
while len(frames) < sequence_length:
|
95 |
+
frames.append(frames[-1] if frames else torch.zeros((3, self.im_size, self.im_size)))
|
96 |
+
|
97 |
+
# Stack frames
|
98 |
+
frames = torch.stack(frames)
|
99 |
+
|
100 |
+
# Add batch dimension
|
101 |
+
frames = frames.unsqueeze(0)
|
102 |
+
|
103 |
+
return frames
|
104 |
+
|
105 |
+
def __call__(self, video_path=None, frames=None, return_tensors="pt", **kwargs):
|
106 |
+
"""
|
107 |
+
Process video for the model.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
video_path: Path to the video file
|
111 |
+
frames: List of frames (PIL Images or numpy arrays)
|
112 |
+
return_tensors: Return format (only "pt" supported)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
dict: Processed inputs for the model
|
116 |
+
"""
|
117 |
+
if return_tensors != "pt":
|
118 |
+
raise ValueError("Only 'pt' return tensors are supported")
|
119 |
+
|
120 |
+
if video_path is not None:
|
121 |
+
# Extract frames from video
|
122 |
+
sequence_length = kwargs.get("sequence_length", 20)
|
123 |
+
extract_faces = kwargs.get("extract_faces", True)
|
124 |
+
processed_frames = self.extract_frames(
|
125 |
+
video_path,
|
126 |
+
sequence_length=sequence_length,
|
127 |
+
extract_faces=extract_faces
|
128 |
+
)
|
129 |
+
elif frames is not None:
|
130 |
+
# Process provided frames
|
131 |
+
processed_frames = torch.stack([self.preprocess_frame(frame) for frame in frames])
|
132 |
+
processed_frames = processed_frames.unsqueeze(0) # Add batch dimension
|
133 |
+
else:
|
134 |
+
raise ValueError("Either video_path or frames must be provided")
|
135 |
+
|
136 |
+
return {"pixel_values": processed_frames}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
transformers>=4.30.0
|
4 |
+
numpy>=1.24.0
|
5 |
+
Pillow>=9.0.0
|
6 |
+
opencv-python>=4.7.0
|
7 |
+
face-recognition>=1.3.0
|
8 |
+
safetensors>=0.3.1
|