rishi002 commited on
Commit
cc3e1c0
·
verified ·
1 Parent(s): 16e6135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -138
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import gradio as gr
3
- import requests
4
  import tempfile
5
- from fastapi import FastAPI, HTTPException, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from langchain_community.vectorstores import FAISS
@@ -10,22 +9,30 @@ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
10
  from langchain.chains import RetrievalQA
11
  from langchain_core.prompts import PromptTemplate
12
  from langchain_community.document_loaders import PyPDFLoader
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from collections import OrderedDict
 
 
15
 
16
  # Retrieve HF_TOKEN from environment
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
19
  # Constants
20
- CACHE_DIR = "/tmp/models_cache"
21
  DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
22
- USER_REPORT_DB_PATH = "/tmp/vectorstore/user_report_db"
23
  HUGGINGFACE_REPO_ID = "microsoft/Phi-3-mini-4k-instruct"
 
24
 
25
- # Create directories
 
26
  os.makedirs(CACHE_DIR, exist_ok=True)
27
  os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True)
28
- os.makedirs(os.path.dirname(USER_REPORT_DB_PATH), exist_ok=True)
 
 
 
 
 
 
29
 
30
  # Initialize FastAPI app
31
  app = FastAPI()
@@ -39,15 +46,28 @@ app.add_middleware(
39
  allow_headers=["*"],
40
  )
41
 
42
- # Load the embedding model
43
- embedding_model = HuggingFaceEmbeddings(
44
- model_name="rishi002/all-MiniLM-L6-v2",
45
- cache_folder=CACHE_DIR
46
- )
 
 
 
 
47
 
48
- # Global variables to track report status and database
49
- user_report_processed = False
50
- user_report_db = None
 
 
 
 
 
 
 
 
 
51
 
52
  # Load LLM
53
  def load_llm():
@@ -58,66 +78,97 @@ def load_llm():
58
  model_kwargs={"token": HF_TOKEN, "max_length": 512}
59
  )
60
 
