IsmatS commited on
Commit
25bd370
·
verified ·
1 Parent(s): 1a539a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -50
app.py CHANGED
@@ -7,19 +7,123 @@ import matplotlib
7
  import cv2
8
  import io
9
  import os
 
10
  matplotlib.use('Agg') # Use non-interactive backend
11
 
 
 
 
 
12
  # Image size - matching what the model was trained on
13
  IMG_SIZE = 256
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Function for preprocessing
16
  def preprocess_image(image):
17
  img = Image.fromarray(image).convert('RGB')
18
  img = img.resize((IMG_SIZE, IMG_SIZE))
19
  img_array = np.array(img) / 255.0
20
- return np.expand_dims(img_array, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Generate attention map using edge detection (simplified)
23
  def generate_attention_map(img_array, prediction):
24
  # Convert to grayscale
25
  gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
@@ -48,57 +152,66 @@ def predict_and_explain(image):
48
  if image is None:
49
  return None, "Please upload an image.", 0.0
50
 
51
- # Preprocess the image
52
- preprocessed = preprocess_image(image)
53
-
54
- # For demo, use a fixed prediction
55
- # In a real deployment, you would load and use your model
56
- prediction = 0.75 # Simulated cancer probability
57
-
58
- # Generate attention map
59
- heatmap, attention = generate_attention_map(preprocessed, prediction)
60
-
61
- # Create overlay
62
- original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
63
- superimposed = (0.6 * original_resized) + (0.4 * heatmap)
64
- superimposed = superimposed.astype(np.uint8)
65
-
66
- # Create visualization with matplotlib
67
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
68
-
69
- axes[0].imshow(original_resized)
70
- axes[0].set_title("Original CT Scan")
71
- axes[0].axis('off')
72
-
73
- axes[1].imshow(heatmap)
74
- axes[1].set_title("Feature Map")
75
- axes[1].axis('off')
76
-
77
- axes[2].imshow(superimposed)
78
- axes[2].set_title(f"Overlay")
79
- axes[2].axis('off')
80
-
81
- # Add prediction information
82
- result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
83
- fig.suptitle(result_text, fontsize=16)
84
-
85
- # Convert plot to image
86
- buf = io.BytesIO()
87
- plt.tight_layout()
88
- plt.savefig(buf, format='png')
89
- plt.close(fig)
90
- buf.seek(0)
91
- result_image = np.array(Image.open(buf))
92
-
93
- # Return prediction information
94
- prediction_class = "Cancer" if prediction > 0.5 else "Normal"
95
- confidence = float(prediction if prediction > 0.5 else 1-prediction)
96
-
97
- return result_image, prediction_class, confidence
 
 
 
 
 
 
 
 
98
 
99
  # Create Gradio interface
100
  with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
101
  gr.Markdown("# Chest CT Scan Cancer Detection")
 
102
  gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
103
 
104
  with gr.Row():
@@ -120,7 +233,7 @@ with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
120
  - Middle: Feature map highlighting areas with distinctive patterns
121
  - Right: Overlay of the feature map on the original image
122
 
123
- Note: This is a demonstration version without the full model due to size limitations.
124
  """)
125
 
126
  submit_btn.click(
 
7
  import cv2
8
  import io
9
  import os
10
+ import logging
11
  matplotlib.use('Agg') # Use non-interactive backend
12
 
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
  # Image size - matching what the model was trained on
18
  IMG_SIZE = 256
19
 
20
+ # Function to load model with multiple format fallbacks
21
+ def load_model():
22
+ # Try different model formats
23
+ model = None
24
+
25
+ # Try loading SavedModel format first
26
+ try:
27
+ logger.info("Attempting to load SavedModel format...")
28
+ if os.path.exists('saved_model.pb'):
29
+ model = tf.saved_model.load('.')
30
+ logger.info("Successfully loaded SavedModel format")
31
+ return model, "saved_model"
32
+ except Exception as e:
33
+ logger.info(f"Failed to load SavedModel: {e}")
34
+
35
+ # Try loading Keras model (.keras)
36
+ try:
37
+ logger.info("Attempting to load .keras format...")
38
+ if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.keras'):
39
+ model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.keras')
40
+ logger.info("Successfully loaded .keras format")
41
+ return model, "keras"
42
+ except Exception as e:
43
+ logger.info(f"Failed to load .keras model: {e}")
44
+
45
+ # Try loading H5 model
46
+ try:
47
+ logger.info("Attempting to load .h5 format...")
48
+ if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.h5'):
49
+ model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.h5')
50
+ logger.info("Successfully loaded .h5 format")
51
+ return model, "h5"
52
+ except Exception as e:
53
+ logger.info(f"Failed to load .h5 model: {e}")
54
+
55
+ # Try loading checkpoint model
56
+ try:
57
+ logger.info("Attempting to load checkpoint model...")
58
+ if os.path.exists('binary_model_densenet_checkpoint.keras'):
59
+ model = tf.keras.models.load_model('binary_model_densenet_checkpoint.keras')
60
+ logger.info("Successfully loaded checkpoint model")
61
+ return model, "checkpoint"
62
+ except Exception as e:
63
+ logger.info(f"Failed to load checkpoint model: {e}")
64
+
65
+ # Try loading TFLite model
66
+ try:
67
+ logger.info("Attempting to load TFLite model...")
68
+ if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.tflite'):
69
+ interpreter = tf.lite.Interpreter(model_path="chest_ct_binary_classifier_densenet_20250427_182239.tflite")
70
+ interpreter.allocate_tensors()
71
+ logger.info("Successfully loaded TFLite model")
72
+ return interpreter, "tflite"
73
+ except Exception as e:
74
+ logger.info(f"Failed to load TFLite model: {e}")
75
+
76
+ # If all attempts fail, return None
77
+ logger.warning("All model loading attempts failed")
78
+ return None, None
79
+
80
+ # Load the model
81
+ model, model_type = load_model()
82
+ logger.info(f"Model loaded, type: {model_type}")
83
+
84
  # Function for preprocessing
85
  def preprocess_image(image):
86
  img = Image.fromarray(image).convert('RGB')
87
  img = img.resize((IMG_SIZE, IMG_SIZE))
88
  img_array = np.array(img) / 255.0
89
+
90
+ # Different model formats might need different input shapes
91
+ if model_type == "tflite":
92
+ return np.expand_dims(img_array, axis=0).astype(np.float32)
93
+ else:
94
+ return np.expand_dims(img_array, axis=0)
95
+
96
+ # Prediction function for different model types
97
+ def get_prediction(img_tensor):
98
+ if model is None:
99
+ # If no model was loaded, use a mock prediction
100
+ logger.warning("Using mock prediction since no model was loaded")
101
+ return 0.75 # Mock cancer probability
102
+
103
+ if model_type == "saved_model":
104
+ # SavedModel format
105
+ infer = model.signatures["serving_default"]
106
+ input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
107
+ output_tensor_name = list(infer.structured_outputs.keys())[0]
108
+ input_dict = {input_tensor_name: img_tensor}
109
+ output = infer(**input_dict)
110
+ prediction = output[output_tensor_name].numpy()[0][0]
111
+
112
+ elif model_type == "tflite":
113
+ # TFLite format
114
+ input_details = model.get_input_details()
115
+ output_details = model.get_output_details()
116
+ model.set_tensor(input_details[0]['index'], img_tensor)
117
+ model.invoke()
118
+ prediction = model.get_tensor(output_details[0]['index'])[0][0]
119
+
120
+ else:
121
+ # Keras or H5 format
122
+ prediction = model.predict(img_tensor)[0][0]
123
+
124
+ return float(prediction)
125
 
126
+ # Generate attention map
127
  def generate_attention_map(img_array, prediction):
128
  # Convert to grayscale
129
  gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
 
152
  if image is None:
153
  return None, "Please upload an image.", 0.0
154
 
155
+ try:
156
+ # Preprocess the image
157
+ preprocessed = preprocess_image(image)
158
+
159
+ # Get prediction
160
+ prediction = get_prediction(preprocessed)
161
+ logger.info(f"Prediction value: {prediction}")
162
+
163
+ # Generate attention map
164
+ heatmap, attention = generate_attention_map(preprocessed, prediction)
165
+
166
+ # Create overlay
167
+ original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
168
+ superimposed = (0.6 * original_resized) + (0.4 * heatmap)
169
+ superimposed = superimposed.astype(np.uint8)
170
+
171
+ # Create visualization with matplotlib
172
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
173
+
174
+ axes[0].imshow(original_resized)
175
+ axes[0].set_title("Original CT Scan")
176
+ axes[0].axis('off')
177
+
178
+ axes[1].imshow(heatmap)
179
+ axes[1].set_title("Feature Map")
180
+ axes[1].axis('off')
181
+
182
+ axes[2].imshow(superimposed)
183
+ axes[2].set_title(f"Overlay")
184
+ axes[2].axis('off')
185
+
186
+ # Add prediction information
187
+ result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
188
+ fig.suptitle(result_text, fontsize=16)
189
+
190
+ # Convert plot to image
191
+ buf = io.BytesIO()
192
+ plt.tight_layout()
193
+ plt.savefig(buf, format='png')
194
+ plt.close(fig)
195
+ buf.seek(0)
196
+ result_image = np.array(Image.open(buf))
197
+
198
+ # Return prediction information
199
+ prediction_class = "Cancer" if prediction > 0.5 else "Normal"
200
+ confidence = float(prediction if prediction > 0.5 else 1-prediction)
201
+
202
+ return result_image, prediction_class, confidence
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error in prediction: {e}")
206
+ # Return a fallback image
207
+ fallback_img = np.ones((400, 800, 3), dtype=np.uint8) * 255
208
+ cv2.putText(fallback_img, f"Error: {str(e)}", (50, 200), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
209
+ return fallback_img, "Error", 0.0
210
 
211
  # Create Gradio interface
212
  with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
213
  gr.Markdown("# Chest CT Scan Cancer Detection")
214
+ gr.Markdown(f"### Model: {model_type if model_type else 'None'}")
215
  gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
216
 
217
  with gr.Row():
 
233
  - Middle: Feature map highlighting areas with distinctive patterns
234
  - Right: Overlay of the feature map on the original image
235
 
236
+ The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
237
  """)
238
 
239
  submit_btn.click(