import torch | |
from transformers import ViTFeatureExtractor | |
from config import UNTRAINED | |
feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED) | |
def predict(model, image): | |
inputs = feature_extractor(image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# model predicts one of the 1000 ImageNet classes | |
predicted_label = logits.argmax(-1).item() | |
return model.config.id2label[str(predicted_label)] |