import streamlit as st import fitz # PyMuPDF from transformers import T5Tokenizer, T5ForConditionalGeneration import torch # Hugging Face models MODEL_SUMMARIZER = "kshitij230/T5-Summarizer" MODEL_TERM_GENERATOR = "kshitij230/T5-Term-Generation" MODEL_EXPLAINER = "kshitij230/T5-Term-Exaplainer" MODEL_QUESTION_GENERATOR = "kshitij230/T5-Question-Generation" @st.cache_resource def load_model_and_tokenizer(model_name): model = T5ForConditionalGeneration.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name) return model, tokenizer def summarize_report(text): model, tokenizer = load_model_and_tokenizer(MODEL_SUMMARIZER) input_text = f"summarize: {text}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) summary_ids = model.generate(inputs["input_ids"], max_length=150, num_beams=4, early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip() summary = summary.replace("A [Extracted] Complaint came in with an extracted complaint.", "").strip() return summary def generate_term(summary_text): model, tokenizer = load_model_and_tokenizer(MODEL_TERM_GENERATOR) input_text = f"generate term: {summary_text}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) term_ids = model.generate(inputs["input_ids"], max_length=50, num_beams=4, early_stopping=True) return tokenizer.decode(term_ids[0], skip_special_tokens=True).strip() def explain_term(term_text): model, tokenizer = load_model_and_tokenizer(MODEL_EXPLAINER) input_text = f"explain: {term_text}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) explanation_ids = model.generate(inputs["input_ids"], max_length=150, num_beams=4, early_stopping=True) return tokenizer.decode(explanation_ids[0], skip_special_tokens=True).strip() def generate_question(summary_text): model, tokenizer = load_model_and_tokenizer(MODEL_QUESTION_GENERATOR) input_text = f"generate question: {summary_text}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) question_ids = model.generate(inputs["input_ids"], max_length=100, num_beams=4, early_stopping=True) return tokenizer.decode(question_ids[0], skip_special_tokens=True) def extract_text_from_pdf(pdf_bytes): with fitz.open(stream=pdf_bytes, filetype="pdf") as doc: return "\n".join([page.get_text() for page in doc]) def format_to_clinical_structure(raw_text): return raw_text.strip() # ---------- UI ---------- # st.set_page_config(page_title="Elysium HealthAI", layout="wide") st.title("📄 Medical Report AI Assistant") # Initialize session state for history if "history" not in st.session_state: st.session_state.history = { "summary": [], "term": [], "explanation": [], "question": [] } uploaded_file = st.file_uploader("Upload a medical report (PDF only)", type=["pdf"]) if uploaded_file: file_bytes = uploaded_file.read() with st.spinner("Extracting text from document..."): raw_text = extract_text_from_pdf(file_bytes) structured_text = format_to_clinical_structure(raw_text) st.subheader("📋 Extracted Report") st.text_area("Structured Clinical Format", value=structured_text, height=250) # Generate Summary if st.button("📝 Generate Summary"): with st.spinner("Summarizing..."): summary = summarize_report(structured_text) st.session_state.history["summary"].append(summary) # Generate Term if st.button("🔍 Generate Term"): with st.spinner("Extracting medical term..."): if st.session_state.history["summary"]: term = generate_term(st.session_state.history["summary"][-1]) st.session_state.history["term"].append(term) else: st.warning("Please generate a summary first.") # Explain Term if st.button("📘 Explain Term"): with st.spinner("Generating explanation..."): if st.session_state.history["term"]: explanation = explain_term(st.session_state.history["term"][-1]) st.session_state.history["explanation"].append(explanation) else: st.warning("Please generate a term first.") # Generate Question if st.button("🗣️ Generate Patient-Friendly Question"): with st.spinner("Generating question..."): if st.session_state.history["summary"]: question = generate_question(st.session_state.history["summary"][-1]) st.session_state.history["question"].append(question) else: st.warning("Please generate a summary first.") # ---------- Display All Results (History) ---------- # st.markdown("### 🕓 History of Results") if st.session_state.history["summary"]: st.subheader("📝 Summaries") for idx, summary in enumerate(st.session_state.history["summary"], 1): st.markdown(f"**{idx}.** {summary}") if st.session_state.history["term"]: st.subheader("🔍 Medical Terms") for idx, term in enumerate(st.session_state.history["term"], 1): st.markdown(f"**{idx}.** `{term}`") if st.session_state.history["explanation"]: st.subheader("📘 Explanations") for idx, explanation in enumerate(st.session_state.history["explanation"], 1): st.markdown(f"**{idx}.** {explanation}") if st.session_state.history["question"]: st.subheader("🗣️ Patient-Friendly Questions") for idx, question in enumerate(st.session_state.history["question"], 1): st.markdown(f"**{idx}.** {question}")