rishi002 commited on
Commit
16e6135
·
verified ·
1 Parent(s): 4326e7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -441
app.py CHANGED
@@ -1,482 +1,265 @@
1
  import os
2
- import shutil
3
- import tempfile
4
- import io
5
- import re
6
- from pathlib import Path
7
  import gradio as gr
8
- import torch
9
- from langchain.chains import RetrievalQA
10
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader
 
 
 
11
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
12
- from langchain.prompts import PromptTemplate
 
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from langchain_community.vectorstores import FAISS
15
  from collections import OrderedDict
16
- import fitz # PyMuPDF for more robust PDF handling
17
 
18
- from fastapi import FastAPI, File, UploadFile, HTTPException, Request
19
- from fastapi.responses import JSONResponse
20
- from fastapi.middleware.cors import CORSMiddleware
21
 
22
  # Constants
23
- KNOWLEDGE_DIR = "medical_knowledge"
24
- VECTOR_STORE_PATH = "vectorstore"
25
- MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" # Gated model requiring authentication
26
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- EMBEDDING_MODEL = "rishi002/all-MiniLM-L6-v2" # Using the embedding model from chatbot code
28
  CACHE_DIR = "/tmp/models_cache"
 
 
 
29
 
30
- # Get HF token from environment variables (set in HF Spaces secrets)
31
- HF_TOKEN = os.environ.get("HF_TOKEN")
32
- if not HF_TOKEN:
33
- print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
34
-
35
- # Create cache directory
36
  os.makedirs(CACHE_DIR, exist_ok=True)
 
 
37
 
38
- class MedicalReportAnalyzer:
39
- def __init__(self):
40
- self.vector_store = None
41
- self.llm = None
42
- self.qa_chain = None
43
- self.user_report_data = "No report data available." # Default value
44
- self.original_report_data = "No original report data available." # Store original data
45
- # Initialize everything
46
- self._load_or_create_vector_store()
47
- self._initialize_llm()
48
- self._setup_qa_chain()
49
 
