|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
|
|
def load_model(): |
|
model_paths = [ |
|
'best_model.pt', |
|
'tree_disease_detector.pt', |
|
'./best_model.pt', |
|
'./tree_disease_detector.pt' |
|
] |
|
|
|
|
|
for path in model_paths: |
|
if os.path.exists(path): |
|
try: |
|
print(f"Loading model from {path}") |
|
model = YOLO(path) |
|
return model, f"Tree Disease Detection Model ({path})" |
|
except Exception as e: |
|
print(f"Error loading {path}: {e}") |
|
continue |
|
|
|
|
|
try: |
|
print("Loading standard YOLOv8s model...") |
|
model = YOLO('yolov8s.pt') |
|
return model, "Standard YOLOv8s Model (Fallback)" |
|
except Exception as e: |
|
print(f"Error loading YOLOv8s: {e}") |
|
return None, "No model available" |
|
|
|
|
|
model, model_status = load_model() |
|
|
|
def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45): |
|
"""Detect unhealthy trees in the uploaded image""" |
|
|
|
if model is None: |
|
return image, "Error: No model available" |
|
|
|
|
|
image_np = np.array(image) |
|
|
|
|
|
results = model(image_np, conf=conf_threshold, iou=iou_threshold) |
|
|
|
|
|
annotated_img = results[0].plot() |
|
annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) |
|
annotated_img = Image.fromarray(annotated_img) |
|
|
|
|
|
detections = [] |
|
for result in results: |
|
boxes = result.boxes |
|
if boxes is not None: |
|
for box in boxes: |
|
detection = { |
|
'confidence': float(box.conf[0]), |
|
'bbox': box.xyxy[0].tolist(), |
|
'class': 'unhealthy' |
|
} |
|
detections.append(detection) |
|
|
|
|
|
is_custom_model = "Tree Disease Detection Model" in model_status |
|
|
|
if is_custom_model: |
|
summary = f"Detected {len(detections)} unhealthy tree(s)\n\n" |
|
for i, det in enumerate(detections, 1): |
|
summary += f"Tree {i}: Confidence {det['confidence']:.2f}\n" |
|
else: |
|
summary = f"Using {model_status}\n" |
|
summary += f"Detected {len(detections)} object(s)\n\n" |
|
for i, det in enumerate(detections, 1): |
|
summary += f"Object {i}: Confidence {det['confidence']:.2f}\n" |
|
|
|
summary += f"\nModel Status: {model_status}" |
|
|
|
return annotated_img, summary |
|
|
|
|
|
example_images = [ |
|
["https://images.pexels.com/photos/1632790/pexels-photo-1632790.jpeg", 0.25, 0.45], |
|
["https://images.pexels.com/photos/38537/woodland-road-falling-leaf-natural-38537.jpeg", 0.25, 0.45], |
|
["https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Ash_Tree_-_geograph.org.uk_-_590710.jpg/640px-Ash_Tree_-_geograph.org.uk_-_590710.jpg", 0.25, 0.45], |
|
] |
|
|
|
|
|
with gr.Blocks(title="Tree Disease Detection") as demo: |
|
gr.Markdown(f""" |
|
# 🌳 Tree Disease Detection with YOLOv8 |
|
|
|
This model detects unhealthy/diseased trees in aerial UAV imagery. |
|
Upload an image or use one of the examples below to detect diseased trees. |
|
|
|
**Current Model**: {model_status} |
|
""") |
|
|
|
if "Fallback" in model_status: |
|
gr.Markdown(""" |
|
⚠️ **Note**: Using a fallback model. Detection will work but won't be specific to tree diseases. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
conf_threshold = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.25, |
|
step=0.05, |
|
label="Confidence Threshold" |
|
) |
|
iou_threshold = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.45, |
|
step=0.05, |
|
label="IoU Threshold" |
|
) |
|
detect_button = gr.Button("Detect Tree Disease", variant="primary") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(type="pil", label="Detection Results") |
|
detection_summary = gr.Textbox(label="Detection Summary", lines=10) |
|
|
|
|
|
detect_button.click( |
|
fn=detect_tree_disease, |
|
inputs=[input_image, conf_threshold, iou_threshold], |
|
outputs=[output_image, detection_summary] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=example_images, |
|
inputs=[input_image, conf_threshold, iou_threshold], |
|
outputs=[output_image, detection_summary], |
|
fn=detect_tree_disease, |
|
cache_examples=False, |
|
) |
|
|
|
gr.Markdown(""" |
|
## About this Model |
|
|
|
- **Architecture**: YOLOv8s |
|
- **Dataset**: [PDT Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset) |
|
- **mAP50**: 0.933 |
|
- **mAP50-95**: 0.659 |
|
- **Precision**: 0.878 |
|
- **Recall**: 0.863 |
|
- **Classes**: 1 (unhealthy trees) |
|
|
|
## Usage Tips |
|
|
|
- This model works best with aerial/UAV imagery |
|
- Optimal input resolution: 640x640 pixels |
|
- Adjust confidence threshold to filter detections |
|
- Lower IoU threshold for overlapping trees |
|
|
|
[Model Card](https://huggingface.co/IsmatS/crop_desease_detection) | |
|
[Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset) |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|