kakasher's picture
changed device to cpu
8531995
raw
history blame
3.18 kB
import gradio as gr
from pathlib import Path
from torchvision import transforms
import timm
import torch
import json
import os
# Load toxicity data
def load_toxicity_data(file_path='combined_plant_toxicity.json'):
try:
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
print(f"Warning: {file_path} not found. Toxicity information will not be available.")
return {}
# Load class labels
def load_class_labels(file_path='idx_to_class.json'):
try:
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: {file_path} not found. Class labels will not be available.")
return {}
# Global variables
toxicity_data = load_toxicity_data()
idx_to_class = load_class_labels()
model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
model = None
# Load model
try:
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu')))
model.eval()
except Exception as e:
print(f"Error loading model: {str(e)}")
# Define the transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def classify_image(input_image):
if input_image is None:
return None, "Error: No image uploaded. Please choose an image and try again."
if model is None:
return None, "Error: Model could not be loaded. Please check the model path and try again."
try:
# Preprocess the image
input_tensor = transform(input_image).unsqueeze(0)
with torch.inference_mode():
output = model(input_tensor)
predictions = torch.nn.functional.softmax(output[0], dim=0)
confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)}
# Sort confidences and get top 3
top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3]
# Prepare the output for Label
label_output = {plant: conf for plant, conf in top_3}
# Prepare the toxicity information
toxicity_info = "### Toxicity Information\n\n"
for plant, _ in top_3:
toxicity = toxicity_data.get(plant, "Unknown")
toxicity_info += f"- **{plant}**: {toxicity}\n"
return label_output, toxicity_info
except Exception as e:
return None, f"An error occurred during image classification: {str(e)}"
demo = gr.Interface(
classify_image,
gr.Image(type="pil"),
[
gr.Label(num_top_classes=3, label="Top 3 Predictions"),
gr.Markdown(label="Toxicity Information")
],
title="🌱 Cat-Safe Plant Classifier 🐱",
description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants",
examples=[["examples/" + example] for example in os.listdir("examples")]
)
if __name__ == "__main__":
demo.launch()