61
- # Custom prompt template for medical report analysis
62
- MEDICAL_REPORT_PROMPT = """
63
- You are a helpful medical assistant analyzing a patient's medical report.
64
- Use only the information provided in the context to answer the user's question.
65
- If you don't know the answer based on the given context, simply state that you don't have enough information.
66
- Don't make up any medical information or conclusions not supported by the report.
67
- Provide concise, clear explanations in simple language that a patient can understand.
68
- Avoid using complex medical terminology unless necessary, and if used, briefly explain what it means.
69
- Keep your answer concise and focused on the question asked.
70
-
71
- Context: {context}
72
- Question: {question}
73
-
74
- Start the answer directly without repeating the question.
75
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Function to download and process PDF from URL
78
- def process_pdf_from_url(pdf_url):
79
  try:
80
- # Download the PDF from the URL
81
- response = requests.get(pdf_url)
82
- response.raise_for_status() # Raise exception for bad status codes
83
-
84
- # Create a temporary file to save the PDF
85
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
86
- temp_pdf.write(response.content)
87
- temp_path = temp_pdf.name
88
-
89
  # Load the PDF
90
- loader = PyPDFLoader(temp_path)
91
  documents = loader.load()
92
 
93
- # Split documents into chunks
94
- text_splitter = RecursiveCharacterTextSplitter(
95
- chunk_size=1000,
96
- chunk_overlap=200
97
- )
98
- text_chunks = text_splitter.split_documents(documents)
99
-
100
- # Create vector database from the text chunks
101
- db = FAISS.from_documents(text_chunks, embedding_model)
102
- db.save_local(USER_REPORT_DB_PATH)
103
 
104
- # Clean up the temporary file
105
- os.unlink(temp_path)
 
106
 
107
- return True
108
 
109
  except Exception as e:
110
  print(f"Error processing PDF: {str(e)}")
111
- return False
112
 
113
- # Create QA chain for user report
114
- def create_user_report_qa_chain():
115
- if not os.path.exists(USER_REPORT_DB_PATH):
116
- return None
117
-
118
- db = FAISS.load_local(USER_REPORT_DB_PATH, embedding_model, allow_dangerous_deserialization=True)
119
-
120
- prompt = PromptTemplate(template=MEDICAL_REPORT_PROMPT, input_variables=["context", "question"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  return RetrievalQA.from_chain_type(
123
  llm=load_llm(),
@@ -127,48 +178,57 @@ def create_user_report_qa_chain():
127
  chain_type_kwargs={'prompt': prompt}
128
  )
129
 
130
- # API Models
131
- class ReportURL(BaseModel):
132
- url: str
133
 
 
134
  class Question(BaseModel):
135
  query: str
136
 
137
- # Combined API endpoint to process a PDF report from a URL and return status
138
- @app.post("/api/process-report")
139
- async def process_report(report_data: ReportURL):
140
- global user_report_processed, user_report_db
 
 
 
141
 
142
- # Process the PDF from the URL
143
- success = process_pdf_from_url(report_data.url)
 
 
 
144
 
145
  if success:
146
- user_report_processed = True
147
- user_report_db = create_user_report_qa_chain()
148
  return {
149
  "status": "success",
150
  "message": "Medical report data extracted successfully",
151
- "processed": True
 
152
  }
153
  else:
154
- user_report_processed = False
155
  return {
156
  "status": "error",
157
- "message": "Failed to process the medical report",
158
  "processed": False
159
  }
160
 
161
  # API endpoint to ask questions about the processed report
162
  @app.post("/api/ask-question")
163
  async def ask_question(question_data: Question):
164
- global user_report_db, user_report_processed
165
 
166
- if not user_report_processed or user_report_db is None:
167
  raise HTTPException(status_code=400, detail="No medical report has been processed yet")
168
 
169
  try:
170
- # Get answer from the QA chain
171
- response = user_report_db.invoke({'query': question_data.query})
 
 
 
 
 
 
172
 
173
  # Get the raw result
174
  result = response["result"]
@@ -181,79 +241,155 @@ async def ask_question(question_data: Question):
181
  # Rejoin with periods
182
  cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
183
 
 
 
 
184
  return {"answer": cleaned_result}
185
 
186
  except Exception as e:
187
  raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
188
 
189
- # Gradio Interface
190
- with gr.Blocks() as iface:
191
- gr.Markdown("# Medical Report Analysis")
 
192
 
193
- with gr.Row():
194
- with gr.Column():
195
- pdf_url_input = gr.Textbox(label="Enter PDF Report URL")
196
- process_button = gr.Button("Analyze Report")
197
- status_text = gr.Textbox(label="Status", interactive=False)
198
-
199
- with gr.Row():
200
- with gr.Column():
201
- query_input = gr.Textbox(label="Ask a question about your report")
202
- query_button = gr.Button("Submit Question")
203
- answer_output = gr.Textbox(label="Answer", interactive=False)
204
 
205
- def process_report_gradio(url):
206
- global user_report_processed, user_report_db
 
207
 
208
- if not url:
209
- return "Please enter a valid URL"
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- success = process_pdf_from_url(url)
 
 
 
 
212
 
213
- if success:
214
- user_report_processed = True
215
- user_report_db = create_user_report_qa_chain()
216
- return "Medical report data extracted successfully. You can now ask questions about your report."
217
- else:
218
- user_report_processed = False
219
- return "Failed to process the medical report. Please check the URL and try again."
220
-
221
- def ask_question_gradio(query):
222
- global user_report_db, user_report_processed
223
 
224
- if not user_report_processed or user_report_db is None:
225
- return "No medical report has been processed yet. Please upload and analyze a report first."
226
 
227
- try:
228
- # Get answer from the QA chain
229
- response = user_report_db.invoke({'query': query})
230
-
231
- # Get the raw result
232
- result = response["result"]
233
-
234
- # Remove duplicates by splitting into sentences and keeping only unique ones
235
- sentences = [s.strip() for s in result.split('.') if s.strip()]
236
- # Use OrderedDict to preserve order while removing duplicates
237
- unique_sentences = list(OrderedDict.fromkeys(sentences))
238
-
239
- # Rejoin with periods
240
- cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
241
-
242
- return cleaned_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- except Exception as e:
245
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- process_button.click(
248
- fn=process_report_gradio,
249
- inputs=pdf_url_input,
250
- outputs=status_text
 
251
  )
252
 
253
- query_button.click(
254
  fn=ask_question_gradio,
255
- inputs=query_input,
256
- outputs=answer_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
 
259
  # Mount the Gradio app to FastAPI
 
1
  import os
2
  import gradio as gr
 
3
  import tempfile
4
+ from fastapi import FastAPI, HTTPException, File, UploadFile
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from langchain_community.vectorstores import FAISS
 
9
  from langchain.chains import RetrievalQA
10
  from langchain_core.prompts import PromptTemplate
11
  from langchain_community.document_loaders import PyPDFLoader
 
12
  from collections import OrderedDict
13
+ import re
14
+ import shutil
15
 
16
  # Retrieve HF_TOKEN from environment
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
19
  # Constants
20
+ DATA_PATH = "dataFolder/"
21
  DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
 
22
  HUGGINGFACE_REPO_ID = "microsoft/Phi-3-mini-4k-instruct"
23
+ UPLOAD_DIR = "/tmp/uploads/"
24
 
25
+ # Create necessary directories
26
+ CACHE_DIR = "/tmp/models_cache"
27
  os.makedirs(CACHE_DIR, exist_ok=True)
28
  os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True)
29
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
30
+
31
+ # Load the embedding model
32
+ embedding_model = HuggingFaceEmbeddings(
33
+ model_name="rishi002/all-MiniLM-L6-v2",
34
+ cache_folder=CACHE_DIR
35
+ )
36
 
37
  # Initialize FastAPI app
38
  app = FastAPI()
 
46
  allow_headers=["*"],
47
  )
48
 
49
+ # Global variables to track user report data and conversation history
50
+ user_report_data = None
51
+ conversation_history = []
52
+
53
+ # Load or create FAISS database from knowledge base PDFs
54
+ def load_or_create_faiss():
55
+ if not os.path.exists(DB_FAISS_PATH):
56
+ print("🔄 Creating FAISS Database...")
57
+ from embeddings import load_pdf_files, create_chunks # Import functions from embeddings.py
58
 
59
+ documents = load_pdf_files(DATA_PATH) # Load PDFs
60
+ text_chunks = create_chunks(documents) # Split into Chunks
61
+
62
+ db = FAISS.from_documents(text_chunks, embedding_model)
63
+ db.save_local(DB_FAISS_PATH)
64
+ else:
65
+ print("✅ FAISS Database Exists. Loading...")
66
+
67
+ return FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
68
+
69
+ # Load the knowledge base
70
+ db = load_or_create_faiss()
71
 
72
  # Load LLM
73
  def load_llm():
 
78
  model_kwargs={"token": HF_TOKEN, "max_length": 512}
79
  )
80
 
81
+ # Function to extract medical parameters from PDF text
82
+ def extract_medical_parameters(text):
83
+ # This is a simplified extraction function
84
+ # In a real-world scenario, you'd want more sophisticated extraction logic
85
+ parameters = {}
86
+
87
+ # Look for common medical parameters with regex
88
+ # Blood pressure: systolic/diastolic
89
+ bp_match = re.search(r'blood pressure[:\s]*([\d]+)[\s\/]*([\d]+)', text, re.IGNORECASE)
90
+ if bp_match:
91
+ parameters['blood_pressure'] = f"{bp_match.group(1)}/{bp_match.group(2)}"
92
+
93
+ # Heart rate
94
+ hr_match = re.search(r'heart rate[:\s]*([\d]+)', text, re.IGNORECASE)
95
+ if hr_match:
96
+ parameters['heart_rate'] = hr_match.group(1)
97
+
98
+ # Blood glucose
99
+ glucose_match = re.search(r'glucose[:\s]*([\d\.]+)', text, re.IGNORECASE)
100
+ if glucose_match:
101
+ parameters['glucose'] = glucose_match.group(1)
102
+
103
+ # Hemoglobin
104
+ hb_match = re.search(r'h(?:a|e)moglobin[:\s]*([\d\.]+)', text, re.IGNORECASE)
105
+ if hb_match:
106
+ parameters['hemoglobin'] = hb_match.group(1)
107
+
108
+ # White blood cell count
109
+ wbc_match = re.search(r'white blood cell[s]?[:\s]*([\d\.]+)', text, re.IGNORECASE)
110
+ if wbc_match:
111
+ parameters['wbc_count'] = wbc_match.group(1)
112
+
113
+ # Cholesterol
114
+ cholesterol_match = re.search(r'cholesterol[:\s]*([\d\.]+)', text, re.IGNORECASE)
115
+ if cholesterol_match:
116
+ parameters['cholesterol'] = cholesterol_match.group(1)
117
+
118
+ # Add more parameter extraction as needed
119
+
120
+ # If no specific parameters were found, store the whole text for context
121
+ if not parameters:
122
+ # Simplify by taking first 1000 chars if text is too long
123
+ parameters['report_summary'] = text[:1000] if len(text) > 1000 else text
124
+
125
+ return parameters
126
 
127
+ # Function to process uploaded PDF file
128
+ def process_pdf_file(file_path):
129
  try:
 
 
 
 
 
 
 
 
 
130
  # Load the PDF
131
+ loader = PyPDFLoader(file_path)
132
  documents = loader.load()
133
 
134
+ # Extract text from all pages
135
+ all_text = " ".join([doc.page_content for doc in documents])
 
 
 
 
 
 
 
 
136
 
137
+ # Extract medical parameters from the text
138
+ global user_report_data
139
+ user_report_data = extract_medical_parameters(all_text)
140
 
141
+ return True, user_report_data
142
 
143
  except Exception as e:
144
  print(f"Error processing PDF: {str(e)}")
145
+ return False, str(e)
146
 
147
+ # Custom prompt template that includes medical parameters
148
+ MEDICAL_REPORT_PROMPT = """
149
+ Use the following information to answer the user's question about their medical report.
150
+ If you don't know the answer, just say that you don't know. Don't make up an answer.
151
+ Keep your answer concise and avoid repeating the same information.
152
+ Explain medical terms in a way that's easy for patients to understand.
153
+ Do not mention the source of information in your answer.
154
+
155
+ User's Medical Parameters:
156
+ {parameters}
157
+
158
+ Knowledge Base Context:
159
+ {context}
160
+
161
+ Question: {question}
162
+
163
+ Start the answer directly.
164
+ """
165
+
166
+ # Create the QA chain
167
+ def create_qa_chain():
168
+ prompt = PromptTemplate(
169
+ template=MEDICAL_REPORT_PROMPT,
170
+ input_variables=["parameters", "context", "question"]
171
+ )
172
 
173
  return RetrievalQA.from_chain_type(
174
  llm=load_llm(),
 
178
  chain_type_kwargs={'prompt': prompt}
179
  )
180
 
181
+ qa_chain = create_qa_chain()
 
 
182
 
183
+ # API Models
184
  class Question(BaseModel):
185
  query: str
186
 
187
+ # API endpoint to process an uploaded PDF file
188
+ @app.post("/api/upload-report")
189
+ async def upload_report(file: UploadFile = File(...)):
190
+ # Save the uploaded file
191
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
192
+ with open(file_path, "wb") as buffer:
193
+ shutil.copyfileobj(file.file, buffer)
194
 
195
+ # Process the PDF file
196
+ success, data = process_pdf_file(file_path)
197
+
198
+ # Clean up the file
199
+ os.remove(file_path)
200
 
201
  if success:
 
 
202
  return {
203
  "status": "success",
204
  "message": "Medical report data extracted successfully",
205
+ "processed": True,
206
+ "parameters_found": len(data) > 0
207
  }
208
  else:
 
209
  return {
210
  "status": "error",
211
+ "message": f"Failed to process the medical report: {data}",
212
  "processed": False
213
  }
214
 
215
  # API endpoint to ask questions about the processed report
216
  @app.post("/api/ask-question")
217
  async def ask_question(question_data: Question):
218
+ global user_report_data, conversation_history
219
 
220
+ if user_report_data is None:
221
  raise HTTPException(status_code=400, detail="No medical report has been processed yet")
222
 
223
  try:
224
+ # Format the parameters for the prompt
225
+ parameters_text = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in user_report_data.items()])
226
+
227
+ # Get answer from the QA chain with user parameters included
228
+ response = qa_chain.invoke({
229
+ 'query': question_data.query,
230
+ 'parameters': parameters_text
231
+ })
232
 
233
  # Get the raw result
234
  result = response["result"]
 
241
  # Rejoin with periods
242
  cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
243
 
244
+ # Add to conversation history
245
+ conversation_history.append({"user": question_data.query, "bot": cleaned_result})
246
+
247
  return {"answer": cleaned_result}
248
 
249
  except Exception as e:
250
  raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
251
 
252
+ # Gradio Interface Components
253
+ def process_file_upload(file):
254
+ if file is None:
255
+ return None, "Please upload a PDF file", []
256
 
257
+ success, data = process_pdf_file(file.name)
 
 
 
 
 
 
 
 
 
 
258
 
259
+ if success:
260
+ parameters = [f"**{k.replace('_', ' ').title()}**: {v}" for k, v in data.items()]
261
+ parameters_markdown = "\n".join(parameters)
262
 
263
+ return file.name, f"✅ Report processed successfully!\n\n### Extracted Parameters:\n{parameters_markdown}", []
264
+ else:
265
+ return None, f"❌ Failed to process report: {data}", []
266
+
267
+ def ask_question_gradio(question, history):
268
+ global user_report_data, conversation_history
269
+
270
+ if user_report_data is None:
271
+ history.append((question, "No medical report has been processed yet. Please upload a report first."))
272
+ return "", history
273
+
274
+ try:
275
+ # Format the parameters for the prompt
276
+ parameters_text = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in user_report_data.items()])
277
 
278
+ # Get answer from the QA chain with user parameters included
279
+ response = qa_chain.invoke({
280
+ 'query': question,
281
+ 'parameters': parameters_text
282
+ })
283
 
284
+ # Get the raw result
285
+ result = response["result"]
286
+
287
+ # Remove duplicates by splitting into sentences and keeping only unique ones
288
+ sentences = [s.strip() for s in result.split('.') if s.strip()]
289
+ # Use OrderedDict to preserve order while removing duplicates
290
+ unique_sentences = list(OrderedDict.fromkeys(sentences))
 
 
 
291
 
292
+ # Rejoin with periods
293
+ cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
294
 
295
+ history.append((question, cleaned_result))
296
+ return "", history
297
+
298
+ except Exception as e:
299
+ history.append((question, f"Error: {str(e)}"))
300
+ return "", history
301
+
302
+ def clear_conversation():
303
+ return [], None, "Upload your medical report PDF to get started", []
304
+
305
+ # Improved Gradio Interface
306
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
307
+ gr.Markdown(
308
+ """
309
+ # 🏥 Medical Report Analyzer
310
+
311
+ Upload your medical report and ask questions to understand it better.
312
+ Our AI assistant will help explain your results in plain language.
313
+ """
314
+ )
315
+
316
+ with gr.Row():
317
+ with gr.Column(scale=1):
318
+ with gr.Box():
319
+ gr.Markdown("### 1️⃣ Upload Your Report")
320
+
321
+ file_upload = gr.File(
322
+ file_types=[".pdf"],
323
+ label="Upload Medical Report (PDF)",
324
+ )
325
+
326
+ uploaded_file = gr.Textbox(
327
+ label="Current Report",
328
+ interactive=False,
329
+ visible=False
330
+ )
331
+
332
+ upload_status = gr.Markdown(
333
+ "Upload your medical report PDF to get started"
334
+ )
335
+
336
+ upload_button = gr.Button("Process Report", variant="primary")
337
+
338
+ clear_button = gr.Button("Clear & Start Over", variant="secondary")
339
 
340
+ with gr.Column(scale=2):
341
+ with gr.Box():
342
+ gr.Markdown("### 2️⃣ Ask Questions About Your Report")
343
+
344
+ chat_interface = gr.Chatbot(
345
+ label="Conversation",
346
+ height=400,
347
+ show_copy_button=True,
348
+ )
349
+
350
+ question_input = gr.Textbox(
351
+ label="Ask a question about your report",
352
+ placeholder="e.g., What does my blood pressure mean?",
353
+ )
354
+
355
+ with gr.Row():
356
+ submit_button = gr.Button("Submit Question", variant="primary")
357
+ clear_chat_button = gr.Button("Clear Chat", variant="secondary")
358
+
359
+ parameter_display = gr.JSON(
360
+ label="Extracted Parameters",
361
+ visible=False
362
+ )
363
 
364
+ # Set up interactions
365
+ upload_button.click(
366
+ fn=process_file_upload,
367
+ inputs=[file_upload],
368
+ outputs=[uploaded_file, upload_status, parameter_display]
369
  )
370
 
371
+ submit_button.click(
372
  fn=ask_question_gradio,
373
+ inputs=[question_input, chat_interface],
374
+ outputs=[question_input, chat_interface]
375
+ )
376
+
377
+ question_input.submit(
378
+ fn=ask_question_gradio,
379
+ inputs=[question_input, chat_interface],
380
+ outputs=[question_input, chat_interface]
381
+ )
382
+
383
+ clear_button.click(
384
+ fn=clear_conversation,
385
+ inputs=[],
386
+ outputs=[chat_interface, uploaded_file, upload_status, parameter_display]
387
+ )
388
+
389
+ clear_chat_button.click(
390
+ fn=lambda: ([], ""),
391
+ inputs=[],
392
+ outputs=[chat_interface, question_input]
393
  )
394
 
395
  # Mount the Gradio app to FastAPI