okeowo1014's picture
Create main.py
c6896d2 verified
raw
history blame
1.27 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from torchvision import models, transforms
from PIL import Image
import torch
import io
app = FastAPI()
# Load the pre-trained VGG16 model
model = models.vgg16()
num_features_in = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(num_features_in, 1)
model.load_state_dict(torch.load('cat_dog_classifier.pt'))
model.eval()
def preprocess_image(image):
img_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = img_transform(image).unsqueeze(0) # Add a batch dimension
return img
@app.post("/predict/")
async def predict_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
prediction = torch.sigmoid(output.squeeze()).item()
predicted_class = "Dog" if prediction > 0.5 else "Cat"
return {"class": predicted_class}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))