import cv2 import torch import numpy as np import torch.nn.functional as F from torch import nn from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation import streamlit as st from PIL import Image import io import zipfile import pandas as pd from datetime import datetime import os import tempfile import base64 # Add at the top with other constants MODEL_OPTIONS = { "Default (ferferefer/segformer)": "ferferefer/segformer", "Pamixsun": "pamixsun/segformer_for_optic_disc_cup_segmentation" } # --- GlaucomaModel Class --- class GlaucomaModel(object): def __init__(self, cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", seg_model_path=None, # Make this optional device=torch.device('cpu')): self.device = device # Classification model for glaucoma (always the same) self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() # Segmentation model - use provided path or default seg_path = seg_model_path or MODEL_OPTIONS["Pamixsun"] # Default to Pamixsun if none provided self.seg_extractor = AutoImageProcessor.from_pretrained(seg_path) self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_path).to(device).eval() # Mapping for class labels self.cls_id2label = self.cls_model.config.id2label def glaucoma_pred(self, image): inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.cls_model(**inputs).logits probs = F.softmax(outputs, dim=-1) disease_idx = probs.cpu()[0, :].numpy().argmax() confidence = probs.cpu()[0, disease_idx].item() * 100 return disease_idx, confidence def optic_disc_cup_pred(self, image): inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") with torch.no_grad(): inputs.to(self.device) outputs = self.seg_model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.shape[:2], mode="bilinear", align_corners=False ) seg_probs = F.softmax(upsampled_logits, dim=1) pred_disc_cup = upsampled_logits.argmax(dim=1)[0] # Calculate segmentation confidence based on probability distribution # For each pixel classified as cup/disc, check how confident the model is cup_mask = pred_disc_cup == 2 disc_mask = pred_disc_cup == 1 # Get confidence only for pixels predicted as cup/disc cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence def process(self, image): disease_idx, cls_confidence = self.glaucoma_pred(image) disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) try: vcdr = simple_vcdr(disc_cup) except: vcdr = np.nan mask = (disc_cup > 0).astype(np.uint8) x, y, w, h = cv2.boundingRect(mask) padding = max(50, int(0.2 * max(w, h))) x = max(x - padding, 0) y = max(y - padding, 0) w = min(w + 2 * padding, image.shape[1] - x) h = min(h + 2 * padding, image.shape[0] - y) cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image # --- Utility Functions --- def simple_vcdr(mask): disc_area = np.sum(mask == 1) cup_area = np.sum(mask == 2) if disc_area == 0: return np.nan vcdr = cup_area / disc_area return vcdr def add_mask(image, mask, classes, colors, alpha=0.5): overlay = image.copy() for class_id, color in zip(classes, colors): overlay[mask == class_id] = color output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) return output, overlay def get_confidence_level(confidence): """Enhanced confidence descriptions for segmentation""" if confidence >= 90: return "Excellent (Model is very certain about the detected boundaries)" elif confidence >= 75: return "Good (Model is confident about most of the detected area)" elif confidence >= 60: return "Fair (Model has some uncertainty in parts of the detection)" elif confidence >= 45: return "Poor (Model is uncertain about many detected areas)" else: return "Very Poor (Model's detection is highly uncertain)" def process_batch(model, images_data, progress_bar=None): results = [] for idx, (file_name, image) in enumerate(images_data): try: disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) results.append({ 'file_name': file_name, 'diagnosis': model.cls_id2label[disease_idx], 'confidence': cls_conf, 'vcdr': vcdr, 'cup_conf': cup_conf, 'disc_conf': disc_conf, 'processed_image': disc_cup_image, 'cropped_image': cropped_image }) if progress_bar: progress_bar.progress((idx + 1) / len(images_data)) except Exception as e: st.error(f"Error processing {file_name}: {str(e)}") return results def save_results(results, original_images): # Create temporary directory for results with tempfile.TemporaryDirectory() as temp_dir: # Save report as CSV df = pd.DataFrame([{ 'File': r['file_name'], 'Diagnosis': r['diagnosis'], 'Confidence (%)': f"{r['confidence']:.1f}", 'VCDR': f"{r['vcdr']:.3f}", 'Cup Confidence (%)': f"{r['cup_conf']:.1f}", 'Disc Confidence (%)': f"{r['disc_conf']:.1f}" } for r in results]) report_path = os.path.join(temp_dir, 'report.csv') df.to_csv(report_path, index=False) # Save processed images for result, orig_img in zip(results, original_images): img_name = result['file_name'] base_name = os.path.splitext(img_name)[0] # Save original orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") Image.fromarray(orig_img).save(orig_path) # Save segmentation seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") Image.fromarray(result['processed_image']).save(seg_path) # Save ROI roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") Image.fromarray(result['cropped_image']).save(roi_path) # Create ZIP file zip_path = os.path.join(temp_dir, 'results.zip') with zipfile.ZipFile(zip_path, 'w') as zipf: for root, _, files in os.walk(temp_dir): for file in files: if file != 'results.zip': file_path = os.path.join(root, file) arcname = os.path.basename(file_path) zipf.write(file_path, arcname) with open(zip_path, 'rb') as f: return f.read() # --- Streamlit Interface --- def main(): # Use the old layout setting method st.set_page_config(layout="wide") # Use simple title instead of markdown st.title("Glaucoma Screening from Retinal Fundus Images") st.write("Upload retinal images for automated glaucoma detection and optic disc/cup segmentation") # Add model selection in sidebar before file upload st.sidebar.title("Model Settings") selected_model = st.sidebar.selectbox( "Select Segmentation Model", list(MODEL_OPTIONS.keys()), index=1 # Default to Pamixsun ) st.sidebar.title("Upload Images") st.set_option('deprecation.showfileUploaderEncoding', False) # Important for old versions uploaded_files = st.sidebar.file_uploader( "Upload retinal images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True ) # Simple explanation in sidebar st.sidebar.markdown(""" ### Understanding Results: - Diagnosis Confidence: AI certainty level - VCDR: Cup to disc ratio (>0.7 high risk) - Segmentation: Accuracy of detection """) if uploaded_files: try: # Enhanced model loading feedback st.write("šŸ¤– Initializing AI models...") st.write(f"ā€¢ Loading classification model: pamixsun/swinv2_tiny_for_glaucoma_classification") st.write(f"ā€¢ Loading segmentation model: {selected_model}") # Initialize model with selected segmentation model model = GlaucomaModel( device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), seg_model_path=MODEL_OPTIONS[selected_model] ) # Show model loading completion st.write("āœ… Models loaded successfully") st.write(f"šŸ–„ļø Using: {'GPU' if torch.cuda.is_available() else 'CPU'} for processing") st.write("---") for file in uploaded_files: try: # Process each image with enhanced feedback st.write(f"šŸ“ø Processing image: {file.name}") image = Image.open(file).convert('RGB') image_np = np.array(image) # Get predictions disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) # Enhanced results display st.write("---") st.write(f"Results for {file.name}") # Diagnosis section st.write("šŸ“Š **Diagnosis Results:**") st.write(f"ā€¢ Finding: {model.cls_id2label[disease_idx]}") st.write(f"ā€¢ AI Confidence: {cls_conf:.1f}% ({get_confidence_level(cls_conf)})") # Enhanced Segmentation confidence section with detailed explanations st.write("\nšŸ” **Understanding Segmentation Confidence:**") st.write(""" Segmentation confidence shows how certain the AI is about each pixel it classified: ā€¢ For the Optic Cup (central depression): - Measures the AI's certainty that the red-colored pixels are truly part of the cup - Higher confidence means clearer cup boundaries and more reliable VCDR ā€¢ For the Optic Disc (entire circular area): - Indicates how sure the AI is about the green-outlined disc boundary - Higher confidence suggests better disc margin visibility Confidence scores are calculated by averaging the model's certainty for each pixel it identified as cup or disc. A score of 100% would mean the model is absolutely certain about every pixel's classification. """) st.write("\nšŸ“Š **Current Segmentation Confidence Scores:**") st.write(f"ā€¢ Optic Cup Detection: {cup_conf:.1f}% - {get_confidence_level(cup_conf)}") st.write(f"ā€¢ Optic Disc Detection: {disc_conf:.1f}% - {get_confidence_level(disc_conf)}") # Add interpretation guidance if cup_conf >= 75 and disc_conf >= 75: st.write("āœ… High confidence scores indicate reliable measurements") elif cup_conf < 60 or disc_conf < 60: st.write(""" āš ļø Lower confidence scores might be due to: ā€¢ Image quality issues (blur, poor contrast) ā€¢ Unusual anatomical variations ā€¢ Pathological changes affecting visibility ā€¢ Poor image centering or focus Consider retaking the image if possible. """) # Clinical metrics st.write("\nšŸ“ **Clinical Measurements:**") st.write(f"ā€¢ Cup-to-Disc Ratio (VCDR): {vcdr:.3f}") if vcdr > 0.7: st.write(" āš ļø High VCDR - Potential risk indicator") elif vcdr > 0.5: st.write(" ā„¹ļø Borderline VCDR - Follow-up recommended") else: st.write(" āœ… Normal VCDR range") # Image display with enhanced captions st.write("\nšŸ–¼ļø **Visual Analysis:**") st.image(disc_cup_image, caption=""" Segmentation Overlay ā€¢ Green outline: Optic Disc boundary ā€¢ Red area: Optic Cup region ā€¢ Transparency shows underlying retina """) st.image(cropped_image, caption="Zoomed Region of Interest") # Add quality note if needed if cup_conf < 60 or disc_conf < 60: st.write("\nāš ļø Note: Low segmentation confidence. Image quality might affect measurements.") except Exception as e: st.error(f"Error processing {file.name}: {str(e)}") continue # Simple summary at the end st.write("---") st.write("Processing complete!") except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()