Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from model import NeuralNetwork | |
| import json | |
| import os | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def pokemon_classifier(inp): | |
| model = NeuralNetwork() | |
| model.load_state_dict(torch.load("model_best.pt", map_location=torch.device(device))) | |
| model.eval() | |
| with open('labels.json') as f: | |
| labels = json.load(f) | |
| x = TF.to_tensor(inp) | |
| x = TF.resize(x, (64, 64), antialias=True) | |
| x = x.to(device) | |
| x = x.unsqueeze(0) | |
| with torch.no_grad(): | |
| y_pred = model(x) | |
| pokemon = torch.argmax(y_pred, dim=1).item() | |
| return labels[str(pokemon)] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gen 1 Pokemon classifier") | |
| with gr.Column(scale=4): | |
| inp = gr.Image(type="pil") | |
| out = gr.Textbox(label='Pokemon') | |
| gr.Examples( | |
| examples=[ | |
| os.path.join(os.path.dirname(__file__), "images/Aerodactyl.jpg"), | |
| os.path.join(os.path.dirname(__file__), "images/Bulbasaur.jpg"), | |
| os.path.join(os.path.dirname(__file__), "images/Charizard.jpg") | |
| ], | |
| inputs=inp, | |
| outputs=out, | |
| fn=pokemon_classifier, | |
| cache_examples=False | |
| ) | |
| btn = gr.Button("Run") | |
| btn.click(fn=pokemon_classifier, inputs=inp, outputs=out) | |
| demo.launch() |