Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import os | |
from dotenv import load_dotenv | |
from groq import Groq | |
# Load environment variables | |
load_dotenv() | |
# Page settings | |
st.set_page_config(page_title="πΏ Leaf Disease Detector", layout="wide") | |
st.markdown("<h1 style='text-align: center;'>πΏ Plant Leaf Disease Detection</h1>", unsafe_allow_html=True) | |
st.markdown("<p style='text-align: center;'>Upload a leaf image to detect plant diseases and get treatment guidance.</p>", unsafe_allow_html=True) | |
st.markdown("---") | |
# Initialize Groq client | |
try: | |
api_key = os.getenv("GROQ_API_KEY") | |
client = Groq(api_key=api_key) | |
except Exception as e: | |
st.error(f"Failed to initialize Groq client: {str(e)}") | |
client = None | |
# Dummy CNN model | |
class PlantDiseaseModel(nn.Module): | |
def __init__(self, num_classes=28): | |
super(PlantDiseaseModel, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), | |
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), | |
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), | |
) | |
self.classifier = nn.Sequential( | |
nn.Linear(128 * 32 * 32, 512), nn.ReLU(), nn.Dropout(0.5), | |
nn.Linear(512, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
def load_model(): | |
model = PlantDiseaseModel() | |
model.eval() | |
return model | |
model = load_model() | |
# Preprocessing | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
# Disease classes | |
disease_classes = [ | |
"Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust", | |
"Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust", | |
"Grape Black Rot", "Grape Esca", "Grape Leaf Blight", | |
"Orange Huanglongbing", "Peach Bacterial Spot", "Pepper Bacterial Spot", | |
"Potato Early Blight", "Potato Late Blight", "Raspberry Healthy", | |
"Soybean Healthy", "Squash Powdery Mildew", "Strawberry Leaf Scorch", | |
"Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Late Blight", | |
"Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites", | |
"Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus" | |
] | |
# Predict class | |
def classify_disease(image): | |
try: | |
img_tensor = preprocess_image(image) | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, predicted = torch.max(outputs, 1) | |
class_idx = predicted.item() | |
return disease_classes[class_idx % len(disease_classes)] | |
except Exception as e: | |
st.error(f"Error during classification: {str(e)}") | |
return "Unknown" | |
# Fetch disease info | |
def get_disease_info(disease_name): | |
if not client: | |
return { | |
"description": "API not available. Check GROQ_API_KEY.", | |
} | |
try: | |
if disease_name.lower() == "healthy": | |
return { | |
"description": "The plant appears to be healthy. No treatment is needed.", | |
} | |
response = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a plant pathologist assistant."}, | |
{"role": "user", "content": f"Describe {disease_name} in plants including symptoms, treatment, and prevention."} | |
], | |
model="llama-3.3-70b-versatile", | |
temperature=0.3, | |
max_tokens=1024 | |
) | |
return {"description": response.choices[0].message.content} | |
except Exception as e: | |
st.error(f"Error fetching disease info: {str(e)}") | |
return { | |
"description": "Unable to fetch disease info. Please try again later.", | |
} | |
# Main app | |
def main(): | |
uploaded_file = st.file_uploader("π· Upload a leaf image", type=["jpeg", "png", "jpg"]) | |
if uploaded_file: | |
filename = uploaded_file.name.lower() | |
if not (filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png")): | |
st.error("Only JPG, JPEG, and PNG files are allowed.") | |
st.stop() | |
try: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded Leaf Image", width=400) | |
if st.button("π Predict Disease"): | |
with st.spinner("Analyzing..."): | |
disease_name = classify_disease(image) | |
info = get_disease_info(disease_name) | |
st.markdown("---") | |
st.subheader("π¬ Prediction Results") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
status = "β Healthy" if disease_name.lower() == "healthy" else "β οΈ Diseased" | |
st.markdown(f"**Status:** {status}") | |
st.markdown(f"**Detected Disease:** `{disease_name}`") | |
with col2: | |
st.markdown("**π Disease Info:**") | |
st.markdown(info["description"]) | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
if __name__ == "__main__": | |
main() | |