This model doesn't inherit huggingface/transformers so it needs to be downloaded
wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/best_model_89.pth
Inference Code
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import torch.nn.functional as F
#load id_2_class.json
import json
id_2_class = {"0": "back", "1": "Briefcase", "2": "Call", "3": "Camera", "4": "Circle", "5": "Cloud", "6": "delete", "7": "Down", "8": "edit", "9": "Export", "10": "Face", "11": "Folder", "12": "Globe", "13": "Google", "14": "Heart", "15": "Home", "16": "Image", "17": "Import", "18": "Info", "19": "Link", "20": "Location", "21": "Mail", "22": "menu", "23": "Merge", "24": "Message", "25": "Microphone", "26": "more", "27": "Music", "28": "Mute", "29": "Person", "30": "Phone", "31": "plus", "32": "QRCODE", "33": "Refresh", "34": "search", "35": "settings", "36": "share", "37": "Star", "38": "Tick", "39": "Up", "40": "vidCam", "41": "Video", "42": "Volume"}
#make class_2_id dict
class_2_id = {}
for key, value in id_2_class.items():
class_2_id[value] = key
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])
class MaxViT(nn.Module):
def __init__(self):
super(MaxViT, self).__init__()
model = models.maxvit_t(weights="DEFAULT")
num_ftrs = model.classifier[5].in_features
model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id))
self.model = model
def forward(self, x):
return self.model(x)
# Instantiate the model
model = MaxViT()
model.load_state_dict(torch.load('best_model_89.pth'))
model.eval()
def inference(image_path, CONFIDENT_THRESHOLD=None):
img = Image.open(image_path).convert("L").convert("RGB")
img = test_transform(img)
img = img.unsqueeze(0)
with torch.no_grad():
output = F.softmax(model(img), dim=1)
confidence, predicted = torch.max(output.data, 1)
if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD:
return "UNKNOWN_CLASS", confidence.item()
return id_2_class[str(predicted.item())], confidence.item()
inference("images/7820.jpg", 0.9) #0.9 should be good enough
Training
Check the repo
Dataset
Trained on Lancelot53/android_icon_dataset
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
HF Inference deployability: The model has no library tag.