Saqib772's picture
deep fake with cnn vite and ensembling
dd98f48 verified
# app.py
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io, torch
from collections import Counter
from models import ModelA, ModelB, ModelC, transform_small, transform_large
# 1. spin up FastAPI
app = FastAPI()
# 2. load your saved weights
device = torch.device('cpu')
modelA = ModelA();
modelA.load_state_dict(torch.load('modelA.pth', map_location=device,weights_only=True))
modelA.eval()
modelB = ModelB()
modelB.load_state_dict(torch.load('modelB.pth', map_location=device,weights_only=True))
modelB.eval()
modelC = ModelC()
modelC.load_state_dict(torch.load('modelC.pth', map_location=device,weights_only=True))
modelC.eval()
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# read image bytes → PIL
data = await file.read()
img = Image.open(io.BytesIO(data)).convert('RGB')
# preprocess
t_small = transform_small(img).unsqueeze(0) # for A & B
t_large = transform_large(img).unsqueeze(0) # for C
# run inference
votes = []
with torch.no_grad():
for model, inp in [(modelA, t_small), (modelB, t_small), (modelC, t_large)]:
out = model(inp)
_, pred = out.max(1)
votes.append(int(pred.item()))
# majority vote + confidence
vote_count = Counter(votes)
final_label = vote_count.most_common(1)[0][0]
confidence = vote_count[final_label] / len(votes)
return {
"prediction": "Real" if final_label == 1 else "Fake",
"confidence": f"{confidence*100:.1f}%"
}