import os import gradio as gr import tempfile from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint from langchain.chains import RetrievalQA from langchain_core.prompts import PromptTemplate from langchain_community.document_loaders import PyPDFLoader from collections import OrderedDict import re import shutil # Retrieve HF_TOKEN from environment HF_TOKEN = os.environ.get("HF_TOKEN") # Constants DATA_PATH = "medical_knowledge/" DB_FAISS_PATH = "/tmp/vectorstore/db_faiss" HUGGINGFACE_REPO_ID = "microsoft/Phi-3-mini-4k-instruct" UPLOAD_DIR = "/tmp/uploads/" # Create necessary directories CACHE_DIR = "/tmp/models_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True) os.makedirs(UPLOAD_DIR, exist_ok=True) # Load the embedding model embedding_model = HuggingFaceEmbeddings( model_name="rishi002/all-MiniLM-L6-v2", cache_folder=CACHE_DIR ) # Initialize FastAPI app app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables to track user report data and conversation history user_report_data = None conversation_history = [] # Load or create FAISS database from knowledge base PDFs def load_or_create_faiss(): if not os.path.exists(DB_FAISS_PATH): print("🔄 Creating FAISS Database...") from embeddings import load_pdf_files, create_chunks # Import functions from embeddings.py documents = load_pdf_files(DATA_PATH) # Load PDFs text_chunks = create_chunks(documents) # Split into Chunks db = FAISS.from_documents(text_chunks, embedding_model) db.save_local(DB_FAISS_PATH) else: print("✅ FAISS Database Exists. Loading...") return FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True) # Load the knowledge base db = load_or_create_faiss() # Load LLM def load_llm(): return HuggingFaceEndpoint( repo_id=HUGGINGFACE_REPO_ID, task="text-generation", temperature=0.5, model_kwargs={"token": HF_TOKEN, "max_length": 512} ) # Function to extract medical parameters from PDF text def extract_medical_parameters(text): # This is a simplified extraction function # In a real-world scenario, you'd want more sophisticated extraction logic parameters = {} # Look for common medical parameters with regex # Blood pressure: systolic/diastolic bp_match = re.search(r'blood pressure[:\s]*([\d]+)[\s\/]*([\d]+)', text, re.IGNORECASE) if bp_match: parameters['blood_pressure'] = f"{bp_match.group(1)}/{bp_match.group(2)}" # Heart rate hr_match = re.search(r'heart rate[:\s]*([\d]+)', text, re.IGNORECASE) if hr_match: parameters['heart_rate'] = hr_match.group(1) # Blood glucose glucose_match = re.search(r'glucose[:\s]*([\d\.]+)', text, re.IGNORECASE) if glucose_match: parameters['glucose'] = glucose_match.group(1) # Hemoglobin hb_match = re.search(r'h(?:a|e)moglobin[:\s]*([\d\.]+)', text, re.IGNORECASE) if hb_match: parameters['hemoglobin'] = hb_match.group(1) # White blood cell count wbc_match = re.search(r'white blood cell[s]?[:\s]*([\d\.]+)', text, re.IGNORECASE) if wbc_match: parameters['wbc_count'] = wbc_match.group(1) # Cholesterol cholesterol_match = re.search(r'cholesterol[:\s]*([\d\.]+)', text, re.IGNORECASE) if cholesterol_match: parameters['cholesterol'] = cholesterol_match.group(1) # Add more parameter extraction as needed # If no specific parameters were found, store the whole text for context if not parameters: # Simplify by taking first 1000 chars if text is too long parameters['report_summary'] = text[:1000] if len(text) > 1000 else text return parameters # Function to process uploaded PDF file def process_pdf_file(file_path): try: # Load the PDF loader = PyPDFLoader(file_path) documents = loader.load() # Extract text from all pages all_text = " ".join([doc.page_content for doc in documents]) # Extract medical parameters from the text global user_report_data user_report_data = extract_medical_parameters(all_text) return True, user_report_data except Exception as e: print(f"Error processing PDF: {str(e)}") return False, str(e) # Custom prompt template that includes medical parameters MEDICAL_REPORT_PROMPT = """ Use the following information to answer the user's question about their medical report. If you don't know the answer, just say that you don't know. Don't make up an answer. Keep your answer concise and avoid repeating the same information. Explain medical terms in a way that's easy for patients to understand. Do not mention the source of information in your answer. User's Medical Parameters: {parameters} Knowledge Base Context: {context} Question: {question} Start the answer directly. """ # Create the QA chain def create_qa_chain(): prompt = PromptTemplate( template=MEDICAL_REPORT_PROMPT, input_variables=["parameters", "context", "question"] ) return RetrievalQA.from_chain_type( llm=load_llm(), chain_type="stuff", retriever=db.as_retriever(search_kwargs={'k': 3}), return_source_documents=False, chain_type_kwargs={'prompt': prompt} ) qa_chain = create_qa_chain() # API Models class Question(BaseModel): query: str # API endpoint to process an uploaded PDF file @app.post("/api/upload-report") async def upload_report(file: UploadFile = File(...)): # Save the uploaded file file_path = os.path.join(UPLOAD_DIR, file.filename) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) # Process the PDF file success, data = process_pdf_file(file_path) # Clean up the file os.remove(file_path) if success: return { "status": "success", "message": "Medical report data extracted successfully", "processed": True, "parameters_found": len(data) > 0 } else: return { "status": "error", "message": f"Failed to process the medical report: {data}", "processed": False } # API endpoint to ask questions about the processed report @app.post("/api/ask-question") async def ask_question(question_data: Question): global user_report_data, conversation_history if user_report_data is None: raise HTTPException(status_code=400, detail="No medical report has been processed yet") try: # Format the parameters for the prompt parameters_text = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in user_report_data.items()]) # Get answer from the QA chain with user parameters included response = qa_chain.invoke({ 'query': question_data.query, 'parameters': parameters_text }) # Get the raw result result = response["result"] # Remove duplicates by splitting into sentences and keeping only unique ones sentences = [s.strip() for s in result.split('.') if s.strip()] # Use OrderedDict to preserve order while removing duplicates unique_sentences = list(OrderedDict.fromkeys(sentences)) # Rejoin with periods cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else "" # Add to conversation history conversation_history.append({"user": question_data.query, "bot": cleaned_result}) return {"answer": cleaned_result} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}") # Gradio Interface Components def process_file_upload(file): if file is None: return None, "Please upload a PDF file", [] success, data = process_pdf_file(file.name) if success: parameters = [f"**{k.replace('_', ' ').title()}**: {v}" for k, v in data.items()] parameters_markdown = "\n".join(parameters) return file.name, f"✅ Report processed successfully!\n\n### Extracted Parameters:\n{parameters_markdown}", [] else: return None, f"❌ Failed to process report: {data}", [] def ask_question_gradio(question, history): global user_report_data, conversation_history if user_report_data is None: history.append((question, "No medical report has been processed yet. Please upload a report first.")) return "", history try: # Format the parameters for the prompt parameters_text = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in user_report_data.items()]) # Get answer from the QA chain with user parameters included response = qa_chain.invoke({ 'query': question, 'parameters': parameters_text }) # Get the raw result result = response["result"] # Remove duplicates by splitting into sentences and keeping only unique ones sentences = [s.strip() for s in result.split('.') if s.strip()] # Use OrderedDict to preserve order while removing duplicates unique_sentences = list(OrderedDict.fromkeys(sentences)) # Rejoin with periods cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else "" history.append((question, cleaned_result)) return "", history except Exception as e: history.append((question, f"Error: {str(e)}")) return "", history def clear_conversation(): return [], None, "Upload your medical report PDF to get started", [] # Improved Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown( """ # 🏥 Medical Report Analyzer Upload your medical report and ask questions to understand it better. Our AI assistant will help explain your results in plain language. """ ) with gr.Row(): with gr.Column(scale=1): with gr.Box(): gr.Markdown("### 1️⃣ Upload Your Report") file_upload = gr.File( file_types=[".pdf"], label="Upload Medical Report (PDF)", ) uploaded_file = gr.Textbox( label="Current Report", interactive=False, visible=False ) upload_status = gr.Markdown( "Upload your medical report PDF to get started" ) upload_button = gr.Button("Process Report", variant="primary") clear_button = gr.Button("Clear & Start Over", variant="secondary") with gr.Column(scale=2): with gr.Box(): gr.Markdown("### 2️⃣ Ask Questions About Your Report") chat_interface = gr.Chatbot( label="Conversation", height=400, show_copy_button=True, ) question_input = gr.Textbox( label="Ask a question about your report", placeholder="e.g., What does my blood pressure mean?", ) with gr.Row(): submit_button = gr.Button("Submit Question", variant="primary") clear_chat_button = gr.Button("Clear Chat", variant="secondary") parameter_display = gr.JSON( label="Extracted Parameters", visible=False ) # Set up interactions upload_button.click( fn=process_file_upload, inputs=[file_upload], outputs=[uploaded_file, upload_status, parameter_display] ) submit_button.click( fn=ask_question_gradio, inputs=[question_input, chat_interface], outputs=[question_input, chat_interface] ) question_input.submit( fn=ask_question_gradio, inputs=[question_input, chat_interface], outputs=[question_input, chat_interface] ) clear_button.click( fn=clear_conversation, inputs=[], outputs=[chat_interface, uploaded_file, upload_status, parameter_display] ) clear_chat_button.click( fn=lambda: ([], ""), inputs=[], outputs=[chat_interface, question_input] ) # Mount the Gradio app to FastAPI app = gr.mount_gradio_app(app, iface, path="/") # Run the app with uvicorn if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)