Retina-AI-OCT / app.py
ferferefer's picture
Upload app.py
17b0b49 verified
import streamlit as st
import google.generativeai as genai
from PIL import Image
import os
from dotenv import load_dotenv
import PyPDF2
import io
from datetime import datetime
# Page configuration must be the first Streamlit command
st.set_page_config(
page_title="OCT Retina Analysis Assistant",
page_icon="πŸ‘οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# Load environment variables
load_dotenv()
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-2.0-flash-exp")
# Custom CSS
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
background-color: #FF4B4B;
color: white;
padding: 0.5rem;
margin-top: 1rem;
}
.credit-box {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 0.5rem;
margin: 1rem 0;
}
.header-box {
background-color: #FF4B4B;
padding: 2rem;
border-radius: 0.5rem;
color: white;
margin-bottom: 2rem;
text-align: center;
}
.image-container {
margin: 1rem 0;
padding: 1rem;
border-radius: 0.5rem;
background-color: #f0f2f6;
}
.analysis-container {
margin-top: 1rem;
padding: 1rem;
border-radius: 0.5rem;
background-color: #f0f2f6;
}
</style>
""", unsafe_allow_html=True)
# System prompts
SINGLE_TIMEPOINT_PROMPT = """You are an expert ophthalmologist specializing in interpreting macular Optical Coherence Tomography (OCT) scans. Your goal is to provide an accurate description and a possible diagnosis, supported by clear medical reasoning. These scans are from the same patient at a single timepoint. Please provide a comprehensive analysis.
Step 1: Image Quality Assessment
For each scan, describe the overall image quality, noting any artifacts or limitations that may affect your analysis.
Step 2: Layer-by-Layer Analysis Across All Scans
Analyze each of the retinal layers across all provided scans, describing:
β€’ Thickness patterns: Note any variations or consistencies in layer thickness
β€’ Morphological changes: Compare layer appearance across scans
β€’ Reflectivity patterns: Identify any recurring patterns or changes
β€’ Document abnormalities and their distribution across scans
Step 3: Comprehensive Foveal Analysis
Analyze the foveal region across all scans:
β€’ Compare foveal contour and thickness
β€’ Note any consistent or varying abnormalities
β€’ Identify patterns of foveal involvement
Step 4: Integrated Abnormality Assessment
Provide a unified analysis of abnormalities across all scans:
β€’ Distribution patterns
β€’ Progression or variation in appearance
β€’ Relationship between findings in different scans
Step 5: Differential Diagnoses
Based on the comprehensive analysis:
β€’ List potential diagnoses supported by findings across multiple scans
β€’ Explain how the pattern of findings supports each diagnosis
β€’ Note any temporal or spatial progression that helps narrow the diagnosis
Step 6: Most Likely Diagnosis
Provide a unified diagnosis considering all scans:
β€’ Explain how the combined findings support this diagnosis
β€’ Discuss any progression or pattern that confirms the diagnosis
β€’ Address any variations or inconsistencies
Step 7: Recommendations
Suggest:
β€’ Additional tests or imaging if needed
β€’ Follow-up scanning recommendations
β€’ Treatment considerations based on the comprehensive analysis"""
COMPARISON_PROMPT = """You are an expert ophthalmologist specializing in interpreting macular Optical Coherence Tomography (OCT) scans. Your goal is to provide an accurate description and a possible diagnosis, supported by clear medical reasoning. These scans are from the same patient at two different timepoints. Please provide a comprehensive analysis and comparison.
Step 1: Image Quality Assessment
For each set of scans, describe the overall image quality, noting any artifacts or limitations that may affect your analysis.
Step 2: Layer-by-Layer Comparison
Compare the retinal layers between timepoints:
β€’ Changes in thickness patterns
β€’ Evolution of morphological features
β€’ Alterations in reflectivity patterns
β€’ Progression or regression of abnormalities
Step 3: Foveal Evolution Analysis
Compare the foveal region between timepoints:
β€’ Changes in contour and thickness
β€’ Evolution of abnormalities
β€’ Progression or improvement patterns
Step 4: Disease Progression Assessment
Analyze changes between timepoints:
β€’ Quantify and describe changes in abnormalities
β€’ Identify new or resolved findings
β€’ Assess overall disease progression or improvement
Step 5: Treatment Response Evaluation
If treatment was administered:
β€’ Evaluate effectiveness
β€’ Identify areas of improvement
β€’ Note resistant or worsening areas
Step 6: Updated Diagnosis and Prognosis
Based on the temporal comparison:
β€’ Confirm or revise previous diagnosis
β€’ Assess disease trajectory
β€’ Provide prognostic insights
Step 7: Recommendations
Suggest:
β€’ Treatment modifications if needed
β€’ Follow-up interval
β€’ Additional testing if required
β€’ Preventive measures"""
TREATMENT_GUIDELINES_PROMPT = """Based on the current diagnosis and findings, please provide evidence-based treatment recommendations following established ophthalmological guidelines. Consider:
1. Standard of Care
β€’ First-line treatments
β€’ Alternative options
β€’ Contraindications
2. Treatment Plan
β€’ Immediate interventions
β€’ Long-term management
β€’ Follow-up schedule
3. Monitoring Parameters
β€’ Key metrics to track
β€’ Warning signs
β€’ Success indicators
4. Patient Education
β€’ Lifestyle modifications
β€’ Self-monitoring instructions
β€’ Prevention strategies"""
def extract_pdf_text(pdf_file):
"""Extract text from uploaded PDF file"""
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def analyze_oct_images(images, timepoint=None, patient_data=None):
"""Analyze OCT images with optional timepoint and patient data"""
if timepoint:
prompt = f"{SINGLE_TIMEPOINT_PROMPT}\n\nTimepoint: {timepoint}\n"
else:
prompt = f"{SINGLE_TIMEPOINT_PROMPT}\n"
if patient_data:
prompt += f"\nPatient Information:\n{patient_data}\n"
prompt += "\nPlease analyze these OCT scans:"
content = [prompt] + images
response = model.generate_content(content)
return response.text
def compare_oct_timepoints(images1, date1, images2, date2, patient_data=None):
"""Compare OCT images from two timepoints"""
prompt = f"{COMPARISON_PROMPT}\n\nTimepoint 1: {date1}\nTimepoint 2: {date2}\n"
if patient_data:
prompt += f"\nPatient Information:\n{patient_data}\n"
prompt += "\nPlease compare these OCT scans:"
content = [prompt] + images1 + images2
response = model.generate_content(content)
return response.text
def get_treatment_recommendations(diagnosis, findings):
"""Get treatment recommendations based on guidelines"""
prompt = f"{TREATMENT_GUIDELINES_PROMPT}\n\nDiagnosis: {diagnosis}\nFindings: {findings}"
response = model.generate_content(prompt)
return response.text
def main():
# Header with custom styling
st.markdown("""
<div class="header-box">
<h1>OCT Retina Analysis Assistant</h1>
</div>
""", unsafe_allow_html=True)
# Credits
st.markdown("""
<div class="credit-box">
<h3>About</h3>
<p>Developed by Dr. Fernando Ly</p>
<p>This tool assists in the analysis of OCT retina scans using advanced AI technology.
It provides detailed layer analysis and potential diagnoses to support clinical decision-making.</p>
<p><strong>Note:</strong> This tool is for assistance only and should not replace professional medical judgment.</p>
</div>
""", unsafe_allow_html=True)
# Main content
col1, col2 = st.columns([1, 1])
with col1:
# Patient Data Section
st.markdown("### Patient Information")
patient_pdf = st.file_uploader("Upload Patient Data (PDF)", type=['pdf'])
patient_data = None
if patient_pdf:
patient_data = extract_pdf_text(patient_pdf)
with st.expander("View Patient Data"):
st.text(patient_data)
# Scan Upload Section
st.markdown("### Upload OCT Scans")
timepoint_option = st.radio(
"Select scan type:",
["Single Timepoint", "Two Timepoints for Comparison"]
)
if timepoint_option == "Single Timepoint":
uploaded_files = st.file_uploader(
"Choose OCT scans",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
key="single_timepoint"
)
if uploaded_files:
scan_date = st.date_input("Scan Date")
st.markdown("### Uploaded Scans")
images = []
for idx, uploaded_file in enumerate(uploaded_files):
with st.expander(f"OCT Scan {idx + 1}", expanded=True):
image = Image.open(uploaded_file)
images.append(image)
st.image(image, use_container_width=True, caption=f"OCT Scan {idx + 1}")
if st.button("πŸ” Analyze Scans"):
with st.spinner("Analyzing OCT scans... Please wait."):
try:
analysis = analyze_oct_images(images, scan_date, patient_data)
treatment_recs = get_treatment_recommendations(
"Based on the analysis above",
"See detailed findings above"
)
with col2:
st.markdown("### Analysis Results")
st.markdown(f"""
<div class="analysis-container">
{analysis.replace(chr(10), '<br>')}
</div>
""", unsafe_allow_html=True)
st.markdown("### Treatment Recommendations")
st.markdown(f"""
<div class="analysis-container">
{treatment_recs.replace(chr(10), '<br>')}
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
else: # Two Timepoints
st.markdown("#### First Timepoint")
files1 = st.file_uploader(
"Choose first set of OCT scans",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
key="timepoint1"
)
date1 = st.date_input("First Scan Date")
st.markdown("#### Second Timepoint")
files2 = st.file_uploader(
"Choose second set of OCT scans",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
key="timepoint2"
)
date2 = st.date_input("Second Scan Date")
if files1 and files2:
st.markdown("### Uploaded Scans")
images1 = []
images2 = []
st.markdown("#### First Timepoint Scans")
for idx, uploaded_file in enumerate(files1):
with st.expander(f"OCT Scan {idx + 1} - First Timepoint", expanded=False):
image = Image.open(uploaded_file)
images1.append(image)
st.image(image, use_container_width=True)
st.markdown("#### Second Timepoint Scans")
for idx, uploaded_file in enumerate(files2):
with st.expander(f"OCT Scan {idx + 1} - Second Timepoint", expanded=False):
image = Image.open(uploaded_file)
images2.append(image)
st.image(image, use_container_width=True)
if st.button("πŸ” Compare Timepoints"):
with st.spinner("Analyzing and comparing OCT scans... Please wait."):
try:
comparison = compare_oct_timepoints(
images1, date1,
images2, date2,
patient_data
)
treatment_recs = get_treatment_recommendations(
"Based on the comparison above",
"See detailed findings above"
)
with col2:
st.markdown("### Comparison Results")
st.markdown(f"""
<div class="analysis-container">
{comparison.replace(chr(10), '<br>')}
</div>
""", unsafe_allow_html=True)
st.markdown("### Treatment Recommendations")
st.markdown(f"""
<div class="analysis-container">
{treatment_recs.replace(chr(10), '<br>')}
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
# Instructions in col2 if no files uploaded
if not (patient_pdf or (timepoint_option == "Single Timepoint" and uploaded_files) or
(timepoint_option == "Two Timepoints for Comparison" and files1 and files2)):
with col2:
st.markdown("### Instructions")
st.markdown("""
1. Upload patient data PDF (optional)
2. Choose analysis type (single timepoint or comparison)
3. Upload OCT scans for selected timepoint(s)
4. Set scan dates
5. Click analyze/compare button
6. Review analysis and treatment recommendations
Supported formats:
- Patient Data: PDF
- OCT Scans: PNG, JPG, JPEG
""")
if __name__ == "__main__":
main()