IsmatS commited on
Commit
9281224
·
1 Parent(s): 265a095

Update app.py to use local model files

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -2,33 +2,37 @@ import gradio as gr
2
  from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
- from PIL import Image, ImageDraw, ImageFont
6
  import os
7
- import requests
8
- from io import BytesIO
9
 
10
- # Load the model with error handling
11
  def load_model():
12
- try:
13
- # Try to load from HuggingFace first
14
- model = YOLO('IsmatS/crop_desease_detection')
15
- return model, "Custom Tree Disease Detection Model"
16
- except:
17
- try:
18
- # Try direct URL
19
- model = YOLO('https://huggingface.co/IsmatS/crop_desease_detection/resolve/main/best.pt')
20
- return model, "Custom Tree Disease Detection Model"
21
- except:
22
  try:
23
- # Try local file if exists
24
- if os.path.exists('best.pt'):
25
- model = YOLO('best.pt')
26
- return model, "Custom Tree Disease Detection Model (Local)"
27
- except:
28
- # Fallback to standard YOLOv8s
29
- print("Loading standard YOLOv8s model as fallback...")
30
- model = YOLO('yolov8s.pt')
31
- return model, "Standard YOLOv8s Model (Fallback)"
 
 
 
 
 
 
32
 
33
  # Load model and get status
34
  model, model_status = load_model()
@@ -36,12 +40,20 @@ model, model_status = load_model()
36
  def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
37
  """Detect unhealthy trees in the uploaded image"""
38
 
 
 
 
39
  # Convert PIL image to numpy array
40
  image_np = np.array(image)
41
 
42
  # Run inference
43
  results = model(image_np, conf=conf_threshold, iou=iou_threshold)
44
 
 
 
 
 
 
45
  # Extract detections
46
  detections = []
47
  for result in results:
@@ -51,17 +63,14 @@ def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
51
  detection = {
52
  'confidence': float(box.conf[0]),
53
  'bbox': box.xyxy[0].tolist(),
54
- 'class': 'unhealthy' if model_status.startswith("Custom") else 'object'
55
  }
56
  detections.append(detection)
57
 
58
- # Get annotated image directly from results
59
- annotated_img = results[0].plot()
60
- annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
61
- annotated_img = Image.fromarray(annotated_img)
62
-
63
  # Create detection summary
64
- if model_status.startswith("Custom"):
 
 
65
  summary = f"Detected {len(detections)} unhealthy tree(s)\n\n"
66
  for i, det in enumerate(detections, 1):
67
  summary += f"Tree {i}: Confidence {det['confidence']:.2f}\n"
@@ -75,10 +84,11 @@ def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
75
 
76
  return annotated_img, summary
77
 
78
- # Create example images
79
  example_images = [
80
- ["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1841066-1691513468.jpg", 0.25, 0.45], # Tree image
81
- ["https://www.fs.usda.gov/Internet/FSE_MEDIA/fseprd1115588.jpg", 0.25, 0.45], # Another tree
 
82
  ]
83
 
84
  # Create Gradio interface
@@ -92,10 +102,9 @@ with gr.Blocks(title="Tree Disease Detection") as demo:
92
  **Current Model**: {model_status}
93
  """)
94
 
95
- if not model_status.startswith("Custom"):
96
  gr.Markdown("""
97
- ⚠️ **Note**: Currently using a fallback model. The specialized tree disease model is being updated.
98
- Detection will work but won't be specific to tree diseases.
99
  """)
100
 
101
  with gr.Row():
@@ -134,7 +143,7 @@ with gr.Blocks(title="Tree Disease Detection") as demo:
134
  inputs=[input_image, conf_threshold, iou_threshold],
135
  outputs=[output_image, detection_summary],
136
  fn=detect_tree_disease,
137
- cache_examples=False, # Disable caching to avoid initialization issues
138
  )
139
 
140
  gr.Markdown("""
