Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import cv2 | |
import io | |
import os | |
import logging | |
matplotlib.use('Agg') # Use non-interactive backend | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Image size - matching what the model was trained on | |
IMG_SIZE = 256 | |
# Function to load model with multiple format fallbacks | |
def load_model(): | |
# Try different model formats | |
model = None | |
# Try loading SavedModel format first | |
try: | |
logger.info("Attempting to load SavedModel format...") | |
if os.path.exists('saved_model.pb'): | |
model = tf.saved_model.load('.') | |
logger.info("Successfully loaded SavedModel format") | |
return model, "saved_model" | |
except Exception as e: | |
logger.info(f"Failed to load SavedModel: {e}") | |
# Try loading Keras model (.keras) | |
try: | |
logger.info("Attempting to load .keras format...") | |
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.keras'): | |
model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.keras') | |
logger.info("Successfully loaded .keras format") | |
return model, "keras" | |
except Exception as e: | |
logger.info(f"Failed to load .keras model: {e}") | |
# Try loading H5 model | |
try: | |
logger.info("Attempting to load .h5 format...") | |
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.h5'): | |
model = tf.keras.models.load_model('chest_ct_binary_classifier_densenet_20250427_182239.h5') | |
logger.info("Successfully loaded .h5 format") | |
return model, "h5" | |
except Exception as e: | |
logger.info(f"Failed to load .h5 model: {e}") | |
# Try loading checkpoint model | |
try: | |
logger.info("Attempting to load checkpoint model...") | |
if os.path.exists('binary_model_densenet_checkpoint.keras'): | |
model = tf.keras.models.load_model('binary_model_densenet_checkpoint.keras') | |
logger.info("Successfully loaded checkpoint model") | |
return model, "checkpoint" | |
except Exception as e: | |
logger.info(f"Failed to load checkpoint model: {e}") | |
# Try loading TFLite model | |
try: | |
logger.info("Attempting to load TFLite model...") | |
if os.path.exists('chest_ct_binary_classifier_densenet_20250427_182239.tflite'): | |
interpreter = tf.lite.Interpreter(model_path="chest_ct_binary_classifier_densenet_20250427_182239.tflite") | |
interpreter.allocate_tensors() | |
logger.info("Successfully loaded TFLite model") | |
return interpreter, "tflite" | |
except Exception as e: | |
logger.info(f"Failed to load TFLite model: {e}") | |
# If all attempts fail, return None | |
logger.warning("All model loading attempts failed") | |
return None, None | |
# Load the model | |
model, model_type = load_model() | |
logger.info(f"Model loaded, type: {model_type}") | |
# Function for preprocessing | |
def preprocess_image(image): | |
img = Image.fromarray(image).convert('RGB') | |
img = img.resize((IMG_SIZE, IMG_SIZE)) | |
img_array = np.array(img) / 255.0 | |
# Different model formats might need different input shapes | |
if model_type == "tflite": | |
return np.expand_dims(img_array, axis=0).astype(np.float32) | |
else: | |
return np.expand_dims(img_array, axis=0) | |
# Prediction function for different model types | |
def get_prediction(img_tensor): | |
if model is None: | |
# If no model was loaded, use a mock prediction | |
logger.warning("Using mock prediction since no model was loaded") | |
return 0.75 # Mock cancer probability | |
if model_type == "saved_model": | |
# SavedModel format | |
infer = model.signatures["serving_default"] | |
input_tensor_name = list(infer.structured_input_signature[1].keys())[0] | |
output_tensor_name = list(infer.structured_outputs.keys())[0] | |
input_dict = {input_tensor_name: img_tensor} | |
output = infer(**input_dict) | |
prediction = output[output_tensor_name].numpy()[0][0] | |
elif model_type == "tflite": | |
# TFLite format | |
input_details = model.get_input_details() | |
output_details = model.get_output_details() | |
model.set_tensor(input_details[0]['index'], img_tensor) | |
model.invoke() | |
prediction = model.get_tensor(output_details[0]['index'])[0][0] | |
else: | |
# Keras or H5 format | |
prediction = model.predict(img_tensor)[0][0] | |
return float(prediction) | |
# Generate attention map | |
def generate_attention_map(img_array, prediction): | |
# Convert to grayscale | |
gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY) | |
blur = cv2.GaussianBlur(gray, (5, 5), 0) | |
# Use edge detection to find interesting regions | |
sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3) | |
sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3) | |
magnitude = np.sqrt(sobelx**2 + sobely**2) | |
# Normalize to 0-1 | |
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8) | |
# Weight by prediction confidence | |
weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction | |
magnitude = magnitude * weight | |
# Apply colormap | |
heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET) | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
return heatmap, magnitude | |
# Prediction function with visualization | |
def predict_and_explain(image): | |
if image is None: | |
return None, "Please upload an image.", 0.0 | |
try: | |
# Preprocess the image | |
preprocessed = preprocess_image(image) | |
# Get prediction | |
prediction = get_prediction(preprocessed) | |
logger.info(f"Prediction value: {prediction}") | |
# Generate attention map | |
heatmap, attention = generate_attention_map(preprocessed, prediction) | |
# Create overlay | |
original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) | |
superimposed = (0.6 * original_resized) + (0.4 * heatmap) | |
superimposed = superimposed.astype(np.uint8) | |
# Create visualization with matplotlib | |
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
axes[0].imshow(original_resized) | |
axes[0].set_title("Original CT Scan") | |
axes[0].axis('off') | |
axes[1].imshow(heatmap) | |
axes[1].set_title("Feature Map") | |
axes[1].axis('off') | |
axes[2].imshow(superimposed) | |
axes[2].set_title(f"Overlay") | |
axes[2].axis('off') | |
# Add prediction information | |
result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})" | |
fig.suptitle(result_text, fontsize=16) | |
# Convert plot to image | |
buf = io.BytesIO() | |
plt.tight_layout() | |
plt.savefig(buf, format='png') | |
plt.close(fig) | |
buf.seek(0) | |
result_image = np.array(Image.open(buf)) | |
# Return prediction information | |
prediction_class = "Cancer" if prediction > 0.5 else "Normal" | |
confidence = float(prediction if prediction > 0.5 else 1-prediction) | |
return result_image, prediction_class, confidence | |
except Exception as e: | |
logger.error(f"Error in prediction: {e}") | |
# Return a fallback image | |
fallback_img = np.ones((400, 800, 3), dtype=np.uint8) * 255 | |
cv2.putText(fallback_img, f"Error: {str(e)}", (50, 200), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) | |
return fallback_img, "Error", 0.0 | |
# Create Gradio interface | |
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo: | |
gr.Markdown("# Chest CT Scan Cancer Detection") | |
gr.Markdown(f"### Model: {model_type if model_type else 'None'}") | |
gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload CT Scan Image", type="numpy") | |
submit_btn = gr.Button("Analyze Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Analysis Results") | |
prediction_label = gr.Label(label="Prediction") | |
confidence_score = gr.Number(label="Confidence Score") | |
gr.Markdown("### How it works") | |
gr.Markdown(""" | |
This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'. | |
The visualization shows: | |
- Left: Original CT scan | |
- Middle: Feature map highlighting areas with distinctive patterns | |
- Right: Overlay of the feature map on the original image | |
The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma). | |
""") | |
submit_btn.click( | |
predict_and_explain, | |
inputs=input_image, | |
outputs=[output_image, prediction_label, confidence_score] | |
) | |
demo.launch() |