50
- def _load_or_create_vector_store(self):
51
- """Load existing vector store or create a new one from knowledge documents"""
52
- embeddings = HuggingFaceEmbeddings(
53
- model_name=EMBEDDING_MODEL,
54
- cache_folder=CACHE_DIR
55
- )
56
-
57
- # Check if vector store exists
58
- if os.path.exists(VECTOR_STORE_PATH):
59
- print("Loading existing vector store...")
60
- self.vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True)
61
- else:
62
- print("Creating new vector store from documents...")
63
- # Create knowledge directory if it doesn't exist
64
- os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
65
-
66
- # Check if there are documents to process
67
- if len(os.listdir(KNOWLEDGE_DIR)) == 0:
68
- print(f"Warning: No documents found in {KNOWLEDGE_DIR}. Please add medical PDFs.")
69
- # Initialize empty vector store
70
- self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
71
- self.vector_store.save_local(VECTOR_STORE_PATH)
72
- return
73
-
74
- # Load all PDFs from the knowledge directory
75
- try:
76
- # First try with DirectoryLoader
77
- loader = DirectoryLoader(KNOWLEDGE_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
78
- documents = loader.load()
79
-
80
- # Split documents into chunks
81
- text_splitter = RecursiveCharacterTextSplitter(
82
- chunk_size=1000,
83
- chunk_overlap=200,
84
- length_function=len
85
- )
86
- chunks = text_splitter.split_documents(documents)
87
-
88
- # Create and save the vector store
89
- self.vector_store = FAISS.from_documents(chunks, embeddings)
90
- self.vector_store.save_local(VECTOR_STORE_PATH)
91
- except Exception as e:
92
- print(f"Error loading documents with DirectoryLoader: {str(e)}")
93
- # Initialize with minimal data
94
- self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
95
- self.vector_store.save_local(VECTOR_STORE_PATH)
96
 
97
- def _initialize_llm(self):
98
- """Initialize the language model using HuggingFaceEndpoint"""
99
- print(f"Initializing LLM with {MODEL_NAME}...")
100
- try:
101
- self.llm = HuggingFaceEndpoint(
102
- repo_id=MODEL_NAME,
103
- task="text-generation",
104
- temperature=0.5,
105
- token=HF_TOKEN,
106
- model_kwargs={"max_length": 512}
107
- )
108
- except Exception as e:
109
- print(f"Error initializing HuggingFaceEndpoint: {str(e)}")
110
- # Fallback to a simpler model if needed
111
- fallback_model = "google/flan-t5-large"
112
- print(f"Falling back to {fallback_model}")
113
- self.llm = HuggingFaceEndpoint(
114
- repo_id=fallback_model,
115
- task="text-generation",
116
- temperature=0.5
117
- )
118
 
119
- def _setup_qa_chain(self):
120
- """Set up the question-answering chain"""
121
- # Define a custom prompt template for medical analysis
122
- template = """
123
- You are a medical assistant analyzing patient medical reports. Use the following pieces of context to answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
124
- Also summarize your answer strictly in not more than 350 words and keep the language of your answer simple and easy to understand. Make sure you use easy and simple terms for explanation. Each important point should be stated only once.
125
- Patient Report Summary: {patient_data}
 
126
 
127
- Context from medical knowledge base: {context}
 
 
 
 
 
 
 
 
128
 
129
- Question: {question}
 
130
 
131
- Start the answer directly:
132
- """
 
 
 
 
 
 
 
133
 
134
- # Create prompt with correct variable names
135
- prompt = PromptTemplate(
136
- template=template,
137
- input_variables=["context", "question", "patient_data"]
138
- )
139
 
140
- # Setup the retriever
141
- retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
 
142
 
143
- # Create the QA chain with fixed parameters
144
- self.qa_chain = RetrievalQA.from_chain_type(
145
- llm=self.llm,
146
- chain_type="stuff",
147
- retriever=retriever,
148
- return_source_documents=False,
149
- chain_type_kwargs={"prompt": prompt}
150
  )
151
-
152
- def remove_header_information(self, text):
153
- """Remove header information from the report text"""
154
- # Store the original text
155
- self.original_report_data = text
156
 
157
- # Split the text into lines to analyze
158
- lines = text.split('\n')
 
159
 
160
- # Define patterns to identify header information
161
- header_patterns = [
162
- r'(Name\s*:)',
163
- r'(Patient\s*Name\s*:)',
164
- r'(DOB|Date of Birth\s*:)',
165
- r'(Age\s*:)',
166
- r'(Gender\s*:)',
167
- r'(Lab No\.|Laboratory Number\s*:)',
168
- r'(Patient ID\s*:)',
169
- r'(Report Status\s*:)',
170
- r'(Ref By|Referred By\s*:)',
171
- r'(Collected\s*:)',
172
- r'(Reported\s*:)',
173
- r'(A/c Status\s*:)',
174
- r'(Processed at\s*:)',
175
- r'(Collected at\s*:)',
176
- r'(Address\s*:)',
177
- r'(Phone|Mobile|Mob\s*:)',
178
- ]
179
 
180
- # Create a regex pattern that matches any of the header patterns
181
- combined_pattern = '|'.join(header_patterns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Find where the actual test results begin
184
- test_results_start = -1
185
- for i, line in enumerate(lines):
186
- if re.search(r'(Test\s*Report|Test\s*Name|Test\s*Results|Results|HEMOGRAM|ROUTINE|EXAMINATION)', line, re.IGNORECASE):
187
- test_results_start = i
188
- break
189
 
190
- # If we couldn't find the start of test results, look for key medical terms
191
- if test_results_start == -1:
192
- for i, line in enumerate(lines):
193
- # Look for common test result sections
194
- if re.search(r'(Hemoglobin|Blood|Urine|CBC|Glucose|Cholesterol|Protein|RBC|WBC)', line, re.IGNORECASE):
195
- test_results_start = max(0, i-3) # Start a few lines before the first test result
196
- break
197
 
198
- # If we still couldn't find the start of test results, use a heuristic:
199
- # Skip the first few lines which usually contain header information
200
- if test_results_start == -1:
201
- # Count lines with patient identifiable information
202
- header_count = 0
203
- for i, line in enumerate(lines):
204
- if re.search(combined_pattern, line, re.IGNORECASE):
205
- header_count += 1
206
-
207
- # If we found several header lines, skip those plus a few more
208
- if header_count > 0:
209
- test_results_start = min(header_count + 5, len(lines) // 3)
210
- else:
211
- # If no clear header pattern was found, just skip the first 10% of lines as a fallback
212
- test_results_start = max(1, len(lines) // 10)
213
-
214
- # Return text from the determined start point
215
- clean_text = '\n'.join(lines[test_results_start:])
216
 
217
- # If this dramatically shortened the text, use a less aggressive approach
218
- if len(clean_text) < len(text) * 0.5:
219
- print("Warning: Header removal may have removed too much content. Using alternative approach.")
220
- # Alternative approach: Just remove lines with header patterns
221
- filtered_lines = []
222
- for line in lines:
223
- if not re.search(combined_pattern, line, re.IGNORECASE):
224
- filtered_lines.append(line)
225
- clean_text = '\n'.join(filtered_lines)
226
-
227
- return clean_text
228
-
229
- def extract_text_from_pdf_pymupdf(self, pdf_path):
230
- """Extract text from PDF using PyMuPDF (more robust than PyPDF)"""
231
- text = ""
232
- try:
233
- doc = fitz.open(pdf_path)
234
- for page in doc:
235
- text += page.get_text()
236
- doc.close()
237
- return text
238
- except Exception as e:
239
- print(f"PyMuPDF extraction error: {str(e)}")
240
- return None
241
-
242
- def extract_text_from_pdf_pypdf(self, pdf_path):
243
- """Extract text using PyPDF as a backup method"""
244
- try:
245
- loader = PyPDFLoader(pdf_path)
246
- pages = loader.load()
247
- return "\n".join([page.page_content for page in pages])
248
- except Exception as e:
249
- print(f"PyPDF extraction error: {str(e)}")
250
- return None
251
 
252
- def process_user_report(self, report_file):
253
- """Process the uploaded medical report with multiple fallback methods"""
254
- if report_file is None:
255
- return "No file uploaded. Please upload a medical report."
256
-
257
- # Ensure the uploaded file is read as bytes
258
- temp_dir = tempfile.mkdtemp()
259
- try:
260
- # Copy the uploaded file to the temp directory
261
- temp_file_path = os.path.join(temp_dir, "user_report.pdf")
262
-
263
- # Handle file based on its type
264
- try:
265
- if isinstance(report_file, str): # If it's a file path
266
- shutil.copy(report_file, temp_file_path)
267
- elif hasattr(report_file, 'name'): # Gradio file object
268
- with open(temp_file_path, 'wb') as f:
269
- with open(report_file.name, 'rb') as source:
270
- f.write(source.read())
271
- else: # Try to handle as bytes or file-like object
272
- with open(temp_file_path, 'wb') as f:
273
- f.write(report_file.read() if hasattr(report_file, 'read') else report_file)
274
- except Exception as e:
275
- print(f"Error saving file: {str(e)}")
276
- return f"Error saving the uploaded file: {str(e)}"
277
-
278
- # Try multiple methods to extract text from the PDF
279
- text = None
280
-
281
- # Method 1: PyMuPDF
282
- text = self.extract_text_from_pdf_pymupdf(temp_file_path)
283
-
284
- # Method 2: PyPDF as fallback
285
- if not text:
286
- text = self.extract_text_from_pdf_pypdf(temp_file_path)
287
-
288
- # Method 3: Last resort - try to read as raw text
289
- if not text:
290
- try:
291
- with open(temp_file_path, 'r', errors='ignore') as f:
292
- text = f.read()
293
- except Exception as e:
294
- print(f"Raw text reading error: {str(e)}")
295
 
296
- # If we got text, process it
297
- if text and len(text.strip()) > 0:
298
- # Remove header information from the text
299
- cleaned_text = self.remove_header_information(text)
300
-
301
- # Store the cleaned text
302
- self.user_report_data = cleaned_text
303
-
304
- # Split into chunks if needed
305
- text_splitter = RecursiveCharacterTextSplitter(
306
- chunk_size=1000,
307
- chunk_overlap=200,
308
- length_function=len
309
- )
310
- chunks = text_splitter.split_text(cleaned_text)
311
-
312
- # Check if too much text was removed
313
- original_length = len(text.strip())
314
- cleaned_length = len(cleaned_text.strip())
315
- removal_percentage = (original_length - cleaned_length) / original_length * 100
316
-
317
- if removal_percentage > 80:
318
- return f"Report processed successfully, but significant content may have been filtered. Original length: {original_length} chars. Cleaned length: {cleaned_length} chars. Extracted approximately {len(chunks)} text chunks."
319
- else:
320
- return f"Report processed successfully. Removed approximately {removal_percentage:.1f}% of header content. Extracted {len(chunks)} text chunks."
321
- else:
322
- self.user_report_data = "Unable to extract text from the provided PDF. This is an empty report placeholder."
323
- return "Warning: Could not extract text from the PDF. The file may be corrupted, password-protected, or contain only images. Processing will continue with limited data."
324
-
325
- finally:
326
- # Clean up the temporary directory and file
327
- shutil.rmtree(temp_dir)
328
-
329
- def answer_question(self, question):
330
- """Answer a question based on the uploaded report and knowledge base"""
331
- if not self.user_report_data or self.user_report_data == "No report data available.":
332
- return "No report has been processed or text extraction failed. Please upload a medical report first."
333
 
334
- # Check if question is about patient demographics or identification
335
- demographic_patterns = [
336
- r'(patient|name|age|gender|birth|dob|address|phone|contact|id|identification)',
337
- r'(doctor|physician|referring|referred by)',
338
- r'(date|time|collected|processed|reported)',
339
- r'(lab|laboratory|number|id)'
340
- ]
341
 
342
- combined_demo_pattern = '|'.join(demographic_patterns)
343
 
344
- # If question might be about demographics, check if we need to use original data
345
- if re.search(combined_demo_pattern, question, re.IGNORECASE):
346
- # For demographic questions, check if it's asking for specific identification
347
- specific_id_patterns = [
348
- r'(name of|patient name|who is|what is the name)',
349
- r'(exact age|age of|how old)',
350
- r'(address of|where|location|contact details)',
351
- r'(doctor name|name of doctor|referring doctor|who referred)',
352
- r'(date of|when was|time of|report date)',
353
- r'(lab number|patient id|identification number)'
354
- ]
355
-
356
- specific_id_pattern = '|'.join(specific_id_patterns)
357
-
358
- # If it's a direct question about patient identity, don't answer
359
- if re.search(specific_id_pattern, question, re.IGNORECASE):
360
- return "I'm unable to provide specific patient identification information. This feature is disabled to protect patient privacy. Please ask about medical test results or interpretations instead."
361
 
362
- # Try using the QA chain with proper error handling
363
  try:
364
- # Pass the query to the qa_chain along with the patient data
365
- response = self.qa_chain({"query": question, "patient_data": self.user_report_data})
366
 
367
- # Extract the answer from the response
368
- if isinstance(response, dict) and "result" in response:
369
- # Get the raw result
370
- result = response["result"]
371
-
372
- # Process like in the chatbot code - remove duplicates
373
- sentences = [s.strip() for s in result.split('.') if s.strip()]
374
- unique_sentences = list(OrderedDict.fromkeys(sentences))
375
- cleaned_result = '. '.join(unique_sentences)
376
-
377
- # Add period if needed
378
- if cleaned_result and not cleaned_result.endswith('.'):
379
- cleaned_result += '.'
380
-
381
- return cleaned_result
382
- else:
383
- return str(response)
384
-
385
- except Exception as e:
386
- print(f"Error in QA chain: {str(e)}")
387
- # Log details about the error for debugging
388
- print(f"QA chain type: {type(self.qa_chain).__name__}")
389
 
390
- # Fallback to direct LLM call
391
- try:
392
- direct_prompt = f"""
393
- Act as an expert doctor who performs medical report analysis accurately. Analyze the given patient data and provide me the answer to the question asked about the medical report in strictly less than 200 words.
394
- NOTE : I ONLY WANT THE ANSWER FROM YOU, DO NOT GIVE ME THE PATIENT REPORT DETAILS AND THE QUESTIONS WHICH I ASKED IN YOUR ANSWERS.
395
- Also use simple and easy to understand terms in your answer and keep your answer in easy to understand language.
396
- Question about medical report: {question}
397
-
398
- Patient data available: {self.user_report_data[:800]}... (truncated)
399
-
400
- Please answer based on this information:
401
- """
402
-
403
- direct_result = self.llm(direct_prompt)
404
- return f"Fallback answer {direct_result}"
405
- except Exception as fallback_error:
406
- print(f"Fallback also failed: {str(fallback_error)}")
407
- return f"Error processing your question. Please try a different question or report."
408
-
409
- # Initialize the analyzer
410
- analyzer = MedicalReportAnalyzer()
411
-
412
- # FastAPI app
413
- app = FastAPI()
414
-
415
- # CORS support for frontend testing
416
- app.add_middleware(
417
- CORSMiddleware,
418
- allow_origins=["*"],
419
- allow_credentials=True,
420
- allow_methods=["*"],
421
- allow_headers=["*"],
422
- )
423
-
424
- @app.post("/process_user_report")
425
- async def process_user_report(report_file: UploadFile = File(...)):
426
- try:
427
- temp_dir = tempfile.mkdtemp()
428
- temp_file_path = os.path.join(temp_dir, report_file.filename)
429
-
430
- with open(temp_file_path, "wb") as f:
431
- shutil.copyfileobj(report_file.file, f)
432
-
433
- result = analyzer.process_user_report(temp_file_path)
434
- return {"status": "success", "message": result}
435
-
436
- except Exception as e:
437
- raise HTTPException(status_code=500, detail=str(e))
438
- finally:
439
- report_file.file.close()
440
- shutil.rmtree(temp_dir, ignore_errors=True)
441
-
442
- @app.post("/answer_question")
443
- async def answer_question(request: Request):
444
- try:
445
- data = await request.json()
446
- question = data.get("question", "").strip()
447
-
448
- if not question:
449
- raise HTTPException(status_code=400, detail="Question is required")
450
-
451
- answer = analyzer.answer_question(question)
452
- return {"status": "success", "answer": answer}
453
 
454
- except Exception as e:
455
- raise HTTPException(status_code=500, detail=str(e))
456
 
457
- # Optional: Keep Gradio interface for debugging or UI testing
458
  if __name__ == "__main__":
459
- with gr.Blocks(title="Medical Report Analyzer") as demo:
460
- gr.Markdown("# Medical Report Analyzer")
461
- gr.Markdown("Upload your medical report and ask questions about it. The system will analyze your report and provide answers based on medical knowledge.")
462
-
463
- with gr.Row():
464
- with gr.Column(scale=1):
465
- report_file = gr.File(label="Upload Medical Report (PDF)")
466
- upload_button = gr.Button("Process Report")
467
- upload_output = gr.Textbox(label="Processing Status")
468
-
469
- with gr.Column(scale=2):
470
- question_input = gr.Textbox(label="Ask a question about your report")
471
- answer_button = gr.Button("Get Answer")
472
- answer_output = gr.Textbox(label="Answer")
473
-
474
- upload_button.click(fn=analyzer.process_user_report, inputs=[report_file], outputs=[upload_output])
475
- answer_button.click(fn=analyzer.answer_question, inputs=[question_input], outputs=[answer_output])
476
-
477
- demo.launch(
478
- share=True,
479
- favicon_path="favicon.ico" if os.path.exists("favicon.ico") else None,
480
- server_name="0.0.0.0",
481
- server_port=7860
482
- )
 
1
  import os
 
 
 
 
 
2
  import gradio as gr
3
+ import requests
4
+ import tempfile
5
+ from fastapi import FastAPI, HTTPException, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
10
+ from langchain.chains import RetrievalQA
11
+ from langchain_core.prompts import PromptTemplate
12
+ from langchain_community.document_loaders import PyPDFLoader
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
14
  from collections import OrderedDict
 
15
 
16
+ # Retrieve HF_TOKEN from environment
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
18
 
19
  # Constants
 
 
 
 
 
20
  CACHE_DIR = "/tmp/models_cache"
21
+ DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
22
+ USER_REPORT_DB_PATH = "/tmp/vectorstore/user_report_db"
23
+ HUGGINGFACE_REPO_ID = "microsoft/Phi-3-mini-4k-instruct"
24
 
25
+ # Create directories
 
 
 
 
 
26
  os.makedirs(CACHE_DIR, exist_ok=True)
27
+ os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True)
28
+ os.makedirs(os.path.dirname(USER_REPORT_DB_PATH), exist_ok=True)
29
 
30
+ # Initialize FastAPI app
31
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
32
 
33
+ # Add CORS middleware
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Load the embedding model
43
+ embedding_model = HuggingFaceEmbeddings(
44
+ model_name="rishi002/all-MiniLM-L6-v2",
45
+ cache_folder=CACHE_DIR
46
+ )
47
+
48
+ # Global variables to track report status and database
49
+ user_report_processed = False
50
+ user_report_db = None
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Load LLM
53
+ def load_llm():
54
+ return HuggingFaceEndpoint(
55
+ repo_id=HUGGINGFACE_REPO_ID,
56
+ task="text-generation",
57
+ temperature=0.5,
58
+ model_kwargs={"token": HF_TOKEN, "max_length": 512}
59
+ )
60
 
61
+ # Custom prompt template for medical report analysis
62
+ MEDICAL_REPORT_PROMPT = """
63
+ You are a helpful medical assistant analyzing a patient's medical report.
64
+ Use only the information provided in the context to answer the user's question.
65
+ If you don't know the answer based on the given context, simply state that you don't have enough information.
66
+ Don't make up any medical information or conclusions not supported by the report.
67
+ Provide concise, clear explanations in simple language that a patient can understand.
68
+ Avoid using complex medical terminology unless necessary, and if used, briefly explain what it means.
69
+ Keep your answer concise and focused on the question asked.
70
 
71
+ Context: {context}
72
+ Question: {question}
73
 
74
+ Start the answer directly without repeating the question.
75
+ """
76
+
77
+ # Function to download and process PDF from URL
78
+ def process_pdf_from_url(pdf_url):
79
+ try:
80
+ # Download the PDF from the URL
81
+ response = requests.get(pdf_url)
82
+ response.raise_for_status() # Raise exception for bad status codes
83
 
84
+ # Create a temporary file to save the PDF
85
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
86
+ temp_pdf.write(response.content)
87
+ temp_path = temp_pdf.name
 
88
 
89
+ # Load the PDF
90
+ loader = PyPDFLoader(temp_path)
91
+ documents = loader.load()
92
 
93
+ # Split documents into chunks
94
+ text_splitter = RecursiveCharacterTextSplitter(
95
+ chunk_size=1000,
96
+ chunk_overlap=200
 
 
 
97
  )
98
+ text_chunks = text_splitter.split_documents(documents)
 
 
 
 
99
 
100
+ # Create vector database from the text chunks
101
+ db = FAISS.from_documents(text_chunks, embedding_model)
102
+ db.save_local(USER_REPORT_DB_PATH)
103
 
104
+ # Clean up the temporary file
105
+ os.unlink(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ return True
108
+
109
+ except Exception as e:
110
+ print(f"Error processing PDF: {str(e)}")
111
+ return False
112
+
113
+ # Create QA chain for user report
114
+ def create_user_report_qa_chain():
115
+ if not os.path.exists(USER_REPORT_DB_PATH):
116
+ return None
117
+
118
+ db = FAISS.load_local(USER_REPORT_DB_PATH, embedding_model, allow_dangerous_deserialization=True)
119
+
120
+ prompt = PromptTemplate(template=MEDICAL_REPORT_PROMPT, input_variables=["context", "question"])
121
+
122
+ return RetrievalQA.from_chain_type(
123
+ llm=load_llm(),
124
+ chain_type="stuff",
125
+ retriever=db.as_retriever(search_kwargs={'k': 3}),
126
+ return_source_documents=False,
127
+ chain_type_kwargs={'prompt': prompt}
128
+ )
129
+
130
+ # API Models
131
+ class ReportURL(BaseModel):
132
+ url: str
133
+
134
+ class Question(BaseModel):
135
+ query: str
136
+
137
+ # Combined API endpoint to process a PDF report from a URL and return status
138
+ @app.post("/api/process-report")
139
+ async def process_report(report_data: ReportURL):
140
+ global user_report_processed, user_report_db
141
+
142
+ # Process the PDF from the URL
143
+ success = process_pdf_from_url(report_data.url)
144
+
145
+ if success:
146
+ user_report_processed = True
147
+ user_report_db = create_user_report_qa_chain()
148
+ return {
149
+ "status": "success",
150
+ "message": "Medical report data extracted successfully",
151
+ "processed": True
152
+ }
153
+ else:
154
+ user_report_processed = False
155
+ return {
156
+ "status": "error",
157
+ "message": "Failed to process the medical report",
158
+ "processed": False
159
+ }
160
+
161
+ # API endpoint to ask questions about the processed report
162
+ @app.post("/api/ask-question")
163
+ async def ask_question(question_data: Question):
164
+ global user_report_db, user_report_processed
165
+
166
+ if not user_report_processed or user_report_db is None:
167
+ raise HTTPException(status_code=400, detail="No medical report has been processed yet")
168
+
169
+ try:
170
+ # Get answer from the QA chain
171
+ response = user_report_db.invoke({'query': question_data.query})
172
 
173
+ # Get the raw result
174
+ result = response["result"]
 
 
 
 
175
 
176
+ # Remove duplicates by splitting into sentences and keeping only unique ones
177
+ sentences = [s.strip() for s in result.split('.') if s.strip()]
178
+ # Use OrderedDict to preserve order while removing duplicates
179
+ unique_sentences = list(OrderedDict.fromkeys(sentences))
 
 
 
180
 
181
+ # Rejoin with periods
182
+ cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ return {"answer": cleaned_result}
185
+
186
+ except Exception as e:
187
+ raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # Gradio Interface
190
+ with gr.Blocks() as iface:
191
+ gr.Markdown("# Medical Report Analysis")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ pdf_url_input = gr.Textbox(label="Enter PDF Report URL")
196
+ process_button = gr.Button("Analyze Report")
197
+ status_text = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ with gr.Row():
200
+ with gr.Column():
201
+ query_input = gr.Textbox(label="Ask a question about your report")
202
+ query_button = gr.Button("Submit Question")
203
+ answer_output = gr.Textbox(label="Answer", interactive=False)
204
+
205
+ def process_report_gradio(url):
206
+ global user_report_processed, user_report_db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ if not url:
209
+ return "Please enter a valid URL"
 
 
 
 
 
210
 
211
+ success = process_pdf_from_url(url)
212
 
213
+ if success:
214
+ user_report_processed = True
215
+ user_report_db = create_user_report_qa_chain()
216
+ return "Medical report data extracted successfully. You can now ask questions about your report."
217
+ else:
218
+ user_report_processed = False
219
+ return "Failed to process the medical report. Please check the URL and try again."
220
+
221
+ def ask_question_gradio(query):
222
+ global user_report_db, user_report_processed
223
+
224
+ if not user_report_processed or user_report_db is None:
225
+ return "No medical report has been processed yet. Please upload and analyze a report first."
 
 
 
 
226
 
 
227
  try:
228
+ # Get answer from the QA chain
229
+ response = user_report_db.invoke({'query': query})
230
 
231
+ # Get the raw result
232
+ result = response["result"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # Remove duplicates by splitting into sentences and keeping only unique ones
235
+ sentences = [s.strip() for s in result.split('.') if s.strip()]
236
+ # Use OrderedDict to preserve order while removing duplicates
237
+ unique_sentences = list(OrderedDict.fromkeys(sentences))
238
+
239
+ # Rejoin with periods
240
+ cleaned_result = '. '.join(unique_sentences) + '.' if unique_sentences else ""
241
+
242
+ return cleaned_result
243
+
244
+ except Exception as e:
245
+ return f"Error: {str(e)}"
246
+
247
+ process_button.click(
248
+ fn=process_report_gradio,
249
+ inputs=pdf_url_input,
250
+ outputs=status_text
251
+ )
252
+
253
+ query_button.click(
254
+ fn=ask_question_gradio,
255
+ inputs=query_input,
256
+ outputs=answer_output
257
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ # Mount the Gradio app to FastAPI
260
+ app = gr.mount_gradio_app(app, iface, path="/")
261
 
262
+ # Run the app with uvicorn
263
  if __name__ == "__main__":
264
+ import uvicorn
265
+ uvicorn.run(app, host="0.0.0.0", port=7860)