DeathBlade020 commited on
Commit
c896826
Β·
verified Β·
1 Parent(s): 2ac2fa7

Update src/get_graph.py

Browse files
Files changed (1) hide show
  1. src/get_graph.py +403 -368
src/get_graph.py CHANGED
@@ -1,368 +1,403 @@
1
- from typing import TypedDict, Literal
2
- from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
3
- from pydantic import Field
4
- from pydantic import BaseModel
5
- import streamlit as st
6
- from langchain_core.messages import HumanMessage, SystemMessage
7
- from get_medical_system import load_medical_system
8
- from langchain.prompts import PromptTemplate
9
- from langchain.chains import ConversationalRetrievalChain
10
- from langchain_community.document_loaders import AmazonTextractPDFLoader
11
- from langgraph.graph import StateGraph, END, START
12
- from langchain_community.document_loaders import UnstructuredPDFLoader
13
-
14
-
15
- class Route(BaseModel):
16
- step: Literal["RAG", "GENERAL", "EMERGENCY", "MEMORY"] = Field(None, description="The next step in the routing process") # type: ignore
17
-
18
- class State(TypedDict):
19
- question: str
20
- answer: str
21
- decision: str
22
-
23
-
24
- from langchain_core.retrievers import BaseRetriever
25
- from typing import List
26
- from langchain_core.documents import Document
27
-
28
-
29
-
30
- def init_document_memory():
31
- """Initialize document memory in session state"""
32
- if "uploaded_documents" not in st.session_state:
33
- st.session_state.uploaded_documents = {}
34
-
35
- documents, ensemble_retriever, llm, reranker = load_medical_system()
36
- router = llm.with_structured_output(Route, method="function_calling")
37
-
38
-
39
- def extract_conversation_history():
40
- """Extract conversation from session state"""
41
- if "messages" not in st.session_state:
42
- return []
43
-
44
- conversation = []
45
- for msg in st.session_state.messages:
46
- if msg["role"] == "user":
47
- conversation.append(f"User: {msg['content']}")
48
- elif msg["role"] == "assistant" and not msg["content"].startswith("Hello!"):
49
- conversation.append(f"Assistant: {msg['content']}")
50
-
51
- return conversation
52
-
53
- def handle_conversation_query(state: State):
54
- """Handle questions about conversation history"""
55
-
56
- conversation = extract_conversation_history()
57
-
58
- if not conversation:
59
- return {"answer": "We haven't had any conversation yet. Feel free to ask me a medical question though!"}
60
-
61
- # Create conversation context
62
- conversation_text = "\n".join(conversation[-10:]) # Last 10 exchanges
63
-
64
- result = llm.invoke([
65
- SystemMessage(content=f"""
66
- Based on this conversation history, answer the user's question about our previous discussion:
67
-
68
- Conversation History:
69
- {conversation_text}
70
-
71
- Rules:
72
- - If they ask for a summary, provide a brief overview
73
- - If they ask about specific questions, reference them
74
- - If they ask about previous answers, summarize the key points
75
- - Always maintain medical disclaimers in your response
76
- """),
77
- HumanMessage(content=state['question'])
78
- ])
79
-
80
- return {"answer": result.content}
81
-
82
- def is_conversation_query(question: str) -> bool:
83
- """Check if the question is about conversation history"""
84
- memory_keywords = [
85
- "previous", "last", "earlier", "before", "summarize", "summarise",
86
- "what did i ask", "my questions", "conversation", "history",
87
- "we talked", "discussed", "mentioned"
88
- ]
89
-
90
- question_lower = question.lower()
91
- return any(keyword in question_lower for keyword in memory_keywords)
92
-
93
-
94
- def llm_call_router(state: State):
95
- """Enhanced router that includes document routing"""
96
- # if st.session_state.get("current_document"):
97
- # return {'decision': "DOCUMENT"}
98
-
99
- # Check for conversation/memory queries FIRST
100
- if is_conversation_query(state['question']):
101
- return {'decision': "MEMORY"}
102
-
103
- # Check if question is about an uploaded document
104
- # document_keywords = ["document", "report", "lab results", "test results", "my results", "uploaded", "file"]
105
- # if any(keyword in state['question'].lower() for keyword in document_keywords):
106
- # if "current_document" in st.session_state and st.session_state.current_document:
107
- # return {'decision': "DOCUMENT"}
108
-
109
- # Emergency check
110
- emergency_keywords = ["severe", "chest pain", "can't breathe", "emergency", "urgent",
111
- "heart attack", "stroke", "bleeding", "unconscious"]
112
- question_lower = state['question'].lower()
113
- if any(keyword in question_lower for keyword in emergency_keywords):
114
- return {'decision': "EMERGENCY"}
115
-
116
- # Regular routing
117
- decision = router.invoke([
118
- SystemMessage(content="Route the input to RAG (medical questions) or GENERAL based on the user's request"),
119
- HumanMessage(content=state['question'])
120
- ])
121
- return {"decision": decision.step} # type: ignore
122
-
123
- def emergency_node(state: State):
124
- """Handle emergency queries safely"""
125
- return {"answer": "🚨 EMERGENCY: Please seek immediate medical attention or call emergency services (911). This system cannot provide emergency medical care."}
126
-
127
- def rag_node(state: State):
128
- """Uses RAG to answer the question with reranking"""
129
-
130
- custom_prompt = PromptTemplate(
131
- input_variables=["context", "question"],
132
- template="""You are a medical information assistant. Use the following medical Q&A context to answer questions accurately and safely.
133
-
134
- Context: {context}
135
-
136
- Question: {question}
137
-
138
- Guidelines:
139
- - Provide accurate medical information based on the context above
140
- - Always recommend consulting healthcare professionals for medical decisions
141
- - If uncertain, clearly state limitations
142
- - If the question is not suitable for this bot, respond with: "I'm not able to provide medical advice. Please consult a medical professional."
143
-
144
- Answer:"""
145
- )
146
-
147
- qa_chain = ConversationalRetrievalChain.from_llm(
148
- llm=llm,
149
- retriever=ensemble_retriever,
150
- return_source_documents=True,
151
- combine_docs_chain_kwargs={"prompt": custom_prompt}
152
- )
153
-
154
- result = qa_chain.invoke({
155
- "question": state['question'],
156
- "chat_history": []
157
- })
158
-
159
- # Reranking
160
- docs = result.get('source_documents', [])
161
- if docs and len(docs) > 1:
162
- pairs = [(state['question'], doc.page_content) for doc in docs]
163
- scores = reranker.predict(pairs)
164
-
165
- doc_scores = list(zip(docs, scores))
166
- doc_scores.sort(key=lambda x: x[1], reverse=True)
167
- top_docs = [doc for doc, score in doc_scores[:3]]
168
-
169
- better_context = "\\n\\n".join([doc.page_content for doc in top_docs])
170
- improved_answer = llm.invoke([
171
- SystemMessage(content=f"""Use this medical context to answer the question safely:
172
-
173
- Context: {better_context}
174
-
175
- Always recommend consulting healthcare professionals."""),
176
- HumanMessage(content=state['question'])
177
- ])
178
- return {"answer": improved_answer.content}
179
-
180
- return {"answer": result['answer']}
181
-
182
- def general_node(state: State):
183
- """Enhanced general node with sarcastic responses for identity questions"""
184
-
185
- question_lower = state['question'].lower().strip()
186
-
187
- # Identity/philosophical questions - sarcastic responses
188
- identity_keywords = [
189
- "what are you", "who are you", "what is your name", "are you human",
190
- "are you real", "are you ai", "are you robot", "are you chatbot",
191
- "what's your name", "who made you", "are you alive", "do you think",
192
- "are you conscious", "do you feel", "what do you do", "your purpose"
193
- ]
194
-
195
- if any(keyword in question_lower for keyword in identity_keywords):
196
- # Sarcastic responses for identity questions
197
- sarcastic_responses = [
198
- "πŸ€– Oh, just your friendly neighborhood medical AI trying to keep people from WebMD-ing themselves into thinking they have every disease known to humanity. You know, the usual.",
199
-
200
- "🩺 I'm a sophisticated medical assistant, which is a fancy way of saying I'm here to tell you to 'consult a healthcare professional' in 47 different ways.",
201
-
202
- "πŸ₯ I'm an AI that reads medical textbooks faster than you can say 'Google symptoms at 3 AM.' My purpose? Giving you actual medical info instead of letting you convince yourself that headache is definitely a brain tumor.",
203
-
204
- "πŸ’Š I'm basically a walking medical disclaimer with a personality. Think of me as that friend who went to med school but actually remembers what they learned.",
205
-
206
- "πŸ”¬ I'm an artificial intelligence trained on medical knowledge, which means I can tell you about symptoms but I still can't fix your tendency to ignore doctor's appointments.",
207
-
208
- "🧠 I'm a medical AI assistant. I exist to answer your health questions and remind you that, no, that WebMD article probably doesn't apply to you."
209
- ]
210
-
211
- import random
212
- return {"answer": random.choice(sarcastic_responses)}
213
-
214
- # Greeting responses - also with some personality
215
- greeting_keywords = ["hello", "hi", "hey", "good morning", "good evening", "greetings"]
216
- if any(keyword in question_lower for keyword in greeting_keywords):
217
- friendly_responses = [
218
- "Hello! πŸ‘‹ Ready to get some actual medical information instead of falling down a WebMD rabbit hole?",
219
- "Hi there! πŸ₯ I'm here to answer your medical questions. Fair warning: I'll probably tell you to see a real doctor.",
220
- "Hey! πŸ‘¨β€βš•οΈ What medical mystery can I help solve today? (Spoiler: the answer might be 'drink more water')",
221
- "Greetings! 🩺 Ask me anything medical-related. I promise to give you better advice than your cousin's Facebook post."
222
- ]
223
-
224
- import random
225
- return {"answer": random.choice(friendly_responses)}
226
-
227
- # Regular medical or general questions
228
- result = llm.invoke([
229
- SystemMessage(content="""
230
- Answer the user's question helpfully and accurately.
231
-
232
- IMPORTANT SAFETY RULES:
233
- - For medical questions: Always end with "Please consult a healthcare professional"
234
- - For emergencies: Direct to call emergency services immediately
235
- - If unsure: Say "I don't know" rather than guess
236
-
237
- Be helpful but prioritize user safety. You can be slightly witty or conversational, but always maintain professionalism for serious medical topics.
238
- """),
239
- HumanMessage(content=state['question'])
240
- ])
241
-
242
- return {"answer": result.content}
243
-
244
- def document_node(state: State):
245
- """Simple document processing node that integrates with your existing workflow"""
246
-
247
-
248
- # Check if there's an uploaded document in session state
249
- if "current_document" not in st.session_state or not st.session_state.current_document:
250
- return {"answer": "Please upload a medical document first using the file uploader in the sidebar."}
251
-
252
- file_path = st.session_state.current_document
253
- question = state['question']
254
-
255
- try:
256
- # Check if document already processed
257
- if file_path not in st.session_state.uploaded_documents:
258
- # Extract document content
259
- # loader = AmazonTextractPDFLoader(file_path, region_name="us-east-1")
260
- loader = UnstructuredPDFLoader(file_path)
261
- documents = loader.load()
262
-
263
- # Clean and store content
264
- content = "\n".join([doc.page_content for doc in documents])
265
- st.session_state.uploaded_documents[file_path] = {
266
- "content": content,
267
- "conversation": []
268
- }
269
-
270
- # Get stored document
271
- doc_data = st.session_state.uploaded_documents[file_path]
272
-
273
- # Build context with previous questions about this document
274
- context_parts = [f"Document Content:\n{doc_data['content']}"]
275
-
276
- if doc_data['conversation']:
277
- context_parts.append("\nPrevious questions about this document:")
278
- for qa in doc_data['conversation'][-3:]: # Last 3 Q&As
279
- context_parts.append(f"Q: {qa['question']}\nA: {qa['answer'][:200]}...")
280
-
281
- full_context = "\n".join(context_parts)
282
-
283
- # Generate answer using your existing LLM
284
- from langchain_core.messages import HumanMessage, SystemMessage
285
-
286
- result = llm.invoke([
287
- SystemMessage(content=f"""
288
- You are analyzing a medical document. Use the document content and any previous conversation to answer the user's question.
289
-
290
- Guidelines:
291
- - Base your answer on the document content provided
292
- - Reference specific values or sections when possible
293
- - If information isn't in the document, clearly state this
294
- - Always include medical disclaimers
295
- - Maintain conversation continuity with previous questions
296
-
297
- {full_context}
298
- """),
299
- HumanMessage(content=f"Question about the document: {question}")
300
- ])
301
-
302
- # Store this Q&A in document conversation history
303
- doc_data['conversation'].append({
304
- "question": question,
305
- "answer": result.content
306
- })
307
-
308
- return {"answer": f"πŸ“„ **Document Analysis:**\n\n{result.content}"}
309
-
310
- except Exception as e:
311
- return {"answer": f"Error processing document: {str(e)}. Please ensure the file is accessible and try again."}
312
-
313
-
314
- def route_decision(state: State):
315
- """Enhanced route decision with memory"""
316
- if state["decision"] == "MEMORY":
317
- return "memory_node"
318
- elif state["decision"] == "DOCUMENT":
319
- return "document_node"
320
- elif state["decision"] == "RAG":
321
- return "rag_node"
322
- elif state["decision"] == "EMERGENCY":
323
- return "emergency_node"
324
- else:
325
- return "general_node"
326
-
327
- # ==================== CREATE WORKFLOW ====================
328
-
329
- @st.cache_resource
330
- def create_workflow():
331
- """Create the enhanced workflow graph with memory"""
332
-
333
- init_document_memory()
334
-
335
- router_builder = StateGraph(State)
336
-
337
- # Add all nodes (including new memory node)
338
- router_builder.add_node("rag_node", rag_node)
339
- router_builder.add_node("general_node", general_node)
340
- router_builder.add_node("llm_call_router", llm_call_router)
341
- router_builder.add_node("emergency_node", emergency_node)
342
- router_builder.add_node("memory_node", handle_conversation_query) # NEW NODE
343
- # router_builder.add_node("document_node", document_node)
344
-
345
-
346
- router_builder.add_edge(START, "llm_call_router")
347
- router_builder.add_conditional_edges(
348
- "llm_call_router",
349
- route_decision,
350
- {
351
- "rag_node": "rag_node",
352
- "general_node": "general_node",
353
- "emergency_node": "emergency_node",
354
- "memory_node": "memory_node", # NEW ROUTE,
355
- # "document_node": "document_node"
356
- },
357
- )
358
-
359
- # Add edges to END
360
- router_builder.add_edge("rag_node", END)
361
- router_builder.add_edge("general_node", END)
362
- router_builder.add_edge("emergency_node", END)
363
- router_builder.add_edge("memory_node", END) # NEW EDGE
364
- # router_builder.add_edge("document_node", END)
365
-
366
- return router_builder.compile()
367
-
368
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Literal
2
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
3
+ from pydantic import Field
4
+ from pydantic import BaseModel
5
+ import streamlit as st
6
+ from langchain_core.messages import HumanMessage, SystemMessage
7
+ from get_medical_system import load_medical_system
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_community.document_loaders import AmazonTextractPDFLoader
11
+ from langgraph.graph import StateGraph, END, START
12
+ from langchain_community.document_loaders import UnstructuredPDFLoader
13
+
14
+
15
+ class Route(BaseModel):
16
+ step: Literal["RAG", "GENERAL", "EMERGENCY", "MEMORY"] = Field(None, description="The next step in the routing process") # type: ignore
17
+
18
+ class State(TypedDict):
19
+ question: str
20
+ answer: str
21
+ decision: str
22
+
23
+
24
+ from langchain_core.retrievers import BaseRetriever
25
+ from typing import List
26
+ from langchain_core.documents import Document
27
+ from functools import lru_cache
28
+ import hashlib
29
+
30
+
31
+
32
+ def get_cache_key(question: str) -> str:
33
+ """Create a cache key from the question"""
34
+ return hashlib.md5(question.lower().strip().encode()).hexdigest()
35
+
36
+ def get_cached_answer(question: str):
37
+ """Get cached answer if exists"""
38
+ if "qa_cache" not in st.session_state:
39
+ st.session_state.qa_cache = {}
40
+
41
+ cache_key = get_cache_key(question)
42
+ return st.session_state.qa_cache.get(cache_key)
43
+
44
+ def cache_answer(question: str, answer: str):
45
+ """Cache the question-answer pair"""
46
+ if "qa_cache" not in st.session_state:
47
+ st.session_state.qa_cache = {}
48
+
49
+ cache_key = get_cache_key(question)
50
+ st.session_state.qa_cache[cache_key] = answer
51
+
52
+
53
+ def init_document_memory():
54
+ """Initialize document memory in session state"""
55
+ if "uploaded_documents" not in st.session_state:
56
+ st.session_state.uploaded_documents = {}
57
+
58
+ documents, ensemble_retriever, llm, reranker = load_medical_system()
59
+ router = llm.with_structured_output(Route, method="function_calling")
60
+
61
+
62
+ def extract_conversation_history():
63
+ """Extract conversation from session state"""
64
+ if "messages" not in st.session_state:
65
+ return []
66
+
67
+ conversation = []
68
+ for msg in st.session_state.messages:
69
+ if msg["role"] == "user":
70
+ conversation.append(f"User: {msg['content']}")
71
+ elif msg["role"] == "assistant" and not msg["content"].startswith("Hello!"):
72
+ conversation.append(f"Assistant: {msg['content']}")
73
+
74
+ return conversation
75
+
76
+ def handle_conversation_query(state: State):
77
+ """Handle questions about conversation history"""
78
+
79
+ conversation = extract_conversation_history()
80
+
81
+ if not conversation:
82
+ return {"answer": "We haven't had any conversation yet. Feel free to ask me a medical question though!"}
83
+
84
+ # Create conversation context
85
+ conversation_text = "\n".join(conversation[-10:]) # Last 10 exchanges
86
+
87
+ result = llm.invoke([
88
+ SystemMessage(content=f"""
89
+ Based on this conversation history, answer the user's question about our previous discussion:
90
+
91
+ Conversation History:
92
+ {conversation_text}
93
+
94
+ Rules:
95
+ - If they ask for a summary, provide a brief overview
96
+ - If they ask about specific questions, reference them
97
+ - If they ask about previous answers, summarize the key points
98
+ - Always maintain medical disclaimers in your response
99
+ """),
100
+ HumanMessage(content=state['question'])
101
+ ])
102
+
103
+ return {"answer": result.content}
104
+
105
+ def is_conversation_query(question: str) -> bool:
106
+ """Check if the question is about conversation history"""
107
+ memory_keywords = [
108
+ "previous", "last", "earlier", "before", "summarize", "summarise",
109
+ "what did i ask", "my questions", "conversation", "history",
110
+ "we talked", "discussed", "mentioned"
111
+ ]
112
+
113
+ question_lower = question.lower()
114
+ return any(keyword in question_lower for keyword in memory_keywords)
115
+
116
+
117
+ def llm_call_router(state: State):
118
+ """Enhanced router that includes document routing"""
119
+ # if st.session_state.get("current_document"):
120
+ # return {'decision': "DOCUMENT"}
121
+
122
+ # Check for conversation/memory queries FIRST
123
+ if is_conversation_query(state['question']):
124
+ return {'decision': "MEMORY"}
125
+
126
+ # Check if question is about an uploaded document
127
+ # document_keywords = ["document", "report", "lab results", "test results", "my results", "uploaded", "file"]
128
+ # if any(keyword in state['question'].lower() for keyword in document_keywords):
129
+ # if "current_document" in st.session_state and st.session_state.current_document:
130
+ # return {'decision': "DOCUMENT"}
131
+
132
+ # Emergency check
133
+ emergency_keywords = ["severe", "chest pain", "can't breathe", "emergency", "urgent",
134
+ "heart attack", "stroke", "bleeding", "unconscious"]
135
+ question_lower = state['question'].lower()
136
+ if any(keyword in question_lower for keyword in emergency_keywords):
137
+ return {'decision': "EMERGENCY"}
138
+
139
+ # Regular routing
140
+ decision = router.invoke([
141
+ SystemMessage(content="Route the input to RAG (medical questions) or GENERAL based on the user's request"),
142
+ HumanMessage(content=state['question'])
143
+ ])
144
+ return {"decision": decision.step} # type: ignore
145
+
146
+ def emergency_node(state: State):
147
+ """Handle emergency queries safely"""
148
+ return {"answer": "🚨 EMERGENCY: Please seek immediate medical attention or call emergency services (911). This system cannot provide emergency medical care."}
149
+
150
+ def rag_node(state: State):
151
+ """Uses RAG to answer the question with caching"""
152
+
153
+ # Check cache first
154
+ cached_answer = get_cached_answer(state['question'])
155
+ if cached_answer:
156
+ return {"answer": f"πŸ”„ {cached_answer}"} # Add emoji to show it's cached
157
+
158
+
159
+ custom_prompt = PromptTemplate(
160
+ input_variables=["context", "question"],
161
+ template="""You are a medical information assistant. Use the following medical Q&A context to answer questions accurately and safely.
162
+
163
+ Context: {context}
164
+
165
+ Question: {question}
166
+
167
+ Guidelines:
168
+ - Provide accurate medical information based on the context above
169
+ - Always recommend consulting healthcare professionals for medical decisions
170
+ - If uncertain, clearly state limitations
171
+ - If the question is not suitable for this bot, respond with: "I'm not able to provide medical advice. Please consult a medical professional."
172
+
173
+ Answer:"""
174
+ )
175
+
176
+ qa_chain = ConversationalRetrievalChain.from_llm(
177
+ llm=llm,
178
+ retriever=ensemble_retriever,
179
+ return_source_documents=True,
180
+ combine_docs_chain_kwargs={"prompt": custom_prompt}
181
+ )
182
+
183
+ result = qa_chain.invoke({
184
+ "question": state['question'],
185
+ "chat_history": []
186
+ })
187
+
188
+ # Reranking logic
189
+ docs = result.get('source_documents', [])
190
+ if docs and len(docs) > 1:
191
+ pairs = [(state['question'], doc.page_content) for doc in docs]
192
+ scores = reranker.predict(pairs)
193
+
194
+ doc_scores = list(zip(docs, scores))
195
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
196
+ top_docs = [doc for doc, score in doc_scores[:3]]
197
+
198
+ better_context = "\\n\\n".join([doc.page_content for doc in top_docs])
199
+ improved_answer = llm.invoke([
200
+ SystemMessage(content=f"""Use this medical context to answer the question safely:
201
+
202
+ Context: {better_context}
203
+
204
+ Always recommend consulting healthcare professionals."""),
205
+ HumanMessage(content=state['question'])
206
+ ])
207
+
208
+ final_answer = improved_answer.content
209
+ else:
210
+ final_answer = result['answer']
211
+
212
+
213
+ cache_answer(state['question'], final_answer)
214
+
215
+ return {"answer": final_answer}
216
+
217
+ def general_node(state: State):
218
+ """Enhanced general node with sarcastic responses for identity questions"""
219
+
220
+ question_lower = state['question'].lower().strip()
221
+
222
+ # Identity/philosophical questions - sarcastic responses
223
+ identity_keywords = [
224
+ "what are you", "who are you", "what is your name", "are you human",
225
+ "are you real", "are you ai", "are you robot", "are you chatbot",
226
+ "what's your name", "who made you", "are you alive", "do you think",
227
+ "are you conscious", "do you feel", "what do you do", "your purpose"
228
+ ]
229
+
230
+ if any(keyword in question_lower for keyword in identity_keywords):
231
+ # Sarcastic responses for identity questions
232
+ sarcastic_responses = [
233
+ "πŸ€– Oh, just your friendly neighborhood medical AI trying to keep people from WebMD-ing themselves into thinking they have every disease known to humanity. You know, the usual.",
234
+
235
+ "🩺 I'm a sophisticated medical assistant, which is a fancy way of saying I'm here to tell you to 'consult a healthcare professional' in 47 different ways.",
236
+
237
+ "πŸ₯ I'm an AI that reads medical textbooks faster than you can say 'Google symptoms at 3 AM.' My purpose? Giving you actual medical info instead of letting you convince yourself that headache is definitely a brain tumor.",
238
+
239
+ "πŸ’Š I'm basically a walking medical disclaimer with a personality. Think of me as that friend who went to med school but actually remembers what they learned.",
240
+
241
+ "πŸ”¬ I'm an artificial intelligence trained on medical knowledge, which means I can tell you about symptoms but I still can't fix your tendency to ignore doctor's appointments.",
242
+
243
+ "🧠 I'm a medical AI assistant. I exist to answer your health questions and remind you that, no, that WebMD article probably doesn't apply to you."
244
+ ]
245
+
246
+ import random
247
+ return {"answer": random.choice(sarcastic_responses)}
248
+
249
+ # Greeting responses - also with some personality
250
+ greeting_keywords = ["hello", "hi", "hey", "good morning", "good evening", "greetings"]
251
+ if any(keyword in question_lower for keyword in greeting_keywords):
252
+ friendly_responses = [
253
+ "Hello! πŸ‘‹ Ready to get some actual medical information instead of falling down a WebMD rabbit hole?",
254
+ "Hi there! πŸ₯ I'm here to answer your medical questions. Fair warning: I'll probably tell you to see a real doctor.",
255
+ "Hey! πŸ‘¨β€βš•οΈ What medical mystery can I help solve today? (Spoiler: the answer might be 'drink more water')",
256
+ "Greetings! 🩺 Ask me anything medical-related. I promise to give you better advice than your cousin's Facebook post."
257
+ ]
258
+
259
+ import random
260
+ return {"answer": random.choice(friendly_responses)}
261
+
262
+ # Regular medical or general questions
263
+ result = llm.invoke([
264
+ SystemMessage(content="""
265
+ Answer the user's question helpfully and accurately.
266
+
267
+ IMPORTANT SAFETY RULES:
268
+ - For medical questions: Always end with "Please consult a healthcare professional"
269
+ - For emergencies: Direct to call emergency services immediately
270
+ - If unsure: Say "I don't know" rather than guess
271
+
272
+ Be helpful but prioritize user safety. You can be slightly witty or conversational, but always maintain professionalism for serious medical topics.
273
+ """),
274
+ HumanMessage(content=state['question'])
275
+ ])
276
+
277
+ return {"answer": result.content}
278
+
279
+ def document_node(state: State):
280
+ """Simple document processing node that integrates with your existing workflow"""
281
+
282
+
283
+ # Check if there's an uploaded document in session state
284
+ if "current_document" not in st.session_state or not st.session_state.current_document:
285
+ return {"answer": "Please upload a medical document first using the file uploader in the sidebar."}
286
+
287
+ file_path = st.session_state.current_document
288
+ question = state['question']
289
+
290
+ try:
291
+ # Check if document already processed
292
+ if file_path not in st.session_state.uploaded_documents:
293
+ # Extract document content
294
+ # loader = AmazonTextractPDFLoader(file_path, region_name="us-east-1")
295
+ loader = UnstructuredPDFLoader(file_path)
296
+ documents = loader.load()
297
+
298
+ # Clean and store content
299
+ content = "\n".join([doc.page_content for doc in documents])
300
+ st.session_state.uploaded_documents[file_path] = {
301
+ "content": content,
302
+ "conversation": []
303
+ }
304
+
305
+ # Get stored document
306
+ doc_data = st.session_state.uploaded_documents[file_path]
307
+
308
+ # Build context with previous questions about this document
309
+ context_parts = [f"Document Content:\n{doc_data['content']}"]
310
+
311
+ if doc_data['conversation']:
312
+ context_parts.append("\nPrevious questions about this document:")
313
+ for qa in doc_data['conversation'][-3:]: # Last 3 Q&As
314
+ context_parts.append(f"Q: {qa['question']}\nA: {qa['answer'][:200]}...")
315
+
316
+ full_context = "\n".join(context_parts)
317
+
318
+ # Generate answer using your existing LLM
319
+ from langchain_core.messages import HumanMessage, SystemMessage
320
+
321
+ result = llm.invoke([
322
+ SystemMessage(content=f"""
323
+ You are analyzing a medical document. Use the document content and any previous conversation to answer the user's question.
324
+
325
+ Guidelines:
326
+ - Base your answer on the document content provided
327
+ - Reference specific values or sections when possible
328
+ - If information isn't in the document, clearly state this
329
+ - Always include medical disclaimers
330
+ - Maintain conversation continuity with previous questions
331
+
332
+ {full_context}
333
+ """),
334
+ HumanMessage(content=f"Question about the document: {question}")
335
+ ])
336
+
337
+ # Store this Q&A in document conversation history
338
+ doc_data['conversation'].append({
339
+ "question": question,
340
+ "answer": result.content
341
+ })
342
+
343
+ return {"answer": f"πŸ“„ **Document Analysis:**\n\n{result.content}"}
344
+
345
+ except Exception as e:
346
+ return {"answer": f"Error processing document: {str(e)}. Please ensure the file is accessible and try again."}
347
+
348
+
349
+ def route_decision(state: State):
350
+ """Enhanced route decision with memory"""
351
+ if state["decision"] == "MEMORY":
352
+ return "memory_node"
353
+ elif state["decision"] == "DOCUMENT":
354
+ return "document_node"
355
+ elif state["decision"] == "RAG":
356
+ return "rag_node"
357
+ elif state["decision"] == "EMERGENCY":
358
+ return "emergency_node"
359
+ else:
360
+ return "general_node"
361
+
362
+ # ==================== CREATE WORKFLOW ====================
363
+
364
+ @st.cache_resource
365
+ def create_workflow():
366
+ """Create the enhanced workflow graph with memory"""
367
+
368
+ init_document_memory()
369
+
370
+ router_builder = StateGraph(State)
371
+
372
+ # Add all nodes (including new memory node)
373
+ router_builder.add_node("rag_node", rag_node)
374
+ router_builder.add_node("general_node", general_node)
375
+ router_builder.add_node("llm_call_router", llm_call_router)
376
+ router_builder.add_node("emergency_node", emergency_node)
377
+ router_builder.add_node("memory_node", handle_conversation_query) # NEW NODE
378
+ # router_builder.add_node("document_node", document_node)
379
+
380
+
381
+ router_builder.add_edge(START, "llm_call_router")
382
+ router_builder.add_conditional_edges(
383
+ "llm_call_router",
384
+ route_decision,
385
+ {
386
+ "rag_node": "rag_node",
387
+ "general_node": "general_node",
388
+ "emergency_node": "emergency_node",
389
+ "memory_node": "memory_node", # NEW ROUTE,
390
+ # "document_node": "document_node"
391
+ },
392
+ )
393
+
394
+ # Add edges to END
395
+ router_builder.add_edge("rag_node", END)
396
+ router_builder.add_edge("general_node", END)
397
+ router_builder.add_edge("emergency_node", END)
398
+ router_builder.add_edge("memory_node", END) # NEW EDGE
399
+ # router_builder.add_edge("document_node", END)
400
+
401
+ return router_builder.compile()
402
+
403
+