IsmatS's picture
Update app.py to use local model files
9281224
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import os
# Load the model with proper error handling
def load_model():
model_paths = [
'best_model.pt',
'tree_disease_detector.pt',
'./best_model.pt',
'./tree_disease_detector.pt'
]
# Try to load from local files first
for path in model_paths:
if os.path.exists(path):
try:
print(f"Loading model from {path}")
model = YOLO(path)
return model, f"Tree Disease Detection Model ({path})"
except Exception as e:
print(f"Error loading {path}: {e}")
continue
# Fallback to standard YOLOv8s
try:
print("Loading standard YOLOv8s model...")
model = YOLO('yolov8s.pt')
return model, "Standard YOLOv8s Model (Fallback)"
except Exception as e:
print(f"Error loading YOLOv8s: {e}")
return None, "No model available"
# Load model and get status
model, model_status = load_model()
def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
"""Detect unhealthy trees in the uploaded image"""
if model is None:
return image, "Error: No model available"
# Convert PIL image to numpy array
image_np = np.array(image)
# Run inference
results = model(image_np, conf=conf_threshold, iou=iou_threshold)
# Get annotated image directly from results
annotated_img = results[0].plot()
annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
annotated_img = Image.fromarray(annotated_img)
# Extract detections
detections = []
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
detection = {
'confidence': float(box.conf[0]),
'bbox': box.xyxy[0].tolist(),
'class': 'unhealthy'
}
detections.append(detection)
# Create detection summary
is_custom_model = "Tree Disease Detection Model" in model_status
if is_custom_model:
summary = f"Detected {len(detections)} unhealthy tree(s)\n\n"
for i, det in enumerate(detections, 1):
summary += f"Tree {i}: Confidence {det['confidence']:.2f}\n"
else:
summary = f"Using {model_status}\n"
summary += f"Detected {len(detections)} object(s)\n\n"
for i, det in enumerate(detections, 1):
summary += f"Object {i}: Confidence {det['confidence']:.2f}\n"
summary += f"\nModel Status: {model_status}"
return annotated_img, summary
# Create example images (tree images)
example_images = [
["https://images.pexels.com/photos/1632790/pexels-photo-1632790.jpeg", 0.25, 0.45],
["https://images.pexels.com/photos/38537/woodland-road-falling-leaf-natural-38537.jpeg", 0.25, 0.45],
["https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Ash_Tree_-_geograph.org.uk_-_590710.jpg/640px-Ash_Tree_-_geograph.org.uk_-_590710.jpg", 0.25, 0.45],
]
# Create Gradio interface
with gr.Blocks(title="Tree Disease Detection") as demo:
gr.Markdown(f"""
# 🌳 Tree Disease Detection with YOLOv8
This model detects unhealthy/diseased trees in aerial UAV imagery.
Upload an image or use one of the examples below to detect diseased trees.
**Current Model**: {model_status}
""")
if "Fallback" in model_status:
gr.Markdown("""
⚠️ **Note**: Using a fallback model. Detection will work but won't be specific to tree diseases.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
conf_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Confidence Threshold"
)
iou_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.45,
step=0.05,
label="IoU Threshold"
)
detect_button = gr.Button("Detect Tree Disease", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Detection Results")
detection_summary = gr.Textbox(label="Detection Summary", lines=10)
# Set up event handler
detect_button.click(
fn=detect_tree_disease,
inputs=[input_image, conf_threshold, iou_threshold],
outputs=[output_image, detection_summary]
)
# Add examples
gr.Examples(
examples=example_images,
inputs=[input_image, conf_threshold, iou_threshold],
outputs=[output_image, detection_summary],
fn=detect_tree_disease,
cache_examples=False,
)
gr.Markdown("""
## About this Model
- **Architecture**: YOLOv8s
- **Dataset**: [PDT Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
- **mAP50**: 0.933
- **mAP50-95**: 0.659
- **Precision**: 0.878
- **Recall**: 0.863
- **Classes**: 1 (unhealthy trees)
## Usage Tips
- This model works best with aerial/UAV imagery
- Optimal input resolution: 640x640 pixels
- Adjust confidence threshold to filter detections
- Lower IoU threshold for overlapping trees
[Model Card](https://huggingface.co/IsmatS/crop_desease_detection) |
[Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
""")
# Launch the app
if __name__ == "__main__":
demo.launch()