BrainTech.ai / app.py
arittrabag's picture
Added app.py
c7bd69d verified
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import SegformerForImageClassification
import google.generativeai as genai
import io
# Initialize Gemini API
genai.configure(api_key="AIzaSyDD8QW1BggDVVMLteDygHCHrD6Ff9Dy0e8")
gemini_model = genai.GenerativeModel('gemini-2.0-flash')
# Load the MRI vs Non-MRI model
mri_classifier = tf.keras.models.load_model("alzheimers_detection_model.h5")
# Load Alzheimer's and Brain Tumor models
alzheimers_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
alzheimers_model.classifier = torch.nn.Linear(alzheimers_model.classifier.in_features, 4)
alzheimers_model.load_state_dict(torch.load('alzheimers_model.pth', map_location=torch.device('cpu')))
alzheimers_model.eval()
brain_tumor_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
brain_tumor_model.classifier = torch.nn.Linear(brain_tumor_model.classifier.in_features, 4)
brain_tumor_model.load_state_dict(torch.load('brain_tumor_model.pth', map_location=torch.device('cpu')))
brain_tumor_model.eval()
# Define class labels
mri_classes = ["Brain MRI", "Not a Brain MRI"]
alzheimers_classes = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
brain_tumor_classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
# Define transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def generate_medical_report(diagnosis):
prompt = f"""
Generate a detailed medical report for a patient diagnosed with {diagnosis}.
Include possible causes, symptoms, treatment options, and prognosis.
Conclude the report with the signature: Team BrainTech.ai.
"""
response = gemini_model.generate_content(prompt)
return response.text.strip()
def predict_pipeline(image, model_type):
# Step 1: Check if it's an MRI
image_resized = image.resize((224, 224))
image_array = np.array(image_resized) / 255.0
image_array = np.expand_dims(image_array, axis=0)
mri_prediction = mri_classifier.predict(image_array)
mri_class = mri_classes[np.argmax(mri_prediction)]
mri_confidence = np.max(mri_prediction) * 100 # Confidence score in %
if mri_class == "Not a Brain MRI":
return "Not a Brain MRI", None, None
# Step 2: Classify MRI
image_tensor = transform(image).unsqueeze(0)
if model_type == "Alzheimer's":
with torch.no_grad():
outputs = alzheimers_model(image_tensor).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence = torch.max(probabilities).item() * 100 # Confidence in %
predicted_class = alzheimers_classes[torch.argmax(outputs).item()]
elif model_type == "Brain Tumor":
with torch.no_grad():
outputs = brain_tumor_model(image_tensor).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence = torch.max(probabilities).item() * 100 # Confidence in %
predicted_class = brain_tumor_classes[torch.argmax(outputs).item()]
# Step 3: Generate medical report
report = generate_medical_report(predicted_class)
return predicted_class, confidence, report
def download_report(report_text):
"""Convert report text into a downloadable format."""
buffer = io.BytesIO()
buffer.write(report_text.encode())
buffer.seek(0)
return buffer
# Streamlit UI
st.title("MRI Scan Classification Pipeline with Gemini AI")
st.write("Upload an image to check if it's an MRI, classify it, view confidence scores, and get an AI-generated medical report.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
model_type = st.selectbox("Select Model Type", ["Alzheimer's", "Brain Tumor"])
if st.button("Predict") and uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("Classifying...")
# Run the prediction pipeline
result, confidence, report = predict_pipeline(image, model_type)
# Display results
st.write(f"**Prediction:** {result}")
if confidence is not None:
st.write(f"**Confidence Score:** {confidence:.2f}%")
# Display AI-Generated Report
if report:
st.subheader("AI-Generated Medical Report")
st.write(report)
# Download Report Button
report_buffer = download_report(report)
st.download_button(
label="Download Medical Report",
data=report_buffer,
file_name=f"medical_report_{result.replace(' ', '_')}.txt",
mime="text/plain"
)
# Warning Banner
st.warning("⚠️ Please consult a doctor before taking any medical decisions based on this report.")