File size: 4,964 Bytes
c7bd69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import SegformerForImageClassification
import google.generativeai as genai
import io

# Initialize Gemini API
genai.configure(api_key="AIzaSyDD8QW1BggDVVMLteDygHCHrD6Ff9Dy0e8")
gemini_model = genai.GenerativeModel('gemini-2.0-flash')

# Load the MRI vs Non-MRI model
mri_classifier = tf.keras.models.load_model("alzheimers_detection_model.h5")

# Load Alzheimer's and Brain Tumor models
alzheimers_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
alzheimers_model.classifier = torch.nn.Linear(alzheimers_model.classifier.in_features, 4)
alzheimers_model.load_state_dict(torch.load('alzheimers_model.pth', map_location=torch.device('cpu')))
alzheimers_model.eval()

brain_tumor_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
brain_tumor_model.classifier = torch.nn.Linear(brain_tumor_model.classifier.in_features, 4)
brain_tumor_model.load_state_dict(torch.load('brain_tumor_model.pth', map_location=torch.device('cpu')))
brain_tumor_model.eval()

# Define class labels
mri_classes = ["Brain MRI", "Not a Brain MRI"]
alzheimers_classes = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
brain_tumor_classes = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def generate_medical_report(diagnosis):
    prompt = f"""
    Generate a detailed medical report for a patient diagnosed with {diagnosis}.
    Include possible causes, symptoms, treatment options, and prognosis.
    Conclude the report with the signature: Team BrainTech.ai.
    """
    response = gemini_model.generate_content(prompt)
    return response.text.strip()

def predict_pipeline(image, model_type):
    # Step 1: Check if it's an MRI
    image_resized = image.resize((224, 224))
    image_array = np.array(image_resized) / 255.0
    image_array = np.expand_dims(image_array, axis=0)
    mri_prediction = mri_classifier.predict(image_array)
    mri_class = mri_classes[np.argmax(mri_prediction)]
    mri_confidence = np.max(mri_prediction) * 100  # Confidence score in %

    if mri_class == "Not a Brain MRI":
        return "Not a Brain MRI", None, None

    # Step 2: Classify MRI
    image_tensor = transform(image).unsqueeze(0)
    if model_type == "Alzheimer's":
        with torch.no_grad():
            outputs = alzheimers_model(image_tensor).logits
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence = torch.max(probabilities).item() * 100  # Confidence in %
        predicted_class = alzheimers_classes[torch.argmax(outputs).item()]
    elif model_type == "Brain Tumor":
        with torch.no_grad():
            outputs = brain_tumor_model(image_tensor).logits
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence = torch.max(probabilities).item() * 100  # Confidence in %
        predicted_class = brain_tumor_classes[torch.argmax(outputs).item()]
    
    # Step 3: Generate medical report
    report = generate_medical_report(predicted_class)

    return predicted_class, confidence, report

def download_report(report_text):
    """Convert report text into a downloadable format."""
    buffer = io.BytesIO()
    buffer.write(report_text.encode())
    buffer.seek(0)
    return buffer

# Streamlit UI
st.title("MRI Scan Classification Pipeline with Gemini AI")
st.write("Upload an image to check if it's an MRI, classify it, view confidence scores, and get an AI-generated medical report.")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
model_type = st.selectbox("Select Model Type", ["Alzheimer's", "Brain Tumor"])

if st.button("Predict") and uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    st.write("Classifying...")

    # Run the prediction pipeline
    result, confidence, report = predict_pipeline(image, model_type)

    # Display results
    st.write(f"**Prediction:** {result}")
    if confidence is not None:
        st.write(f"**Confidence Score:** {confidence:.2f}%")

    # Display AI-Generated Report
    if report:
        st.subheader("AI-Generated Medical Report")
        st.write(report)

        # Download Report Button
        report_buffer = download_report(report)
        st.download_button(
            label="Download Medical Report",
            data=report_buffer,
            file_name=f"medical_report_{result.replace(' ', '_')}.txt",
            mime="text/plain"
        )

        # Warning Banner
        st.warning("⚠️ Please consult a doctor before taking any medical decisions based on this report.")