|
import gradio as gr |
|
from pathlib import Path |
|
from torchvision import transforms |
|
import timm |
|
import torch |
|
import json |
|
import os |
|
|
|
|
|
def load_toxicity_data(file_path='combined_plant_toxicity.json'): |
|
try: |
|
with open(file_path, 'r') as f: |
|
return json.load(f) |
|
except FileNotFoundError: |
|
print(f"Warning: {file_path} not found. Toxicity information will not be available.") |
|
return {} |
|
|
|
|
|
def load_class_labels(file_path='idx_to_class.json'): |
|
try: |
|
with open(file_path, 'r') as f: |
|
return json.load(f) |
|
except FileNotFoundError: |
|
print(f"Error: {file_path} not found. Class labels will not be available.") |
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
toxicity_data = load_toxicity_data() |
|
idx_to_class = load_class_labels() |
|
model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar') |
|
model = None |
|
|
|
|
|
try: |
|
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47) |
|
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))) |
|
model.eval() |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
]) |
|
|
|
def classify_image(input_image): |
|
if input_image is None: |
|
return None, "Error: No image uploaded. Please choose an image and try again." |
|
|
|
if model is None: |
|
return None, "Error: Model could not be loaded. Please check the model path and try again." |
|
|
|
try: |
|
|
|
input_tensor = transform(input_image).unsqueeze(0) |
|
|
|
with torch.inference_mode(): |
|
output = model(input_tensor) |
|
predictions = torch.nn.functional.softmax(output[0], dim=0) |
|
confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)} |
|
|
|
|
|
top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
|
|
|
label_output = {plant: conf for plant, conf in top_3} |
|
|
|
|
|
toxicity_info = "### Toxicity Information\n\n" |
|
for plant, _ in top_3: |
|
toxicity = toxicity_data.get(plant, "Unknown") |
|
toxicity_info += f"- **{plant}**: {toxicity}\n" |
|
|
|
return label_output, toxicity_info |
|
|
|
except Exception as e: |
|
return None, f"An error occurred during image classification: {str(e)}" |
|
|
|
demo = gr.Interface( |
|
classify_image, |
|
gr.Image(type="pil"), |
|
[ |
|
gr.Label(num_top_classes=3, label="Top 3 Predictions"), |
|
gr.Markdown(label="Toxicity Information") |
|
], |
|
title="π± Cat-Safe Plant Classifier π±", |
|
description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants", |
|
examples=[["examples/" + example] for example in os.listdir("examples")] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|