Deep-fake-detection / modeling.py
Naman712's picture
Upload folder using huggingface_hub
e16b8cb verified
import torch
import torch.nn as nn
import torchvision.models as models
class DeepFakeDetector(nn.Module):
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
super(DeepFakeDetector, self).__init__()
model = models.resnext50_32x4d(pretrained=True)
self.model = nn.Sequential(*list(model.children())[:-2])
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
self.relu = nn.LeakyReLU()
self.dp = nn.Dropout(0.4)
self.linear1 = nn.Linear(2048, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
fmap = self.model(x)
x = self.avgpool(fmap)
x = x.view(batch_size, seq_length, 2048)
x_lstm, _ = self.lstm(x, None)
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
def load_model():
model = DeepFakeDetector(2)
model.load_state_dict(torch.load("model_87_acc_20_frames_final_data.pt", map_location=torch.device('cpu')))
model.eval()
return model