tayyabimam commited on
Commit
b985dea
·
verified ·
1 Parent(s): 9097bc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -51
app.py CHANGED
@@ -1,51 +1,282 @@
1
- import os
2
- import sys
3
- from fastapi import FastAPI, Request
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.staticfiles import StaticFiles
6
- from fastapi.responses import FileResponse, HTMLResponse
7
- import uvicorn
8
-
9
- # Add server directory to path
10
- sys.path.insert(0, 'server')
11
-
12
- # Import the original app
13
- from server.app import app as server_app
14
-
15
- # Create the main app
16
- app = FastAPI()
17
-
18
- # Configure CORS
19
- app.add_middleware(
20
- CORSMiddleware,
21
- allow_origins=["*"],
22
- allow_credentials=True,
23
- allow_methods=["*"],
24
- allow_headers=["*"],
25
- )
26
-
27
- # Mount the original server app
28
- app.mount("/api", server_app)
29
-
30
- # Mount static directories
31
- app.mount("/uploaded_images", StaticFiles(directory="server/uploaded_images"), name="uploaded_images")
32
- app.mount("/static", StaticFiles(directory="server/static"), name="static")
33
- app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets")
34
-
35
- # Serve frontend
36
- @app.get("/{path:path}")
37
- async def serve_frontend(path: str):
38
- # First check if the path exists in the frontend dist
39
- if os.path.exists(f"frontend/dist/{path}"):
40
- return FileResponse(f"frontend/dist/{path}")
41
-
42
- # Otherwise return the index.html
43
- return FileResponse("frontend/dist/index.html")
44
-
45
- @app.get("/", response_class=HTMLResponse)
46
- async def root():
47
- return FileResponse("frontend/dist/index.html")
48
-
49
- if __name__ == "__main__":
50
- # Use port 7860 for Hugging Face Spaces
51
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from fastapi import FastAPI, Request, UploadFile, File, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
7
+ import uvicorn
8
+ import time
9
+ import shutil
10
+ import glob
11
+ import datetime
12
+ from random import choice
13
+ import torch
14
+ import torchvision
15
+ from torchvision import transforms
16
+ from torch import nn
17
+ import numpy as np
18
+ import cv2
19
+ import face_recognition
20
+ from PIL import Image as pImage
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib
23
+ matplotlib.use('Agg') # Use non-GUI backend for matplotlib
24
+ from typing import List
25
+ import base64
26
+ import io
27
+
28
+ app = FastAPI()
29
+
30
+ # Configure CORS
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # Create directories if they don't exist
40
+ os.makedirs("uploaded_images", exist_ok=True)
41
+ os.makedirs("static", exist_ok=True)
42
+ os.makedirs("uploaded_videos", exist_ok=True)
43
+ os.makedirs("models", exist_ok=True)
44
+
45
+ # Mount static files
46
+ app.mount("/uploaded_images", StaticFiles(directory="uploaded_images"), name="uploaded_images")
47
+ app.mount("/static", StaticFiles(directory="static"), name="static")
48
+ app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets")
49
+
50
+ # Configuration
51
+ im_size = 112
52
+ mean = [0.485, 0.456, 0.406]
53
+ std = [0.229, 0.224, 0.225]
54
+ sm = nn.Softmax(dim=1)
55
+ inv_normalize = transforms.Normalize(
56
+ mean=-1*np.divide(mean, std), std=np.divide([1, 1, 1], std))
57
+
58
+ train_transforms = transforms.Compose([
59
+ transforms.ToPILImage(),
60
+ transforms.Resize((im_size, im_size)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean, std)])
63
+
64
+ ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'gif', 'webm', 'avi', '3gp', 'wmv', 'flv', 'mkv'}
65
+
66
+ # Detects GPU in device
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+
69
+ class Model(nn.Module):
70
+ def __init__(self, num_classes, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
71
+ super(Model, self).__init__()
72
+ model = torchvision.models.resnext50_32x4d(weights=torchvision.models.ResNeXt50_32X4D_Weights.DEFAULT)
73
+ self.model = nn.Sequential(*list(model.children())[:-2])
74
+ self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
75
+ self.relu = nn.LeakyReLU()
76
+ self.dp = nn.Dropout(0.4)
77
+ self.linear1 = nn.Linear(2048, num_classes)
78
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
79
+
80
+ def forward(self, x):
81
+ batch_size, seq_length, c, h, w = x.shape
82
+ x = x.view(batch_size * seq_length, c, h, w)
83
+ fmap = self.model(x)
84
+ x = self.avgpool(fmap)
85
+ x = x.view(batch_size, seq_length, 2048)
86
+ x_lstm, _ = self.lstm(x, None)
87
+ return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
88
+
89
+ class ValidationDataset(torch.utils.data.Dataset):
90
+ def __init__(self, video_names, sequence_length=60, transform=None):
91
+ self.video_names = video_names
92
+ self.transform = transform
93
+ self.count = sequence_length
94
+
95
+ def __len__(self):
96
+ return len(self.video_names)
97
+
98
+ def __getitem__(self, idx):
99
+ video_path = self.video_names[idx]
100
+ frames = []
101
+ a = int(100/self.count)
102
+ first_frame = np.random.randint(0, a)
103
+ for i, frame in enumerate(self.frame_extract(video_path)):
104
+ faces = face_recognition.face_locations(frame)
105
+ try:
106
+ top, right, bottom, left = faces[0]
107
+ frame = frame[top:bottom, left:right, :]
108
+ except:
109
+ pass
110
+ frames.append(self.transform(frame))
111
+ if (len(frames) == self.count):
112
+ break
113
+ frames = torch.stack(frames)
114
+ frames = frames[:self.count]
115
+ return frames.unsqueeze(0) # Shape: (1, seq_len, C, H, W)
116
+
117
+ def frame_extract(self, path):
118
+ vidObj = cv2.VideoCapture(path)
119
+ success = 1
120
+ while success:
121
+ success, image = vidObj.read()
122
+ if success:
123
+ yield image
124
+
125
+ def allowed_video_file(filename):
126
+ return filename.split('.')[-1].lower() in ALLOWED_VIDEO_EXTENSIONS
127
+
128
+ def load_model(sequence_length=20):
129
+ """Load the model from Hugging Face Hub if not available locally."""
130
+ model_path = os.path.join("models", "model.pt")
131
+
132
+ if not os.path.exists(model_path):
133
+ try:
134
+ from huggingface_hub import hf_hub_download
135
+ model_path = hf_hub_download(repo_id="tayyabimam/Deepfake",
136
+ filename="model.pt",
137
+ local_dir="models")
138
+ except Exception as e:
139
+ raise Exception(f"Failed to download model: {str(e)}")
140
+
141
+ # Load model
142
+ model = Model(2).to(device)
143
+ model.load_state_dict(torch.load(model_path, map_location=device))
144
+ model.eval()
145
+ return model
146
+
147
+ def im_convert(tensor, video_file_name=""):
148
+ """Convert tensor to image for visualization."""
149
+ image = tensor.to("cpu").clone().detach()
150
+ image = image.squeeze()
151
+ image = inv_normalize(image)
152
+ image = image.numpy()
153
+ image = image.transpose(1, 2, 0)
154
+ image = image.clip(0, 1)
155
+ return image
156
+
157
+ def generate_gradcam_heatmap(model, img, video_file_name=""):
158
+ """Generate GradCAM heatmap showing areas of focus for deepfake detection."""
159
+ # Forward pass
160
+ fmap, logits = model(img)
161
+
162
+ # Softmax on logits
163
+ logits_softmax = sm(logits)
164
+ confidence, prediction = torch.max(logits_softmax, 1)
165
+ confidence_val = confidence.item() * 100
166
+ pred_idx = prediction.item()
167
+
168
+ # Get weights and feature maps
169
+ weight_softmax = model.linear1.weight.detach().cpu().numpy()
170
+ fmap_last = fmap[-1].detach().cpu().numpy()
171
+ nc, h, w = fmap_last.shape
172
+ fmap_reshaped = fmap_last.reshape(nc, h*w)
173
+
174
+ # Compute GradCAM heatmap
175
+ heatmap_raw = np.dot(fmap_reshaped.T, weight_softmax[pred_idx, :].T)
176
+ heatmap_raw -= heatmap_raw.min()
177
+ heatmap_raw /= heatmap_raw.max()
178
+ heatmap_img = np.uint8(255 * heatmap_raw.reshape(h, w))
179
+
180
+ # Resize heatmap to model input size
181
+ heatmap_resized = cv2.resize(heatmap_img, (im_size, im_size))
182
+ heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
183
+
184
+ # Convert original image tensor to numpy
185
+ original_img = im_convert(img[:, -1, :, :, :])
186
+ original_img_uint8 = (original_img * 255).astype(np.uint8)
187
+
188
+ # Overlay heatmap on original image
189
+ overlay = cv2.addWeighted(original_img_uint8, 0.6, heatmap_colored, 0.4, 0)
190
+
191
+ # Save overlayed image
192
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
193
+ result_filename = f"result_{timestamp}.jpg"
194
+ save_path = os.path.join("static", result_filename)
195
+ plt.figure(figsize=(10, 5))
196
+
197
+ # Plot original and heatmap
198
+ plt.subplot(1, 2, 1)
199
+ plt.imshow(original_img)
200
+ plt.title("Original")
201
+ plt.axis('off')
202
+
203
+ plt.subplot(1, 2, 2)
204
+ plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
205
+ plt.title(f"{'FAKE' if pred_idx == 1 else 'REAL'} ({confidence_val:.2f}%)")
206
+ plt.axis('off')
207
+
208
+ plt.tight_layout()
209
+ plt.savefig(save_path)
210
+ plt.close()
211
+
212
+ return {
213
+ "prediction": "FAKE" if pred_idx == 1 else "REAL",
214
+ "confidence": confidence_val,
215
+ "heatmap_url": f"/static/{result_filename}",
216
+ "original_filename": video_file_name
217
+ }
218
+
219
+ def predict_with_gradcam(model, img, video_file_name=""):
220
+ """Predict with GradCAM visualization."""
221
+ return generate_gradcam_heatmap(model, img, video_file_name)
222
+
223
+ @app.post("/api/upload-video")
224
+ async def api_upload_video(file: UploadFile = File(...), sequence_length: int = 20):
225
+ """API endpoint for video upload and analysis."""
226
+ if not allowed_video_file(file.filename):
227
+ raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: mp4, gif, webm, avi, 3gp, wmv, flv, mkv")
228
+
229
+ # Save uploaded file
230
+ temp_file = f"uploaded_videos/{file.filename}"
231
+ with open(temp_file, "wb") as buffer:
232
+ shutil.copyfileobj(file.file, buffer)
233
+
234
+ try:
235
+ # Process the video
236
+ result = process_video(temp_file, sequence_length)
237
+ return result
238
+ except Exception as e:
239
+ raise HTTPException(status_code=500, detail=str(e))
240
+
241
+ def process_video(video_file, sequence_length):
242
+ """Process video for deepfake detection."""
243
+ # Load model
244
+ model = load_model(sequence_length)
245
+
246
+ # Prepare dataset
247
+ test_dataset = ValidationDataset(video_names=[video_file],
248
+ sequence_length=sequence_length,
249
+ transform=train_transforms)
250
+
251
+ # Get frames
252
+ frames = test_dataset[0]
253
+ frames = frames.to(device)
254
+
255
+ # Make prediction with GradCAM
256
+ result = predict_with_gradcam(model, frames, os.path.basename(video_file))
257
+
258
+ return result
259
+
260
+ @app.get("/{path:path}")
261
+ async def serve_frontend(path: str):
262
+ # First check if the path exists in the frontend dist
263
+ if os.path.exists(f"frontend/dist/{path}"):
264
+ return FileResponse(f"frontend/dist/{path}")
265
+
266
+ # Otherwise return the index.html
267
+ return FileResponse("frontend/dist/index.html")
268
+
269
+ @app.get("/", response_class=HTMLResponse)
270
+ async def root():
271
+ return FileResponse("frontend/dist/index.html")
272
+
273
+ @app.get("/api")
274
+ async def api_root():
275
+ """Root endpoint with API documentation."""
276
+ return {
277
+ "message": "Welcome to DeepSight DeepFake Detection API",
278
+ "usage": "POST /api/upload-video with a video file to detect deepfakes"
279
+ }
280
+
281
+ if __name__ == "__main__":
282
+ uvicorn.run(app, host="0.0.0.0", port=7860)