IsmatS's picture
Update app.py
25bd370 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import cv2
import io
import os
import logging
matplotlib.use('Agg') # Use non-interactive backend
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Image size - matching what the model was trained on
IMG_SIZE = 256
# Function to load model with multiple format fallbacks
def load_model():
# Try different model formats
model = None
# Try loading SavedModel format first
try:
logger.info("Attempting to load SavedModel format...")
if os.path.exists('saved_model.pb'):
model = tf.saved_model.load('.')
logger.info("Successfully loaded SavedModel format")
return model, "saved_model"
except Exception as e:
logger.info(f"Failed to load SavedModel: {e}")
# Try loading Keras model (.keras)
try:
logger.info("Attempting to load .keras format...")
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.keras'):
model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.keras')
logger.info("Successfully loaded .keras format")
return model, "keras"
except Exception as e:
logger.info(f"Failed to load .keras model: {e}")
# Try loading H5 model
try:
logger.info("Attempting to load .h5 format...")
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.h5'):
model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.h5')
logger.info("Successfully loaded .h5 format")
return model, "h5"
except Exception as e:
logger.info(f"Failed to load .h5 model: {e}")
# Try loading checkpoint model
try:
logger.info("Attempting to load checkpoint model...")
if os.path.exists('binary_model_densenet_checkpoint.keras'):
model = tf.keras.models.load_model('binary_model_densenet_checkpoint.keras')
logger.info("Successfully loaded checkpoint model")
return model, "checkpoint"
except Exception as e:
logger.info(f"Failed to load checkpoint model: {e}")
# Try loading TFLite model
try:
logger.info("Attempting to load TFLite model...")
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.tflite'):
interpreter = tf.lite.Interpreter(model_path="chest_ct_binary_classifier_densenet_20250427_182239.tflite")
interpreter.allocate_tensors()
logger.info("Successfully loaded TFLite model")
return interpreter, "tflite"
except Exception as e:
logger.info(f"Failed to load TFLite model: {e}")
# If all attempts fail, return None
logger.warning("All model loading attempts failed")
return None, None
# Load the model
model, model_type = load_model()
logger.info(f"Model loaded, type: {model_type}")
# Function for preprocessing
def preprocess_image(image):
img = Image.fromarray(image).convert('RGB')
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = np.array(img) / 255.0
# Different model formats might need different input shapes
if model_type == "tflite":
return np.expand_dims(img_array, axis=0).astype(np.float32)
else:
return np.expand_dims(img_array, axis=0)
# Prediction function for different model types
def get_prediction(img_tensor):
if model is None:
# If no model was loaded, use a mock prediction
logger.warning("Using mock prediction since no model was loaded")
return 0.75 # Mock cancer probability
if model_type == "saved_model":
# SavedModel format
infer = model.signatures["serving_default"]
input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
output_tensor_name = list(infer.structured_outputs.keys())[0]
input_dict = {input_tensor_name: img_tensor}
output = infer(**input_dict)
prediction = output[output_tensor_name].numpy()[0][0]
elif model_type == "tflite":
# TFLite format
input_details = model.get_input_details()
output_details = model.get_output_details()
model.set_tensor(input_details[0]['index'], img_tensor)
model.invoke()
prediction = model.get_tensor(output_details[0]['index'])[0][0]
else:
# Keras or H5 format
prediction = model.predict(img_tensor)[0][0]
return float(prediction)
# Generate attention map
def generate_attention_map(img_array, prediction):
# Convert to grayscale
gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
blur = cv2.GaussianBlur(gray, (5, 5), 0)
# Use edge detection to find interesting regions
sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3)
sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3)
magnitude = np.sqrt(sobelx**2 + sobely**2)
# Normalize to 0-1
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
# Weight by prediction confidence
weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction
magnitude = magnitude * weight
# Apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
return heatmap, magnitude
# Prediction function with visualization
def predict_and_explain(image):
if image is None:
return None, "Please upload an image.", 0.0
try:
# Preprocess the image
preprocessed = preprocess_image(image)
# Get prediction
prediction = get_prediction(preprocessed)
logger.info(f"Prediction value: {prediction}")
# Generate attention map
heatmap, attention = generate_attention_map(preprocessed, prediction)
# Create overlay
original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
superimposed = (0.6 * original_resized) + (0.4 * heatmap)
superimposed = superimposed.astype(np.uint8)
# Create visualization with matplotlib
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original_resized)
axes[0].set_title("Original CT Scan")
axes[0].axis('off')
axes[1].imshow(heatmap)
axes[1].set_title("Feature Map")
axes[1].axis('off')
axes[2].imshow(superimposed)
axes[2].set_title(f"Overlay")
axes[2].axis('off')
# Add prediction information
result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
fig.suptitle(result_text, fontsize=16)
# Convert plot to image
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
result_image = np.array(Image.open(buf))
# Return prediction information
prediction_class = "Cancer" if prediction > 0.5 else "Normal"
confidence = float(prediction if prediction > 0.5 else 1-prediction)
return result_image, prediction_class, confidence
except Exception as e:
logger.error(f"Error in prediction: {e}")
# Return a fallback image
fallback_img = np.ones((400, 800, 3), dtype=np.uint8) * 255
cv2.putText(fallback_img, f"Error: {str(e)}", (50, 200), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
return fallback_img, "Error", 0.0
# Create Gradio interface
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
gr.Markdown("# Chest CT Scan Cancer Detection")
gr.Markdown(f"### Model: {model_type if model_type else 'None'}")
gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload CT Scan Image", type="numpy")
submit_btn = gr.Button("Analyze Image")
with gr.Column():
output_image = gr.Image(label="Analysis Results")
prediction_label = gr.Label(label="Prediction")
confidence_score = gr.Number(label="Confidence Score")
gr.Markdown("### How it works")
gr.Markdown("""
This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'.
The visualization shows:
- Left: Original CT scan
- Middle: Feature map highlighting areas with distinctive patterns
- Right: Overlay of the feature map on the original image
The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
""")
submit_btn.click(
predict_and_explain,
inputs=input_image,
outputs=[output_image, prediction_label, confidence_score]
)
demo.launch()