import gradio as gr import tensorflow as tf import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib import cv2 import io import os import logging matplotlib.use('Agg') # Use non-interactive backend # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Image size - matching what the model was trained on IMG_SIZE = 256 # Function to load model with multiple format fallbacks def load_model(): # Try different model formats model = None # Try loading SavedModel format first try: logger.info("Attempting to load SavedModel format...") if os.path.exists('saved_model.pb'): model = tf.saved_model.load('.') logger.info("Successfully loaded SavedModel format") return model, "saved_model" except Exception as e: logger.info(f"Failed to load SavedModel: {e}") # Try loading Keras model (.keras) try: logger.info("Attempting to load .keras format...") if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.keras'): model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.keras') logger.info("Successfully loaded .keras format") return model, "keras" except Exception as e: logger.info(f"Failed to load .keras model: {e}") # Try loading H5 model try: logger.info("Attempting to load .h5 format...") if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.h5'): model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.h5') logger.info("Successfully loaded .h5 format") return model, "h5" except Exception as e: logger.info(f"Failed to load .h5 model: {e}") # Try loading checkpoint model try: logger.info("Attempting to load checkpoint model...") if os.path.exists('binary_model_densenet_checkpoint.keras'): model = tf.keras.models.load_model('binary_model_densenet_checkpoint.keras') logger.info("Successfully loaded checkpoint model") return model, "checkpoint" except Exception as e: logger.info(f"Failed to load checkpoint model: {e}") # Try loading TFLite model try: logger.info("Attempting to load TFLite model...") if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.tflite'): interpreter = tf.lite.Interpreter(model_path="chest_ct_binary_classifier_densenet_20250427_182239.tflite") interpreter.allocate_tensors() logger.info("Successfully loaded TFLite model") return interpreter, "tflite" except Exception as e: logger.info(f"Failed to load TFLite model: {e}") # If all attempts fail, return None logger.warning("All model loading attempts failed") return None, None # Load the model model, model_type = load_model() logger.info(f"Model loaded, type: {model_type}") # Function for preprocessing def preprocess_image(image): img = Image.fromarray(image).convert('RGB') img = img.resize((IMG_SIZE, IMG_SIZE)) img_array = np.array(img) / 255.0 # Different model formats might need different input shapes if model_type == "tflite": return np.expand_dims(img_array, axis=0).astype(np.float32) else: return np.expand_dims(img_array, axis=0) # Prediction function for different model types def get_prediction(img_tensor): if model is None: # If no model was loaded, use a mock prediction logger.warning("Using mock prediction since no model was loaded") return 0.75 # Mock cancer probability if model_type == "saved_model": # SavedModel format infer = model.signatures["serving_default"] input_tensor_name = list(infer.structured_input_signature[1].keys())[0] output_tensor_name = list(infer.structured_outputs.keys())[0] input_dict = {input_tensor_name: img_tensor} output = infer(**input_dict) prediction = output[output_tensor_name].numpy()[0][0] elif model_type == "tflite": # TFLite format input_details = model.get_input_details() output_details = model.get_output_details() model.set_tensor(input_details[0]['index'], img_tensor) model.invoke() prediction = model.get_tensor(output_details[0]['index'])[0][0] else: # Keras or H5 format prediction = model.predict(img_tensor)[0][0] return float(prediction) # Generate attention map def generate_attention_map(img_array, prediction): # Convert to grayscale gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY) blur = cv2.GaussianBlur(gray, (5, 5), 0) # Use edge detection to find interesting regions sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3) sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3) magnitude = np.sqrt(sobelx**2 + sobely**2) # Normalize to 0-1 magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8) # Weight by prediction confidence weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction magnitude = magnitude * weight # Apply colormap heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) return heatmap, magnitude # Prediction function with visualization def predict_and_explain(image): if image is None: return None, "Please upload an image.", 0.0 try: # Preprocess the image preprocessed = preprocess_image(image) # Get prediction prediction = get_prediction(preprocessed) logger.info(f"Prediction value: {prediction}") # Generate attention map heatmap, attention = generate_attention_map(preprocessed, prediction) # Create overlay original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) superimposed = (0.6 * original_resized) + (0.4 * heatmap) superimposed = superimposed.astype(np.uint8) # Create visualization with matplotlib fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(original_resized) axes[0].set_title("Original CT Scan") axes[0].axis('off') axes[1].imshow(heatmap) axes[1].set_title("Feature Map") axes[1].axis('off') axes[2].imshow(superimposed) axes[2].set_title(f"Overlay") axes[2].axis('off') # Add prediction information result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})" fig.suptitle(result_text, fontsize=16) # Convert plot to image buf = io.BytesIO() plt.tight_layout() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) result_image = np.array(Image.open(buf)) # Return prediction information prediction_class = "Cancer" if prediction > 0.5 else "Normal" confidence = float(prediction if prediction > 0.5 else 1-prediction) return result_image, prediction_class, confidence except Exception as e: logger.error(f"Error in prediction: {e}") # Return a fallback image fallback_img = np.ones((400, 800, 3), dtype=np.uint8) * 255 cv2.putText(fallback_img, f"Error: {str(e)}", (50, 200), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) return fallback_img, "Error", 0.0 # Create Gradio interface with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo: gr.Markdown("# Chest CT Scan Cancer Detection") gr.Markdown(f"### Model: {model_type if model_type else 'None'}") gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload CT Scan Image", type="numpy") submit_btn = gr.Button("Analyze Image") with gr.Column(): output_image = gr.Image(label="Analysis Results") prediction_label = gr.Label(label="Prediction") confidence_score = gr.Number(label="Confidence Score") gr.Markdown("### How it works") gr.Markdown(""" This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'. The visualization shows: - Left: Original CT scan - Middle: Feature map highlighting areas with distinctive patterns - Right: Overlay of the feature map on the original image The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma). """) submit_btn.click( predict_and_explain, inputs=input_image, outputs=[output_image, prediction_label, confidence_score] ) demo.launch()