File size: 1,712 Bytes
bb5cad1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from torchvision import transforms
from cnn import CNN
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
model = CNN()
model.load_state_dict(torch.load("cnn/model.pt", map_location=device))
model.to(device)
model.eval()
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
def predict(image):
if image is None:
return {}
image = Image.fromarray(image).convert("RGB")
image_tensor = transform(image)
image_tensor = (image_tensor).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
return {classes[i]: float(probabilities[i]) for i in range(len(classes))}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=10),
title="CNN Classifier",
description="Upload an image to classify it into one of 10 CIFAR-10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck",
examples=[
["examples/1.png"],
["examples/2.png"],
["examples/3.png"],
["examples/4.png"],
["examples/5.png"],
["examples/6.png"],
["examples/7.png"],
],
)
if __name__ == "__main__":
demo.launch(share=True, pwa=True)
|