--- license: cc-by-4.0 --- # This model doesn't inherit huggingface/transformers so it needs to be downloaded ``` wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/best_model_89.pth ``` # Inference Code ``` import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import torch.nn.functional as F #load id_2_class.json import json id_2_class = {"0": "back", "1": "Briefcase", "2": "Call", "3": "Camera", "4": "Circle", "5": "Cloud", "6": "delete", "7": "Down", "8": "edit", "9": "Export", "10": "Face", "11": "Folder", "12": "Globe", "13": "Google", "14": "Heart", "15": "Home", "16": "Image", "17": "Import", "18": "Info", "19": "Link", "20": "Location", "21": "Mail", "22": "menu", "23": "Merge", "24": "Message", "25": "Microphone", "26": "more", "27": "Music", "28": "Mute", "29": "Person", "30": "Phone", "31": "plus", "32": "QRCODE", "33": "Refresh", "34": "search", "35": "settings", "36": "share", "37": "Star", "38": "Tick", "39": "Up", "40": "vidCam", "41": "Video", "42": "Volume"} #make class_2_id dict class_2_id = {} for key, value in id_2_class.items(): class_2_id[value] = key test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) ]) class MaxViT(nn.Module): def __init__(self): super(MaxViT, self).__init__() model = models.maxvit_t(weights="DEFAULT") num_ftrs = model.classifier[5].in_features model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id)) self.model = model def forward(self, x): return self.model(x) # Instantiate the model model = MaxViT() model.load_state_dict(torch.load('best_model_89.pth')) model.eval() def inference(image_path, CONFIDENT_THRESHOLD=None): img = Image.open(image_path).convert("L").convert("RGB") img = test_transform(img) img = img.unsqueeze(0) with torch.no_grad(): output = F.softmax(model(img), dim=1) confidence, predicted = torch.max(output.data, 1) if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD: return "UNKNOWN_CLASS", confidence.item() return id_2_class[str(predicted.item())], confidence.item() inference("images/7820.jpg", 0.9) #0.9 should be good enough ``` # Training Check the repo # Dataset Trained on Lancelot53/android_icon_dataset