first commit
Browse files- .gitattributes +1 -0
- app.py +98 -0
- combined_plant_toxicity.json +49 -0
- examples/1.jpg +0 -0
- examples/2.jpg +0 -0
- examples/3.jpg +0 -0
- idx_to_class.json +1 -0
- requirements.txt +5 -0
- vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
from torchvision import transforms
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Load toxicity data
|
10 |
+
def load_toxicity_data(file_path='combined_plant_toxicity.json'):
|
11 |
+
try:
|
12 |
+
with open(file_path, 'r') as f:
|
13 |
+
return json.load(f)
|
14 |
+
except FileNotFoundError:
|
15 |
+
print(f"Warning: {file_path} not found. Toxicity information will not be available.")
|
16 |
+
return {}
|
17 |
+
|
18 |
+
# Load class labels
|
19 |
+
def load_class_labels(file_path='idx_to_class.json'):
|
20 |
+
try:
|
21 |
+
with open(file_path, 'r') as f:
|
22 |
+
return json.load(f)
|
23 |
+
except FileNotFoundError:
|
24 |
+
print(f"Error: {file_path} not found. Class labels will not be available.")
|
25 |
+
return {}
|
26 |
+
|
27 |
+
# Load model
|
28 |
+
def load_model(model_path):
|
29 |
+
try:
|
30 |
+
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
|
31 |
+
model.load_state_dict(torch.load(model_path, weights_only=True))
|
32 |
+
model.eval()
|
33 |
+
return model
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error loading model: {str(e)}")
|
36 |
+
return None
|
37 |
+
|
38 |
+
# Global variables
|
39 |
+
toxicity_data = load_toxicity_data()
|
40 |
+
idx_to_class = load_class_labels()
|
41 |
+
model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
|
42 |
+
model = load_model(model_path)
|
43 |
+
|
44 |
+
# Define the transformation
|
45 |
+
transform = transforms.Compose([
|
46 |
+
transforms.Resize(256),
|
47 |
+
transforms.CenterCrop(224),
|
48 |
+
transforms.ToTensor(),
|
49 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
50 |
+
])
|
51 |
+
|
52 |
+
def classify_image(input_image):
|
53 |
+
if input_image is None:
|
54 |
+
return None, "Error: No image uploaded. Please choose an image and try again."
|
55 |
+
|
56 |
+
if model is None:
|
57 |
+
return None, "Error: Model could not be loaded. Please check the model path and try again."
|
58 |
+
|
59 |
+
try:
|
60 |
+
# Preprocess the image
|
61 |
+
input_tensor = transform(input_image).unsqueeze(0)
|
62 |
+
|
63 |
+
with torch.inference_mode():
|
64 |
+
output = model(input_tensor)
|
65 |
+
predictions = torch.nn.functional.softmax(output[0], dim=0)
|
66 |
+
confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)}
|
67 |
+
|
68 |
+
# Sort confidences and get top 3
|
69 |
+
top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3]
|
70 |
+
|
71 |
+
# Prepare the output for Label
|
72 |
+
label_output = {plant: conf for plant, conf in top_3}
|
73 |
+
|
74 |
+
# Prepare the toxicity information
|
75 |
+
toxicity_info = "### Toxicity Information\n\n"
|
76 |
+
for plant, _ in top_3:
|
77 |
+
toxicity = toxicity_data.get(plant, "Unknown")
|
78 |
+
toxicity_info += f"- **{plant}**: {toxicity}\n"
|
79 |
+
|
80 |
+
return label_output, toxicity_info
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
return None, f"An error occurred during image classification: {str(e)}"
|
84 |
+
|
85 |
+
demo = gr.Interface(
|
86 |
+
classify_image,
|
87 |
+
gr.Image(type="pil"),
|
88 |
+
[
|
89 |
+
gr.Label(num_top_classes=3, label="Top 3 Predictions"),
|
90 |
+
gr.Markdown(label="Toxicity Information")
|
91 |
+
],
|
92 |
+
title="🌱 Cat-Safe Plant Classifier 🐱",
|
93 |
+
description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants",
|
94 |
+
examples=[["examples/" + example] for example in os.listdir("examples")]
|
95 |
+
)
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
demo.launch()
|
combined_plant_toxicity.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"African Violet (Saintpaulia ionantha)": "Non-Toxic: Safe for cats",
|
3 |
+
"Aloe Vera": "Toxic: Causes digestive upset if ingested",
|
4 |
+
"Anthurium (Anthurium andraeanum)": "Toxic: Irritating to mouth and digestive tract",
|
5 |
+
"Areca Palm (Dypsis lutescens)": "Non-Toxic: Safe for cats",
|
6 |
+
"Asparagus Fern (Asparagus setaceus)": "Toxic: Can cause skin irritation and digestive issues",
|
7 |
+
"Begonia (Begonia spp.)": "Toxic: Can cause oral irritation and digestive upset",
|
8 |
+
"Bird of Paradise (Strelitzia reginae)": "Toxic: Can cause oral irritation and mild nausea",
|
9 |
+
"Bird's Nest Fern (Asplenium nidus)": "Non-Toxic: Safe for cats",
|
10 |
+
"Boston Fern (Nephrolepis exaltata)": "Non-Toxic: Safe for cats",
|
11 |
+
"Calathea": "Non-Toxic: Safe for cats",
|
12 |
+
"Cast Iron Plant (Aspidistra elatior)": "Non-Toxic: Safe for cats",
|
13 |
+
"Chinese Money Plant (Pilea peperomioides)": "Non-Toxic: Safe for cats",
|
14 |
+
"Chinese evergreen (Aglaonema)": "Toxic: Can cause oral irritation and swelling",
|
15 |
+
"Christmas Cactus (Schlumbergera bridgesii)": "Non-Toxic: Safe for cats",
|
16 |
+
"Chrysanthemum": "Toxic: Can cause gastrointestinal upset and skin irritation",
|
17 |
+
"Ctenanthe": "Non-Toxic: Safe for cats",
|
18 |
+
"Daffodils (Narcissus spp.)": "Toxic: Can cause severe gastrointestinal issues",
|
19 |
+
"Dracaena": "Toxic: Can cause vomiting and loss of appetite",
|
20 |
+
"Dumb Cane (Dieffenbachia spp.)": "Toxic: Causes severe mouth irritation and swelling",
|
21 |
+
"Elephant Ear (Alocasia spp.)": "Toxic: Can cause oral irritation and swelling",
|
22 |
+
"English Ivy (Hedera helix)": "Toxic: Can cause digestive upset and skin irritation",
|
23 |
+
"Hyacinth (Hyacinthus orientalis)": "Toxic: Can cause severe gastrointestinal issues",
|
24 |
+
"Iron Cross begonia (Begonia masoniana)": "Toxic: Can cause oral irritation and digestive upset",
|
25 |
+
"Jade plant (Crassula ovata)": "Toxic: Can cause vomiting and lethargy",
|
26 |
+
"Kalanchoe": "Toxic: Can cause digestive upset and heart rhythm abnormalities",
|
27 |
+
"Lilium (Hemerocallis)": "Highly Toxic: Extremely dangerous, can cause kidney failure",
|
28 |
+
"Lily of the valley (Convallaria majalis)": "Highly Toxic: Extremely dangerous, affects heart function",
|
29 |
+
"Money Tree (Pachira aquatica)": "Non-Toxic: Safe for cats",
|
30 |
+
"Monstera Deliciosa (Monstera deliciosa)": "Toxic: Can cause oral irritation and swelling",
|
31 |
+
"Orchid": "Non-Toxic: Safe for cats",
|
32 |
+
"Parlor Palm (Chamaedorea elegans)": "Non-Toxic: Safe for cats",
|
33 |
+
"Peace lily": "Toxic: Can cause oral irritation and digestive upset",
|
34 |
+
"Poinsettia (Euphorbia pulcherrima)": "Mildly Toxic: Mildly irritating to mouth and stomach",
|
35 |
+
"Polka Dot Plant (Hypoestes phyllostachya)": "Non-Toxic: Safe for cats",
|
36 |
+
"Ponytail Palm (Beaucarnea recurvata)": "Non-Toxic: Safe for cats",
|
37 |
+
"Pothos (Ivy arum)": "Toxic: Can cause oral irritation and swelling",
|
38 |
+
"Prayer Plant (Maranta leuconeura)": "Non-Toxic: Safe for cats",
|
39 |
+
"Rattlesnake Plant (Calathea lancifolia)": "Non-Toxic: Safe for cats",
|
40 |
+
"Rubber Plant (Ficus elastica)": "Toxic: Can cause skin and gastrointestinal irritation",
|
41 |
+
"Sago Palm (Cycas revoluta)": "Highly Toxic: Extremely toxic, can cause liver failure",
|
42 |
+
"Schefflera": "Toxic: Can cause oral irritation and digestive issues",
|
43 |
+
"Snake plant (Sanseviera)": "Mildly Toxic: Can cause nausea and vomiting",
|
44 |
+
"Tradescantia": "Mildly Toxic: Can cause mild skin irritation",
|
45 |
+
"Tulip": "Toxic: Can cause digestive issues and breathing problems",
|
46 |
+
"Venus Flytrap": "Non-Toxic: Safe for cats",
|
47 |
+
"Yucca": "Toxic: Can cause vomiting and diarrhea",
|
48 |
+
"ZZ Plant (Zamioculcas zamiifolia)": "Toxic: Can cause digestive upset if ingested"
|
49 |
+
}
|
examples/1.jpg
ADDED
![]() |
examples/2.jpg
ADDED
![]() |
examples/3.jpg
ADDED
![]() |
idx_to_class.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0": "African Violet (Saintpaulia ionantha)", "1": "Aloe Vera", "2": "Anthurium (Anthurium andraeanum)", "3": "Areca Palm (Dypsis lutescens)", "4": "Asparagus Fern (Asparagus setaceus)", "5": "Begonia (Begonia spp.)", "6": "Bird of Paradise (Strelitzia reginae)", "7": "Bird's Nest Fern (Asplenium nidus)", "8": "Boston Fern (Nephrolepis exaltata)", "9": "Calathea", "10": "Cast Iron Plant (Aspidistra elatior)", "11": "Chinese Money Plant (Pilea peperomioides)", "12": "Chinese evergreen (Aglaonema)", "13": "Christmas Cactus (Schlumbergera bridgesii)", "14": "Chrysanthemum", "15": "Ctenanthe", "16": "Daffodils (Narcissus spp.)", "17": "Dracaena", "18": "Dumb Cane (Dieffenbachia spp.)", "19": "Elephant Ear (Alocasia spp.)", "20": "English Ivy (Hedera helix)", "21": "Hyacinth (Hyacinthus orientalis)", "22": "Iron Cross begonia (Begonia masoniana)", "23": "Jade plant (Crassula ovata)", "24": "Kalanchoe", "25": "Lilium (Hemerocallis)", "26": "Lily of the valley (Convallaria majalis)", "27": "Money Tree (Pachira aquatica)", "28": "Monstera Deliciosa (Monstera deliciosa)", "29": "Orchid", "30": "Parlor Palm (Chamaedorea elegans)", "31": "Peace lily", "32": "Poinsettia (Euphorbia pulcherrima)", "33": "Polka Dot Plant (Hypoestes phyllostachya)", "34": "Ponytail Palm (Beaucarnea recurvata)", "35": "Pothos (Ivy arum)", "36": "Prayer Plant (Maranta leuconeura)", "37": "Rattlesnake Plant (Calathea lancifolia)", "38": "Rubber Plant (Ficus elastica)", "39": "Sago Palm (Cycas revoluta)", "40": "Schefflera", "41": "Snake plant (Sanseviera)", "42": "Tradescantia", "43": "Tulip", "44": "Venus Flytrap", "45": "Yucca", "46": "ZZ Plant (Zamioculcas zamiifolia)"}
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.0
|
2 |
+
torchvision==0.19.0
|
3 |
+
timm==1.0.8
|
4 |
+
gradio==4.44.0
|
5 |
+
|
vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fab0263d78ff641545d234b20e3b06309dfb67fd13802cdda1cfda8d63c0c39
|
3 |
+
size 343403462
|