kakasher commited on
Commit
b1e5ee7
·
1 Parent(s): c207faa

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from torchvision import transforms
4
+ import timm
5
+ import torch
6
+ import json
7
+ import os
8
+
9
+ # Load toxicity data
10
+ def load_toxicity_data(file_path='combined_plant_toxicity.json'):
11
+ try:
12
+ with open(file_path, 'r') as f:
13
+ return json.load(f)
14
+ except FileNotFoundError:
15
+ print(f"Warning: {file_path} not found. Toxicity information will not be available.")
16
+ return {}
17
+
18
+ # Load class labels
19
+ def load_class_labels(file_path='idx_to_class.json'):
20
+ try:
21
+ with open(file_path, 'r') as f:
22
+ return json.load(f)
23
+ except FileNotFoundError:
24
+ print(f"Error: {file_path} not found. Class labels will not be available.")
25
+ return {}
26
+
27
+ # Load model
28
+ def load_model(model_path):
29
+ try:
30
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=47)
31
+ model.load_state_dict(torch.load(model_path, weights_only=True))
32
+ model.eval()
33
+ return model
34
+ except Exception as e:
35
+ print(f"Error loading model: {str(e)}")
36
+ return None
37
+
38
+ # Global variables
39
+ toxicity_data = load_toxicity_data()
40
+ idx_to_class = load_class_labels()
41
+ model_path = Path('vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar')
42
+ model = load_model(model_path)
43
+
44
+ # Define the transformation
45
+ transform = transforms.Compose([
46
+ transforms.Resize(256),
47
+ transforms.CenterCrop(224),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
50
+ ])
51
+
52
+ def classify_image(input_image):
53
+ if input_image is None:
54
+ return None, "Error: No image uploaded. Please choose an image and try again."
55
+
56
+ if model is None:
57
+ return None, "Error: Model could not be loaded. Please check the model path and try again."
58
+
59
+ try:
60
+ # Preprocess the image
61
+ input_tensor = transform(input_image).unsqueeze(0)
62
+
63
+ with torch.inference_mode():
64
+ output = model(input_tensor)
65
+ predictions = torch.nn.functional.softmax(output[0], dim=0)
66
+ confidences = {idx_to_class[str(i)]: float(predictions[i]) for i in range(47)}
67
+
68
+ # Sort confidences and get top 3
69
+ top_3 = sorted(confidences.items(), key=lambda x: x[1], reverse=True)[:3]
70
+
71
+ # Prepare the output for Label
72
+ label_output = {plant: conf for plant, conf in top_3}
73
+
74
+ # Prepare the toxicity information
75
+ toxicity_info = "### Toxicity Information\n\n"
76
+ for plant, _ in top_3:
77
+ toxicity = toxicity_data.get(plant, "Unknown")
78
+ toxicity_info += f"- **{plant}**: {toxicity}\n"
79
+
80
+ return label_output, toxicity_info
81
+
82
+ except Exception as e:
83
+ return None, f"An error occurred during image classification: {str(e)}"
84
+
85
+ demo = gr.Interface(
86
+ classify_image,
87
+ gr.Image(type="pil"),
88
+ [
89
+ gr.Label(num_top_classes=3, label="Top 3 Predictions"),
90
+ gr.Markdown(label="Toxicity Information")
91
+ ],
92
+ title="🌱 Cat-Safe Plant Classifier 🐱",
93
+ description="Upload an image of a plant to get its classification and toxicity information for cats. Includes 47 most popular house plants",
94
+ examples=[["examples/" + example] for example in os.listdir("examples")]
95
+ )
96
+
97
+ if __name__ == "__main__":
98
+ demo.launch()
combined_plant_toxicity.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "African Violet (Saintpaulia ionantha)": "Non-Toxic: Safe for cats",
3
+ "Aloe Vera": "Toxic: Causes digestive upset if ingested",
4
+ "Anthurium (Anthurium andraeanum)": "Toxic: Irritating to mouth and digestive tract",
5
+ "Areca Palm (Dypsis lutescens)": "Non-Toxic: Safe for cats",
6
+ "Asparagus Fern (Asparagus setaceus)": "Toxic: Can cause skin irritation and digestive issues",
7
+ "Begonia (Begonia spp.)": "Toxic: Can cause oral irritation and digestive upset",
8
+ "Bird of Paradise (Strelitzia reginae)": "Toxic: Can cause oral irritation and mild nausea",
9
+ "Bird's Nest Fern (Asplenium nidus)": "Non-Toxic: Safe for cats",
10
+ "Boston Fern (Nephrolepis exaltata)": "Non-Toxic: Safe for cats",
11
+ "Calathea": "Non-Toxic: Safe for cats",
12
+ "Cast Iron Plant (Aspidistra elatior)": "Non-Toxic: Safe for cats",
13
+ "Chinese Money Plant (Pilea peperomioides)": "Non-Toxic: Safe for cats",
14
+ "Chinese evergreen (Aglaonema)": "Toxic: Can cause oral irritation and swelling",
15
+ "Christmas Cactus (Schlumbergera bridgesii)": "Non-Toxic: Safe for cats",
16
+ "Chrysanthemum": "Toxic: Can cause gastrointestinal upset and skin irritation",
17
+ "Ctenanthe": "Non-Toxic: Safe for cats",
18
+ "Daffodils (Narcissus spp.)": "Toxic: Can cause severe gastrointestinal issues",
19
+ "Dracaena": "Toxic: Can cause vomiting and loss of appetite",
20
+ "Dumb Cane (Dieffenbachia spp.)": "Toxic: Causes severe mouth irritation and swelling",
21
+ "Elephant Ear (Alocasia spp.)": "Toxic: Can cause oral irritation and swelling",
22
+ "English Ivy (Hedera helix)": "Toxic: Can cause digestive upset and skin irritation",
23
+ "Hyacinth (Hyacinthus orientalis)": "Toxic: Can cause severe gastrointestinal issues",
24
+ "Iron Cross begonia (Begonia masoniana)": "Toxic: Can cause oral irritation and digestive upset",
25
+ "Jade plant (Crassula ovata)": "Toxic: Can cause vomiting and lethargy",
26
+ "Kalanchoe": "Toxic: Can cause digestive upset and heart rhythm abnormalities",
27
+ "Lilium (Hemerocallis)": "Highly Toxic: Extremely dangerous, can cause kidney failure",
28
+ "Lily of the valley (Convallaria majalis)": "Highly Toxic: Extremely dangerous, affects heart function",
29
+ "Money Tree (Pachira aquatica)": "Non-Toxic: Safe for cats",
30
+ "Monstera Deliciosa (Monstera deliciosa)": "Toxic: Can cause oral irritation and swelling",
31
+ "Orchid": "Non-Toxic: Safe for cats",
32
+ "Parlor Palm (Chamaedorea elegans)": "Non-Toxic: Safe for cats",
33
+ "Peace lily": "Toxic: Can cause oral irritation and digestive upset",
34
+ "Poinsettia (Euphorbia pulcherrima)": "Mildly Toxic: Mildly irritating to mouth and stomach",
35
+ "Polka Dot Plant (Hypoestes phyllostachya)": "Non-Toxic: Safe for cats",
36
+ "Ponytail Palm (Beaucarnea recurvata)": "Non-Toxic: Safe for cats",
37
+ "Pothos (Ivy arum)": "Toxic: Can cause oral irritation and swelling",
38
+ "Prayer Plant (Maranta leuconeura)": "Non-Toxic: Safe for cats",
39
+ "Rattlesnake Plant (Calathea lancifolia)": "Non-Toxic: Safe for cats",
40
+ "Rubber Plant (Ficus elastica)": "Toxic: Can cause skin and gastrointestinal irritation",
41
+ "Sago Palm (Cycas revoluta)": "Highly Toxic: Extremely toxic, can cause liver failure",
42
+ "Schefflera": "Toxic: Can cause oral irritation and digestive issues",
43
+ "Snake plant (Sanseviera)": "Mildly Toxic: Can cause nausea and vomiting",
44
+ "Tradescantia": "Mildly Toxic: Can cause mild skin irritation",
45
+ "Tulip": "Toxic: Can cause digestive issues and breathing problems",
46
+ "Venus Flytrap": "Non-Toxic: Safe for cats",
47
+ "Yucca": "Toxic: Can cause vomiting and diarrhea",
48
+ "ZZ Plant (Zamioculcas zamiifolia)": "Toxic: Can cause digestive upset if ingested"
49
+ }
examples/1.jpg ADDED
examples/2.jpg ADDED
examples/3.jpg ADDED
idx_to_class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "African Violet (Saintpaulia ionantha)", "1": "Aloe Vera", "2": "Anthurium (Anthurium andraeanum)", "3": "Areca Palm (Dypsis lutescens)", "4": "Asparagus Fern (Asparagus setaceus)", "5": "Begonia (Begonia spp.)", "6": "Bird of Paradise (Strelitzia reginae)", "7": "Bird's Nest Fern (Asplenium nidus)", "8": "Boston Fern (Nephrolepis exaltata)", "9": "Calathea", "10": "Cast Iron Plant (Aspidistra elatior)", "11": "Chinese Money Plant (Pilea peperomioides)", "12": "Chinese evergreen (Aglaonema)", "13": "Christmas Cactus (Schlumbergera bridgesii)", "14": "Chrysanthemum", "15": "Ctenanthe", "16": "Daffodils (Narcissus spp.)", "17": "Dracaena", "18": "Dumb Cane (Dieffenbachia spp.)", "19": "Elephant Ear (Alocasia spp.)", "20": "English Ivy (Hedera helix)", "21": "Hyacinth (Hyacinthus orientalis)", "22": "Iron Cross begonia (Begonia masoniana)", "23": "Jade plant (Crassula ovata)", "24": "Kalanchoe", "25": "Lilium (Hemerocallis)", "26": "Lily of the valley (Convallaria majalis)", "27": "Money Tree (Pachira aquatica)", "28": "Monstera Deliciosa (Monstera deliciosa)", "29": "Orchid", "30": "Parlor Palm (Chamaedorea elegans)", "31": "Peace lily", "32": "Poinsettia (Euphorbia pulcherrima)", "33": "Polka Dot Plant (Hypoestes phyllostachya)", "34": "Ponytail Palm (Beaucarnea recurvata)", "35": "Pothos (Ivy arum)", "36": "Prayer Plant (Maranta leuconeura)", "37": "Rattlesnake Plant (Calathea lancifolia)", "38": "Rubber Plant (Ficus elastica)", "39": "Sago Palm (Cycas revoluta)", "40": "Schefflera", "41": "Snake plant (Sanseviera)", "42": "Tradescantia", "43": "Tulip", "44": "Venus Flytrap", "45": "Yucca", "46": "ZZ Plant (Zamioculcas zamiifolia)"}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchvision==0.19.0
3
+ timm==1.0.8
4
+ gradio==4.44.0
5
+
vit_b16_224_25e_256bs_0.001lr_adamW_transforms.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fab0263d78ff641545d234b20e3b06309dfb67fd13802cdda1cfda8d63c0c39
3
+ size 343403462