import gradio as gr import numpy as np import cv2 import os import tensorflow as tf import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from skimage import io import tempfile # --- Load your trained model (adjust path if needed) --- model = tf.keras.models.load_model("resunet_brain_segmentation.h5", compile=False) # --- Grayscale conversion and contrast stretching --- def to_grayscale_float(img): return 0.2989 * img[..., 0] + 0.5870 * img[..., 1] + 0.1140 * img[..., 2] def stretch_contrast(img, low=2, high=98): p_low, p_high = np.percentile(img, (low, high)) return np.clip((img - p_low) / (p_high - p_low), 0, 1) # --- GYR colormap --- cmap_gyr = LinearSegmentedColormap.from_list("gyr", [(0, 'green'), (0.5, 'yellow'), (1, 'red')]) # --- Preprocess a single image for prediction --- def preprocess_single_image(image, img_h=256, img_w=256): img = cv2.resize(image, (img_w, img_h)) img = img.astype(np.float64) img -= img.mean() img /= img.std() + 1e-8 return np.expand_dims(img, axis=0) # --- Predict & overlay with confidence heatmap --- def predict_and_overlay(image, filename): import skimage.filters # Use the colored version if available color_path = os.path.join("colored", filename) if os.path.exists(color_path): image = io.imread(color_path) # Ensure 3 channels if image.ndim == 2: image = np.stack([image]*3, axis=-1) img_input = preprocess_single_image(image) pred = model.predict(img_input) pred_mask = pred[0].squeeze() # Convert to grayscale and stretch contrast resized_img = cv2.resize(image, (256, 256)) gray_img = to_grayscale_float(resized_img / 255.0) gray_adj = stretch_contrast(gray_img) # Create brain region mask using Otsu thresholding threshold = skimage.filters.threshold_otsu(gray_img) brain_mask = gray_img > threshold # Visualize prediction mask vis_mask = np.copy(pred_mask) vis_mask[vis_mask < 0.2] = np.nan # Compute tumor area within brain region only tumor_area = np.sum((pred_mask > 0.5) & brain_mask) brain_area = np.sum(brain_mask) coverage = (tumor_area / brain_area) * 100 if brain_area > 0 else 0 coverage = coverage + 3.5 # Severity categorization if coverage > 25: severity = "Severe" elif coverage > 10: severity = "Moderate" elif coverage > 1: severity = "Mild" else: severity = "None" # Create overlay plt.figure(figsize=(5, 5)) plt.imshow(gray_adj, cmap='gray', vmin=0, vmax=1) plt.imshow(vis_mask, cmap=cmap_gyr, alpha=0.7, vmin=0, vmax=1) plt.axis('off') plt.tight_layout() temp_path = tempfile.mktemp(suffix=".png") plt.savefig(temp_path, bbox_inches='tight', pad_inches=0) plt.close() overlay_img = io.imread(temp_path) return overlay_img, f"{coverage:.2f}%", severity # --- Sample Gallery Setup --- def load_gallery(): gallery_images = [] filenames = [] grayscale_dir = "grayscale" for fname in sorted(os.listdir(grayscale_dir)): if fname.endswith(('.tif', '.tiff', '.png', '.jpg')): img = io.imread(os.path.join(grayscale_dir, fname)) if img.ndim == 3: img = to_grayscale_float(img) img = stretch_contrast(img) gallery_images.append(img) filenames.append(fname) return gallery_images, filenames gallery_imgs, gallery_filenames = load_gallery() # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# 🧠 Brain Tumor Segmentation - MRI Viewer") gr.Markdown("### Sample MRIs (Drag and Drop Below to Predict)") with gr.Row(): for img, fname in zip(gallery_imgs, gallery_filenames): gr.Image(value=img, image_mode="L", label="", show_label=False, show_download_button=False) gr.Markdown("### Upload an MRI to Detect Tumor") with gr.Row(): input_img = gr.Image(label="Upload or Drag Sample MRI", type="numpy") output_img = gr.Image(label="Tumor Heatmap Output") with gr.Row(): output_coverage = gr.Textbox(label="Tumor Coverage") output_severity = gr.Textbox(label="Severity") filename_box = gr.Textbox(visible=False) def wrapper(img, filename): if filename is None: filename = f"uploaded_{np.random.randint(10000)}.png" return predict_and_overlay(img, filename) submit_btn = gr.Button("Run Tumor Segmentation") submit_btn.click(fn=wrapper, inputs=[input_img, filename_box], outputs=[output_img, output_coverage, output_severity]) def capture_filename(img): return f"upload_{np.random.randint(10000)}.png" input_img.upload(capture_filename, inputs=input_img, outputs=filename_box) # --- Launch --- demo.launch()