File size: 3,249 Bytes
b1e5ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from pathlib import Path
from torchvision import transforms
import timm
import torch
import json
import os

# Load toxicity data
def load_toxicity_data(file_path='combined_plant_toxicity.json'):
    try:
        with open(file_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Warning: {file_path} not found. Toxicity information will not be available.")
        return {}

# Load class labels
def load_class_labels(file_path='idx_to_class.json'):
    try:
        with open(file_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Error: {file_path} not found. Class labels will not be available.")
        return {}

# Load model
def load_model(model_path):
    try:
        model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
        model.load_state_dict(torch.load(model_path, weights_only=True))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

# Global variables
toxicity_data = load_toxicity_data()
idx_to_class = load_class_labels()
model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
model = load_model(model_path)

# Define the transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

def classify_image(input_image):
    if input_image is None:
        return None, "Error: No image uploaded. Please choose an image and try again."

    if model is None:
        return None, "Error: Model could not be loaded. Please check the model path and try again."

    try:
        # Preprocess the image
        input_tensor = transform(input_image).unsqueeze(0)

        with torch.inference_mode():
            output = model(input_tensor)
            predictions = torch.nn.functional.softmax(output[0], dim=0)
            confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)}

        # Sort confidences and get top 3
        top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3]

        # Prepare the output for Label
        label_output = {plant: conf for plant, conf in top_3}

        # Prepare the toxicity information
        toxicity_info = "### Toxicity Information\n\n"
        for plant, _ in top_3:
            toxicity = toxicity_data.get(plant, "Unknown")
            toxicity_info += f"- **{plant}**: {toxicity}\n"

        return label_output, toxicity_info

    except Exception as e:
        return None, f"An error occurred during image classification: {str(e)}"

demo = gr.Interface(
    classify_image,
    gr.Image(type="pil"),
    [
        gr.Label(num_top_classes=3, label="Top 3 Predictions"),
        gr.Markdown(label="Toxicity Information")
    ],
    title="🌱 Cat-Safe Plant Classifier 🐱",
    description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants",
    examples=[["examples/" + example] for example in os.listdir("examples")]
)

if __name__ == "__main__":
    demo.launch()