import streamlit as st import tensorflow as tf import numpy as np from PIL import Image import torch from torchvision import transforms from transformers import SegformerForImageClassification import google.generativeai as genai import io # Initialize Gemini API genai.configure(api_key="AIzaSyDD8QW1BggDVVMLteDygHCHrD6Ff9Dy0e8") gemini_model = genai.GenerativeModel('gemini-2.0-flash') # Load the MRI vs Non-MRI model mri_classifier = tf.keras.models.load_model("alzheimers_detection_model.h5") # Load Alzheimer's and Brain Tumor models alzheimers_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1') alzheimers_model.classifier = torch.nn.Linear(alzheimers_model.classifier.in_features, 4) alzheimers_model.load_state_dict(torch.load('alzheimers_model.pth', map_location=torch.device('cpu'))) alzheimers_model.eval() brain_tumor_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1') brain_tumor_model.classifier = torch.nn.Linear(brain_tumor_model.classifier.in_features, 4) brain_tumor_model.load_state_dict(torch.load('brain_tumor_model.pth', map_location=torch.device('cpu'))) brain_tumor_model.eval() # Define class labels mri_classes = ["Brain MRI", "Not a Brain MRI"] alzheimers_classes = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia'] brain_tumor_classes = ['glioma', 'meningioma', 'notumor', 'pituitary'] # Define transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def generate_medical_report(diagnosis): prompt = f""" Generate a detailed medical report for a patient diagnosed with {diagnosis}. Include possible causes, symptoms, treatment options, and prognosis. Conclude the report with the signature: Team BrainTech.ai. """ response = gemini_model.generate_content(prompt) return response.text.strip() def predict_pipeline(image, model_type): # Step 1: Check if it's an MRI image_resized = image.resize((224, 224)) image_array = np.array(image_resized) / 255.0 image_array = np.expand_dims(image_array, axis=0) mri_prediction = mri_classifier.predict(image_array) mri_class = mri_classes[np.argmax(mri_prediction)] mri_confidence = np.max(mri_prediction) * 100 # Confidence score in % if mri_class == "Not a Brain MRI": return "Not a Brain MRI", None, None # Step 2: Classify MRI image_tensor = transform(image).unsqueeze(0) if model_type == "Alzheimer's": with torch.no_grad(): outputs = alzheimers_model(image_tensor).logits probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence = torch.max(probabilities).item() * 100 # Confidence in % predicted_class = alzheimers_classes[torch.argmax(outputs).item()] elif model_type == "Brain Tumor": with torch.no_grad(): outputs = brain_tumor_model(image_tensor).logits probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence = torch.max(probabilities).item() * 100 # Confidence in % predicted_class = brain_tumor_classes[torch.argmax(outputs).item()] # Step 3: Generate medical report report = generate_medical_report(predicted_class) return predicted_class, confidence, report def download_report(report_text): """Convert report text into a downloadable format.""" buffer = io.BytesIO() buffer.write(report_text.encode()) buffer.seek(0) return buffer # Streamlit UI st.title("MRI Scan Classification Pipeline with Gemini AI") st.write("Upload an image to check if it's an MRI, classify it, view confidence scores, and get an AI-generated medical report.") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) model_type = st.selectbox("Select Model Type", ["Alzheimer's", "Brain Tumor"]) if st.button("Predict") and uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) st.write("Classifying...") # Run the prediction pipeline result, confidence, report = predict_pipeline(image, model_type) # Display results st.write(f"**Prediction:** {result}") if confidence is not None: st.write(f"**Confidence Score:** {confidence:.2f}%") # Display AI-Generated Report if report: st.subheader("AI-Generated Medical Report") st.write(report) # Download Report Button report_buffer = download_report(report) st.download_button( label="Download Medical Report", data=report_buffer, file_name=f"medical_report_{result.replace(' ', '_')}.txt", mime="text/plain" ) # Warning Banner st.warning("⚠️ Please consult a doctor before taking any medical decisions based on this report.")