import torch import torch.nn.functional as F import gradio as gr from PIL import Image from torchvision import transforms from cnn import CNN device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) classes = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] model = CNN() model.load_state_dict(torch.load("cnn/model.pt", map_location=device)) model.to(device) model.eval() transform = transforms.Compose( [ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) def predict(image): if image is None: return {} image = Image.fromarray(image).convert("RGB") image_tensor = transform(image) image_tensor = (image_tensor).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image_tensor) probabilities = F.softmax(outputs, dim=1)[0] return {classes[i]: float(probabilities[i]) for i in range(len(classes))} demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.Label(num_top_classes=10), title="CNN Classifier", description="Upload an image to classify it into one of 10 CIFAR-10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck", examples=[ ["examples/1.png"], ["examples/2.png"], ["examples/3.png"], ["examples/4.png"], ["examples/5.png"], ["examples/6.png"], ["examples/7.png"], ], ) if __name__ == "__main__": demo.launch(share=True, pwa=True)