File size: 9,305 Bytes
2bbade7
 
 
 
 
 
 
 
 
25bd370
2bbade7
 
25bd370
 
 
 
84ab9ea
2bbade7
 
25bd370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbade7
 
 
 
 
25bd370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbade7
25bd370
2bbade7
84ab9ea
2bbade7
 
 
84ab9ea
2bbade7
 
 
 
 
 
 
84ab9ea
2bbade7
 
 
 
 
 
 
 
 
 
 
 
 
 
25bd370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbade7
 
 
 
25bd370
2bbade7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25bd370
2bbade7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import cv2
import io
import os
import logging
matplotlib.use('Agg')  # Use non-interactive backend

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Image size - matching what the model was trained on
IMG_SIZE = 256

# Function to load model with multiple format fallbacks
def load_model():
    # Try different model formats
    model = None
    
    # Try loading SavedModel format first
    try:
        logger.info("Attempting to load SavedModel format...")
        if os.path.exists('saved_model.pb'):
            model = tf.saved_model.load('.')
            logger.info("Successfully loaded SavedModel format")
            return model, "saved_model"
    except Exception as e:
        logger.info(f"Failed to load SavedModel: {e}")
    
    # Try loading Keras model (.keras)
    try:
        logger.info("Attempting to load .keras format...")
        if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.keras'):
            model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.keras')
            logger.info("Successfully loaded .keras format")
            return model, "keras"
    except Exception as e:
        logger.info(f"Failed to load .keras model: {e}")
    
    # Try loading H5 model
    try:
        logger.info("Attempting to load .h5 format...")
        if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.h5'):
            model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.h5')
            logger.info("Successfully loaded .h5 format")
            return model, "h5"
    except Exception as e:
        logger.info(f"Failed to load .h5 model: {e}")
    
    # Try loading checkpoint model
    try:
        logger.info("Attempting to load checkpoint model...")
        if os.path.exists('binary_model_densenet_checkpoint.keras'):
            model = tf.keras.models.load_model('binary_model_densenet_checkpoint.keras')
            logger.info("Successfully loaded checkpoint model")
            return model, "checkpoint"
    except Exception as e:
        logger.info(f"Failed to load checkpoint model: {e}")
    
    # Try loading TFLite model
    try:
        logger.info("Attempting to load TFLite model...")
        if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.tflite'):
            interpreter = tf.lite.Interpreter(model_path="chest_ct_binary_classifier_densenet_20250427_182239.tflite")
            interpreter.allocate_tensors()
            logger.info("Successfully loaded TFLite model")
            return interpreter, "tflite"
    except Exception as e:
        logger.info(f"Failed to load TFLite model: {e}")
    
    # If all attempts fail, return None
    logger.warning("All model loading attempts failed")
    return None, None

# Load the model
model, model_type = load_model()
logger.info(f"Model loaded, type: {model_type}")

# Function for preprocessing
def preprocess_image(image):
    img = Image.fromarray(image).convert('RGB')
    img = img.resize((IMG_SIZE, IMG_SIZE))
    img_array = np.array(img) / 255.0
    
    # Different model formats might need different input shapes
    if model_type == "tflite":
        return np.expand_dims(img_array, axis=0).astype(np.float32)
    else:
        return np.expand_dims(img_array, axis=0)

# Prediction function for different model types
def get_prediction(img_tensor):
    if model is None:
        # If no model was loaded, use a mock prediction
        logger.warning("Using mock prediction since no model was loaded")
        return 0.75  # Mock cancer probability
        
    if model_type == "saved_model":
        # SavedModel format
        infer = model.signatures["serving_default"]
        input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
        output_tensor_name = list(infer.structured_outputs.keys())[0]
        input_dict = {input_tensor_name: img_tensor}
        output = infer(**input_dict)
        prediction = output[output_tensor_name].numpy()[0][0]
        
    elif model_type == "tflite":
        # TFLite format
        input_details = model.get_input_details()
        output_details = model.get_output_details()
        model.set_tensor(input_details[0]['index'], img_tensor)
        model.invoke()
        prediction = model.get_tensor(output_details[0]['index'])[0][0]
        
    else:
        # Keras or H5 format
        prediction = model.predict(img_tensor)[0][0]
        
    return float(prediction)

# Generate attention map
def generate_attention_map(img_array, prediction):
    # Convert to grayscale
    gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Use edge detection to find interesting regions
    sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3)
    magnitude = np.sqrt(sobelx**2 + sobely**2)
    
    # Normalize to 0-1
    magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
    
    # Weight by prediction confidence
    weight = 0.5 + (prediction - 0.5) * 0.5  # Scale between 0.5-1 based on prediction
    magnitude = magnitude * weight
    
    # Apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    return heatmap, magnitude

# Prediction function with visualization
def predict_and_explain(image):
    if image is None:
        return None, "Please upload an image.", 0.0
    
    try:
        # Preprocess the image
        preprocessed = preprocess_image(image)
        
        # Get prediction
        prediction = get_prediction(preprocessed)
        logger.info(f"Prediction value: {prediction}")
        
        # Generate attention map
        heatmap, attention = generate_attention_map(preprocessed, prediction)
        
        # Create overlay
        original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
        superimposed = (0.6 * original_resized) + (0.4 * heatmap)
        superimposed = superimposed.astype(np.uint8)
        
        # Create visualization with matplotlib
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(original_resized)
        axes[0].set_title("Original CT Scan")
        axes[0].axis('off')
        
        axes[1].imshow(heatmap)
        axes[1].set_title("Feature Map")
        axes[1].axis('off')
        
        axes[2].imshow(superimposed)
        axes[2].set_title(f"Overlay")
        axes[2].axis('off')
        
        # Add prediction information
        result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
        fig.suptitle(result_text, fontsize=16)
        
        # Convert plot to image
        buf = io.BytesIO()
        plt.tight_layout()
        plt.savefig(buf, format='png')
        plt.close(fig)
        buf.seek(0)
        result_image = np.array(Image.open(buf))
        
        # Return prediction information
        prediction_class = "Cancer" if prediction > 0.5 else "Normal"
        confidence = float(prediction if prediction > 0.5 else 1-prediction)
        
        return result_image, prediction_class, confidence
        
    except Exception as e:
        logger.error(f"Error in prediction: {e}")
        # Return a fallback image
        fallback_img = np.ones((400, 800, 3), dtype=np.uint8) * 255
        cv2.putText(fallback_img, f"Error: {str(e)}", (50, 200), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
        return fallback_img, "Error", 0.0

# Create Gradio interface
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
    gr.Markdown("# Chest CT Scan Cancer Detection")
    gr.Markdown(f"### Model: {model_type if model_type else 'None'}")
    gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload CT Scan Image", type="numpy")
            submit_btn = gr.Button("Analyze Image")
        
        with gr.Column():
            output_image = gr.Image(label="Analysis Results")
            prediction_label = gr.Label(label="Prediction")
            confidence_score = gr.Number(label="Confidence Score")
    
    gr.Markdown("### How it works")
    gr.Markdown("""
    This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'.
    
    The visualization shows:
    - Left: Original CT scan
    - Middle: Feature map highlighting areas with distinctive patterns
    - Right: Overlay of the feature map on the original image
    
    The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
    """)
    
    submit_btn.click(
        predict_and_explain,
        inputs=input_image,
        outputs=[output_image, prediction_label, confidence_score]
    )

demo.launch()