|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.models as models
|
|
|
|
class DeepFakeDetector(nn.Module):
|
|
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
|
|
super(DeepFakeDetector, self).__init__()
|
|
model = models.resnext50_32x4d(pretrained=True)
|
|
self.model = nn.Sequential(*list(model.children())[:-2])
|
|
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
|
|
self.relu = nn.LeakyReLU()
|
|
self.dp = nn.Dropout(0.4)
|
|
self.linear1 = nn.Linear(2048, num_classes)
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
def forward(self, x):
|
|
batch_size, seq_length, c, h, w = x.shape
|
|
x = x.view(batch_size * seq_length, c, h, w)
|
|
fmap = self.model(x)
|
|
x = self.avgpool(fmap)
|
|
x = x.view(batch_size, seq_length, 2048)
|
|
x_lstm, _ = self.lstm(x, None)
|
|
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
|
|
|
|
def load_model():
|
|
model = DeepFakeDetector(2)
|
|
model.load_state_dict(torch.load("model_87_acc_20_frames_final_data.pt", map_location=torch.device('cpu')))
|
|
model.eval()
|
|
return model
|
|
|