@@ -144,6 +153,8 @@ with gr.Blocks(title="Tree Disease Detection") as demo:
144
  - **Dataset**: [PDT Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
145
  - **mAP50**: 0.933
146
  - **mAP50-95**: 0.659
 
 
147
  - **Classes**: 1 (unhealthy trees)
148
 
149
  ## Usage Tips
@@ -158,4 +169,5 @@ with gr.Blocks(title="Tree Disease Detection") as demo:
158
  """)
159
 
160
  # Launch the app
161
- demo.launch()
 
 
2
  from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
+ from PIL import Image
6
  import os
 
 
7
 
8
+ # Load the model with proper error handling
9
  def load_model():
10
+ model_paths = [
11
+ 'best_model.pt',
12
+ 'tree_disease_detector.pt',
13
+ './best_model.pt',
14
+ './tree_disease_detector.pt'
15
+ ]
16
+
17
+ # Try to load from local files first
18
+ for path in model_paths:
19
+ if os.path.exists(path):
20
  try:
21
+ print(f"Loading model from {path}")
22
+ model = YOLO(path)
23
+ return model, f"Tree Disease Detection Model ({path})"
24
+ except Exception as e:
25
+ print(f"Error loading {path}: {e}")
26
+ continue
27
+
28
+ # Fallback to standard YOLOv8s
29
+ try:
30
+ print("Loading standard YOLOv8s model...")
31
+ model = YOLO('yolov8s.pt')
32
+ return model, "Standard YOLOv8s Model (Fallback)"
33
+ except Exception as e:
34
+ print(f"Error loading YOLOv8s: {e}")
35
+ return None, "No model available"
36
 
37
  # Load model and get status
38
  model, model_status = load_model()
 
40
  def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
41
  """Detect unhealthy trees in the uploaded image"""
42
 
43
+ if model is None:
44
+ return image, "Error: No model available"
45
+
46
  # Convert PIL image to numpy array
47
  image_np = np.array(image)
48
 
49
  # Run inference
50
  results = model(image_np, conf=conf_threshold, iou=iou_threshold)
51
 
52
+ # Get annotated image directly from results
53
+ annotated_img = results[0].plot()
54
+ annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
55
+ annotated_img = Image.fromarray(annotated_img)
56
+
57
  # Extract detections
58
  detections = []
59
  for result in results:
 
63
  detection = {
64
  'confidence': float(box.conf[0]),
65
  'bbox': box.xyxy[0].tolist(),
66
+ 'class': 'unhealthy'
67
  }
68
  detections.append(detection)
69
 
 
 
 
 
 
70
  # Create detection summary
71
+ is_custom_model = "Tree Disease Detection Model" in model_status
72
+
73
+ if is_custom_model:
74
  summary = f"Detected {len(detections)} unhealthy tree(s)\n\n"
75
  for i, det in enumerate(detections, 1):
76
  summary += f"Tree {i}: Confidence {det['confidence']:.2f}\n"
 
84
 
85
  return annotated_img, summary
86
 
87
+ # Create example images (tree images)
88
  example_images = [
89
+ ["https://images.pexels.com/photos/1632790/pexels-photo-1632790.jpeg", 0.25, 0.45],
90
+ ["https://images.pexels.com/photos/38537/woodland-road-falling-leaf-natural-38537.jpeg", 0.25, 0.45],
91
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Ash_Tree_-_geograph.org.uk_-_590710.jpg/640px-Ash_Tree_-_geograph.org.uk_-_590710.jpg", 0.25, 0.45],
92
  ]
93
 
94
  # Create Gradio interface
 
102
  **Current Model**: {model_status}
103
  """)
104
 
105
+ if "Fallback" in model_status:
106
  gr.Markdown("""
107
+ ⚠️ **Note**: Using a fallback model. Detection will work but won't be specific to tree diseases.
 
108
  """)
109
 
110
  with gr.Row():
 
143
  inputs=[input_image, conf_threshold, iou_threshold],
144
  outputs=[output_image, detection_summary],
145
  fn=detect_tree_disease,
146
+ cache_examples=False,
147
  )
148
 
149
  gr.Markdown("""
 
153
  - **Dataset**: [PDT Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
154
  - **mAP50**: 0.933
155
  - **mAP50-95**: 0.659
156
+ - **Precision**: 0.878
157
+ - **Recall**: 0.863
158
  - **Classes**: 1 (unhealthy trees)
159
 
160
  ## Usage Tips
 
169
  """)
170
 
171
  # Launch the app
172
+ if __name__ == "__main__":
173
+ demo.launch()