sanidhya2810's picture
Update app.py
0d11331 verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms.functional as TF
import CNN
import numpy as np
import pandas as pd
# Load model and data
model = CNN.CNN(39)
model.load_state_dict(torch.load("plant_disease_model_1_latest.pt", map_location='cpu'))
model.eval()
disease_info = pd.read_csv('disease_info.csv', encoding='cp1252')
supplement_info = pd.read_csv('supplement_info.csv', encoding='cp1252')
# Prediction logic
def predict(image):
image = image.resize((224, 224))
input_data = TF.to_tensor(image).view((-1, 3, 224, 224))
output = model(input_data)
output = output.detach().numpy()
index = np.argmax(output)
# Fetch prediction details
result = {
"Disease Name": disease_info['disease_name'][index],
"Description": disease_info['description'][index],
"Prevention": disease_info['Possible Steps'][index],
"Supplement": supplement_info['supplement name'][index],
"Buy Link": supplement_info['buy link'][index],
"Disease Image": disease_info['image_url'][index],
"Supplement Image": supplement_info['supplement image'][index]
}
return result["Disease Name"], result["Description"], result["Prevention"], result["Supplement"], result["Buy Link"], result["Disease Image"], result["Supplement Image"]
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Plant Leaf"),
outputs=[
gr.Text(label="Disease Name"),
gr.Textbox(label="Description"),
gr.Textbox(label="Prevention Steps"),
gr.Text(label="Recommended Supplement"),
gr.Textbox(label="Buy Link"),
gr.Image(label="Disease Image"),
gr.Image(label="Supplement Image")
],
title="Plant Disease Diagnosis",
description="Upload an image of the plant leaf to detect disease and get treatment info.",
css="static/styles.css"
)
if __name__ == "__main__":
interface.launch()