Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pathlib import Path | |
from torchvision import transforms | |
import timm | |
import torch | |
import json | |
import os | |
# Load toxicity data | |
def load_toxicity_data(file_path='combined_plant_toxicity.json'): | |
try: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
print(f"Warning: {file_path} not found. Toxicity information will not be available.") | |
return {} | |
# Load class labels | |
def load_class_labels(file_path='idx_to_class.json'): | |
try: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
print(f"Error: {file_path} not found. Class labels will not be available.") | |
return {} | |
# Global variables | |
toxicity_data = load_toxicity_data() | |
idx_to_class = load_class_labels() | |
model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar') | |
model = None | |
# Load model | |
try: | |
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47) | |
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))) | |
model.eval() | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
# Define the transformation | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
def classify_image(input_image): | |
if input_image is None: | |
return None, "Error: No image uploaded. Please choose an image and try again." | |
if model is None: | |
return None, "Error: Model could not be loaded. Please check the model path and try again." | |
try: | |
# Preprocess the image | |
input_tensor = transform(input_image).unsqueeze(0) | |
with torch.inference_mode(): | |
output = model(input_tensor) | |
predictions = torch.nn.functional.softmax(output[0], dim=0) | |
confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)} | |
# Sort confidences and get top 3 | |
top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3] | |
# Prepare the output for Label | |
label_output = {plant: conf for plant, conf in top_3} | |
# Prepare the toxicity information | |
toxicity_info = "### Toxicity Information\n\n" | |
for plant, _ in top_3: | |
toxicity = toxicity_data.get(plant, "Unknown") | |
toxicity_info += f"- **{plant}**: {toxicity}\n" | |
return label_output, toxicity_info | |
except Exception as e: | |
return None, f"An error occurred during image classification: {str(e)}" | |
demo = gr.Interface( | |
classify_image, | |
gr.Image(type="pil"), | |
[ | |
gr.Label(num_top_classes=3, label="Top 3 Predictions"), | |
gr.Markdown(label="Toxicity Information") | |
], | |
title="π± Cat-Safe Plant Classifier π±", | |
description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants", | |
examples=[["examples/" + example] for example in os.listdir("examples")] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |