| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from cnn import CNN | |
| device = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| classes = [ | |
| "airplane", | |
| "automobile", | |
| "bird", | |
| "cat", | |
| "deer", | |
| "dog", | |
| "frog", | |
| "horse", | |
| "ship", | |
| "truck", | |
| ] | |
| model = CNN() | |
| model.load_state_dict(torch.load("cnn/model.pt", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| def predict(image): | |
| if image is None: | |
| return {} | |
| image = Image.fromarray(image).convert("RGB") | |
| image_tensor = transform(image) | |
| image_tensor = (image_tensor).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1)[0] | |
| return {classes[i]: float(probabilities[i]) for i in range(len(classes))} | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Label(num_top_classes=10), | |
| title="CNN Classifier", | |
| description="Upload an image to classify it into one of 10 CIFAR-10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck", | |
| examples=[ | |
| ["examples/1.png"], | |
| ["examples/2.png"], | |
| ["examples/3.png"], | |
| ["examples/4.png"], | |
| ["examples/5.png"], | |
| ["examples/6.png"], | |
| ["examples/7.png"], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, pwa=True) | |