food_classifier / app.py
Jonas Mullerio
fix?
88f2a00
from torchvision import transforms
from PIL import Image
import gradio as gr
from transformers import AutoModelForImageClassification
import torch
import heapq
model = AutoModelForImageClassification.from_pretrained('Mullerjo/food-101-finetuned-model')
#h
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
label_to_class = {
"apple_pie": 0,
"baby_back_ribs": 1,
"baklava": 2,
"beef_carpaccio": 3,
"beef_tartare": 4,
"beet_salad": 5,
"beignets": 6,
"bibimbap": 7,
"bread_pudding": 8,
"breakfast_burrito": 9,
"bruschetta": 10,
"caesar_salad": 11,
"cannoli": 12,
"caprese_salad": 13,
"carrot_cake": 14,
"ceviche": 15,
"cheesecake": 16,
"cheese_plate": 17,
"chicken_curry": 18,
"chicken_quesadilla": 19,
"chicken_wings": 20,
"chocolate_cake": 21,
"chocolate_mousse": 22,
"churros": 23,
"clam_chowder": 24,
"club_sandwich": 25,
"crab_cakes": 26,
"creme_brulee": 27,
"croque_madame": 28,
"cup_cakes": 29,
"deviled_eggs": 30,
"donuts": 31,
"dumplings": 32,
"edamame": 33,
"eggs_benedict": 34,
"escargots": 35,
"falafel": 36,
"filet_mignon": 37,
"fish_and_chips": 38,
"foie_gras": 39,
"french_fries": 40,
"french_onion_soup": 41,
"french_toast": 42,
"fried_calamari": 43,
"fried_rice": 44,
"frozen_yogurt": 45,
"garlic_bread": 46,
"gnocchi": 47,
"greek_salad": 48,
"grilled_cheese_sandwich": 49,
"grilled_salmon": 50,
"guacamole": 51,
"gyoza": 52,
"hamburger": 53,
"hot_and_sour_soup": 54,
"hot_dog": 55,
"huevos_rancheros": 56,
"hummus": 57,
"ice_cream": 58,
"lasagna": 59,
"lobster_bisque": 60,
"lobster_roll_sandwich": 61,
"macaroni_and_cheese": 62,
"macarons": 63,
"miso_soup": 64,
"mussels": 65,
"nachos": 66,
"omelette": 67,
"onion_rings": 68,
"oysters": 69,
"pad_thai": 70,
"paella": 71,
"pancakes": 72,
"panna_cotta": 73,
"peking_duck": 74,
"pho": 75,
"pizza": 76,
"pork_chop": 77,
"poutine": 78,
"prime_rib": 79,
"pulled_pork_sandwich": 80,
"ramen": 81,
"ravioli": 82,
"red_velvet_cake": 83,
"risotto": 84,
"samosa": 85,
"sashimi": 86,
"scallops": 87,
"seaweed_salad": 88,
"shrimp_and_grits": 89,
"spaghetti_bolognese": 90,
"spaghetti_carbonara": 91,
"spring_rolls": 92,
"steak": 93,
"strawberry_shortcake": 94,
"sushi": 95,
"tacos": 96,
"takoyaki": 97,
"tiramisu": 98,
"tuna_tartare": 99,
"waffles": 100
}
class_to_label = {v: k for k, v in label_to_class.items()}
class_labels = model.config.id2label
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def classify_image(img):
img = Image.fromarray(img)
img = preprocess(img)
img = img.unsqueeze(0) # Add batch dimension
img = img.to(device) # Move to device
with torch.no_grad():
outputs = model(img)
logits = outputs.logits.squeeze().tolist() # Squeeze and convert to list
# Get indices of top 3 logits
top3_indices = heapq.nlargest(3, range(len(logits)), key=logits.__getitem__)
# Get corresponding class names
top3_classes = [class_to_label[class_labels[idx]] for idx in top3_indices]
words = top3_classes
outp = f"The most likely food is {words[0]}! If that seems unlikely it could also be {words[1]} or {words[2]} :)"
return outp
# Create the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs=gr.Textbox(label="Output"),
title="Food Image Classifier",
description="Upload an image of food and get the predicted category."
)
iface.launch()