|
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from PIL import Image
|
|
import io, torch
|
|
from collections import Counter
|
|
|
|
from models import ModelA, ModelB, ModelC, transform_small, transform_large
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
device = torch.device('cpu')
|
|
modelA = ModelA();
|
|
modelA.load_state_dict(torch.load('modelA.pth', map_location=device,weights_only=True))
|
|
modelA.eval()
|
|
|
|
modelB = ModelB()
|
|
modelB.load_state_dict(torch.load('modelB.pth', map_location=device,weights_only=True))
|
|
modelB.eval()
|
|
modelC = ModelC()
|
|
modelC.load_state_dict(torch.load('modelC.pth', map_location=device,weights_only=True))
|
|
modelC.eval()
|
|
|
|
@app.post("/predict/")
|
|
async def predict(file: UploadFile = File(...)):
|
|
|
|
data = await file.read()
|
|
img = Image.open(io.BytesIO(data)).convert('RGB')
|
|
|
|
|
|
t_small = transform_small(img).unsqueeze(0)
|
|
t_large = transform_large(img).unsqueeze(0)
|
|
|
|
|
|
votes = []
|
|
with torch.no_grad():
|
|
for model, inp in [(modelA, t_small), (modelB, t_small), (modelC, t_large)]:
|
|
out = model(inp)
|
|
_, pred = out.max(1)
|
|
votes.append(int(pred.item()))
|
|
|
|
|
|
vote_count = Counter(votes)
|
|
final_label = vote_count.most_common(1)[0][0]
|
|
confidence = vote_count[final_label] / len(votes)
|
|
|
|
return {
|
|
"prediction": "Real" if final_label == 1 else "Fake",
|
|
"confidence": f"{confidence*100:.1f}%"
|
|
}
|
|
|