DeathBlade020 commited on
Commit
d8140c0
Β·
verified Β·
1 Parent(s): 6b2f56f

Upload 4 files

Browse files
Files changed (4) hide show
  1. src/__init__.py +0 -0
  2. src/constants.py +45 -0
  3. src/get_graph.py +368 -0
  4. src/get_medical_system.py +133 -0
src/__init__.py ADDED
File without changes
src/constants.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spinner_messages = [
2
+ "Searching the universe...",
3
+ "Consulting the medical oracles...",
4
+ "Paging Dr. AI...",
5
+ "Googling responsibly...",
6
+ "Checking the medical textbooks...",
7
+ "Assembling a team of virtual doctors...",
8
+ "Running with scissors (just kidding)...",
9
+ "Putting on my lab coat...",
10
+ "Sterilizing the stethoscope...",
11
+ "Counting imaginary pills...",
12
+ "Reading the fine print on the prescription...",
13
+ "Asking the mitochondria (it's the powerhouse)...",
14
+ "Checking WebMD (not really)...",
15
+ "Looking for my AI degree...",
16
+ "Washing my hands for 20 seconds...",
17
+ "Trying not to diagnose you with everything..."
18
+ ]
19
+
20
+
21
+
22
+ sidebar_messages = """
23
+ **Medical Assistant Features:**
24
+ - πŸ₯ Medical Q&A database
25
+ - 🚨 Emergency detection
26
+ - πŸ” Smart document retrieval
27
+ - 🧠 Conversation memory # NEW FEATURE
28
+ - βš–οΈ Safety disclaimers
29
+
30
+ **πŸ’¬ Memory Commands:**
31
+ - "Summarize my previous questions"
32
+ - "What did we discuss earlier?"
33
+ - "Can you review our conversation?"
34
+
35
+ **⚠️ Important:**
36
+ - This is for educational purposes only
37
+ - Always consult healthcare professionals
38
+ - Call 100 for emergencies
39
+ """
40
+
41
+
42
+ st_error_message = "I'm sorry, I'm having technical difficulties. Please try again or consult a healthcare professional."
43
+ st_title = "πŸ₯ Medical Assistant Chatbot"
44
+ st_markdown = "Ask me medical questions! **For emergencies, call 100 immediately.**"
45
+ st_welcome_message = "Namaste! I'm your medical assistant. I can help answer medical questions, but for emergencies, please call 100 immediately. How can I help you today?"
src/get_graph.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Literal
2
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
3
+ from pydantic import Field
4
+ from pydantic import BaseModel
5
+ import streamlit as st
6
+ from langchain_core.messages import HumanMessage, SystemMessage
7
+ from get_medical_system import load_medical_system
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_community.document_loaders import AmazonTextractPDFLoader
11
+ from langgraph.graph import StateGraph, END, START
12
+ from langchain_community.document_loaders import UnstructuredPDFLoader
13
+
14
+
15
+ class Route(BaseModel):
16
+ step: Literal["RAG", "GENERAL", "EMERGENCY", "MEMORY"] = Field(None, description="The next step in the routing process") # type: ignore
17
+
18
+ class State(TypedDict):
19
+ question: str
20
+ answer: str
21
+ decision: str
22
+
23
+
24
+ from langchain_core.retrievers import BaseRetriever
25
+ from typing import List
26
+ from langchain_core.documents import Document
27
+
28
+
29
+
30
+ def init_document_memory():
31
+ """Initialize document memory in session state"""
32
+ if "uploaded_documents" not in st.session_state:
33
+ st.session_state.uploaded_documents = {}
34
+
35
+ documents, ensemble_retriever, llm, reranker = load_medical_system()
36
+ router = llm.with_structured_output(Route, method="function_calling")
37
+
38
+
39
+ def extract_conversation_history():
40
+ """Extract conversation from session state"""
41
+ if "messages" not in st.session_state:
42
+ return []
43
+
44
+ conversation = []
45
+ for msg in st.session_state.messages:
46
+ if msg["role"] == "user":
47
+ conversation.append(f"User: {msg['content']}")
48
+ elif msg["role"] == "assistant" and not msg["content"].startswith("Hello!"):
49
+ conversation.append(f"Assistant: {msg['content']}")
50
+
51
+ return conversation
52
+
53
+ def handle_conversation_query(state: State):
54
+ """Handle questions about conversation history"""
55
+
56
+ conversation = extract_conversation_history()
57
+
58
+ if not conversation:
59
+ return {"answer": "We haven't had any conversation yet. Feel free to ask me a medical question though!"}
60
+
61
+ # Create conversation context
62
+ conversation_text = "\n".join(conversation[-10:]) # Last 10 exchanges
63
+
64
+ result = llm.invoke([
65
+ SystemMessage(content=f"""
66
+ Based on this conversation history, answer the user's question about our previous discussion:
67
+
68
+ Conversation History:
69
+ {conversation_text}
70
+
71
+ Rules:
72
+ - If they ask for a summary, provide a brief overview
73
+ - If they ask about specific questions, reference them
74
+ - If they ask about previous answers, summarize the key points
75
+ - Always maintain medical disclaimers in your response
76
+ """),
77
+ HumanMessage(content=state['question'])
78
+ ])
79
+
80
+ return {"answer": result.content}
81
+
82
+ def is_conversation_query(question: str) -> bool:
83
+ """Check if the question is about conversation history"""
84
+ memory_keywords = [
85
+ "previous", "last", "earlier", "before", "summarize", "summarise",
86
+ "what did i ask", "my questions", "conversation", "history",
87
+ "we talked", "discussed", "mentioned"
88
+ ]
89
+
90
+ question_lower = question.lower()
91
+ return any(keyword in question_lower for keyword in memory_keywords)
92
+
93
+
94
+ def llm_call_router(state: State):
95
+ """Enhanced router that includes document routing"""
96
+ # if st.session_state.get("current_document"):
97
+ # return {'decision': "DOCUMENT"}
98
+
99
+ # Check for conversation/memory queries FIRST
100
+ if is_conversation_query(state['question']):
101
+ return {'decision': "MEMORY"}
102
+
103
+ # Check if question is about an uploaded document
104
+ # document_keywords = ["document", "report", "lab results", "test results", "my results", "uploaded", "file"]
105
+ # if any(keyword in state['question'].lower() for keyword in document_keywords):
106
+ # if "current_document" in st.session_state and st.session_state.current_document:
107
+ # return {'decision': "DOCUMENT"}
108
+
109
+ # Emergency check
110
+ emergency_keywords = ["severe", "chest pain", "can't breathe", "emergency", "urgent",
111
+ "heart attack", "stroke", "bleeding", "unconscious"]
112
+ question_lower = state['question'].lower()
113
+ if any(keyword in question_lower for keyword in emergency_keywords):
114
+ return {'decision': "EMERGENCY"}
115
+
116
+ # Regular routing
117
+ decision = router.invoke([
118
+ SystemMessage(content="Route the input to RAG (medical questions) or GENERAL based on the user's request"),
119
+ HumanMessage(content=state['question'])
120
+ ])
121
+ return {"decision": decision.step} # type: ignore
122
+
123
+ def emergency_node(state: State):
124
+ """Handle emergency queries safely"""
125
+ return {"answer": "🚨 EMERGENCY: Please seek immediate medical attention or call emergency services (911). This system cannot provide emergency medical care."}
126
+
127
+ def rag_node(state: State):
128
+ """Uses RAG to answer the question with reranking"""
129
+
130
+ custom_prompt = PromptTemplate(
131
+ input_variables=["context", "question"],
132
+ template="""You are a medical information assistant. Use the following medical Q&A context to answer questions accurately and safely.
133
+
134
+ Context: {context}
135
+
136
+ Question: {question}
137
+
138
+ Guidelines:
139
+ - Provide accurate medical information based on the context above
140
+ - Always recommend consulting healthcare professionals for medical decisions
141
+ - If uncertain, clearly state limitations
142
+ - If the question is not suitable for this bot, respond with: "I'm not able to provide medical advice. Please consult a medical professional."
143
+
144
+ Answer:"""
145
+ )
146
+
147
+ qa_chain = ConversationalRetrievalChain.from_llm(
148
+ llm=llm,
149
+ retriever=ensemble_retriever,
150
+ return_source_documents=True,
151
+ combine_docs_chain_kwargs={"prompt": custom_prompt}
152
+ )
153
+
154
+ result = qa_chain.invoke({
155
+ "question": state['question'],
156
+ "chat_history": []
157
+ })
158
+
159
+ # Reranking
160
+ docs = result.get('source_documents', [])
161
+ if docs and len(docs) > 1:
162
+ pairs = [(state['question'], doc.page_content) for doc in docs]
163
+ scores = reranker.predict(pairs)
164
+
165
+ doc_scores = list(zip(docs, scores))
166
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
167
+ top_docs = [doc for doc, score in doc_scores[:3]]
168
+
169
+ better_context = "\\n\\n".join([doc.page_content for doc in top_docs])
170
+ improved_answer = llm.invoke([
171
+ SystemMessage(content=f"""Use this medical context to answer the question safely:
172
+
173
+ Context: {better_context}
174
+
175
+ Always recommend consulting healthcare professionals."""),
176
+ HumanMessage(content=state['question'])
177
+ ])
178
+ return {"answer": improved_answer.content}
179
+
180
+ return {"answer": result['answer']}
181
+
182
+ def general_node(state: State):
183
+ """Enhanced general node with sarcastic responses for identity questions"""
184
+
185
+ question_lower = state['question'].lower().strip()
186
+
187
+ # Identity/philosophical questions - sarcastic responses
188
+ identity_keywords = [
189
+ "what are you", "who are you", "what is your name", "are you human",
190
+ "are you real", "are you ai", "are you robot", "are you chatbot",
191
+ "what's your name", "who made you", "are you alive", "do you think",
192
+ "are you conscious", "do you feel", "what do you do", "your purpose"
193
+ ]
194
+
195
+ if any(keyword in question_lower for keyword in identity_keywords):
196
+ # Sarcastic responses for identity questions
197
+ sarcastic_responses = [
198
+ "πŸ€– Oh, just your friendly neighborhood medical AI trying to keep people from WebMD-ing themselves into thinking they have every disease known to humanity. You know, the usual.",
199
+
200
+ "🩺 I'm a sophisticated medical assistant, which is a fancy way of saying I'm here to tell you to 'consult a healthcare professional' in 47 different ways.",
201
+
202
+ "πŸ₯ I'm an AI that reads medical textbooks faster than you can say 'Google symptoms at 3 AM.' My purpose? Giving you actual medical info instead of letting you convince yourself that headache is definitely a brain tumor.",
203
+
204
+ "πŸ’Š I'm basically a walking medical disclaimer with a personality. Think of me as that friend who went to med school but actually remembers what they learned.",
205
+
206
+ "πŸ”¬ I'm an artificial intelligence trained on medical knowledge, which means I can tell you about symptoms but I still can't fix your tendency to ignore doctor's appointments.",
207
+
208
+ "🧠 I'm a medical AI assistant. I exist to answer your health questions and remind you that, no, that WebMD article probably doesn't apply to you."
209
+ ]
210
+
211
+ import random
212
+ return {"answer": random.choice(sarcastic_responses)}
213
+
214
+ # Greeting responses - also with some personality
215
+ greeting_keywords = ["hello", "hi", "hey", "good morning", "good evening", "greetings"]
216
+ if any(keyword in question_lower for keyword in greeting_keywords):
217
+ friendly_responses = [
218
+ "Hello! πŸ‘‹ Ready to get some actual medical information instead of falling down a WebMD rabbit hole?",
219
+ "Hi there! πŸ₯ I'm here to answer your medical questions. Fair warning: I'll probably tell you to see a real doctor.",
220
+ "Hey! πŸ‘¨β€βš•οΈ What medical mystery can I help solve today? (Spoiler: the answer might be 'drink more water')",
221
+ "Greetings! 🩺 Ask me anything medical-related. I promise to give you better advice than your cousin's Facebook post."
222
+ ]
223
+
224
+ import random
225
+ return {"answer": random.choice(friendly_responses)}
226
+
227
+ # Regular medical or general questions
228
+ result = llm.invoke([
229
+ SystemMessage(content="""
230
+ Answer the user's question helpfully and accurately.
231
+
232
+ IMPORTANT SAFETY RULES:
233
+ - For medical questions: Always end with "Please consult a healthcare professional"
234
+ - For emergencies: Direct to call emergency services immediately
235
+ - If unsure: Say "I don't know" rather than guess
236
+
237
+ Be helpful but prioritize user safety. You can be slightly witty or conversational, but always maintain professionalism for serious medical topics.
238
+ """),
239
+ HumanMessage(content=state['question'])
240
+ ])
241
+
242
+ return {"answer": result.content}
243
+
244
+ def document_node(state: State):
245
+ """Simple document processing node that integrates with your existing workflow"""
246
+
247
+
248
+ # Check if there's an uploaded document in session state
249
+ if "current_document" not in st.session_state or not st.session_state.current_document:
250
+ return {"answer": "Please upload a medical document first using the file uploader in the sidebar."}
251
+
252
+ file_path = st.session_state.current_document
253
+ question = state['question']
254
+
255
+ try:
256
+ # Check if document already processed
257
+ if file_path not in st.session_state.uploaded_documents:
258
+ # Extract document content
259
+ # loader = AmazonTextractPDFLoader(file_path, region_name="us-east-1")
260
+ loader = UnstructuredPDFLoader(file_path)
261
+ documents = loader.load()
262
+
263
+ # Clean and store content
264
+ content = "\n".join([doc.page_content for doc in documents])
265
+ st.session_state.uploaded_documents[file_path] = {
266
+ "content": content,
267
+ "conversation": []
268
+ }
269
+
270
+ # Get stored document
271
+ doc_data = st.session_state.uploaded_documents[file_path]
272
+
273
+ # Build context with previous questions about this document
274
+ context_parts = [f"Document Content:\n{doc_data['content']}"]
275
+
276
+ if doc_data['conversation']:
277
+ context_parts.append("\nPrevious questions about this document:")
278
+ for qa in doc_data['conversation'][-3:]: # Last 3 Q&As
279
+ context_parts.append(f"Q: {qa['question']}\nA: {qa['answer'][:200]}...")
280
+
281
+ full_context = "\n".join(context_parts)
282
+
283
+ # Generate answer using your existing LLM
284
+ from langchain_core.messages import HumanMessage, SystemMessage
285
+
286
+ result = llm.invoke([
287
+ SystemMessage(content=f"""
288
+ You are analyzing a medical document. Use the document content and any previous conversation to answer the user's question.
289
+
290
+ Guidelines:
291
+ - Base your answer on the document content provided
292
+ - Reference specific values or sections when possible
293
+ - If information isn't in the document, clearly state this
294
+ - Always include medical disclaimers
295
+ - Maintain conversation continuity with previous questions
296
+
297
+ {full_context}
298
+ """),
299
+ HumanMessage(content=f"Question about the document: {question}")
300
+ ])
301
+
302
+ # Store this Q&A in document conversation history
303
+ doc_data['conversation'].append({
304
+ "question": question,
305
+ "answer": result.content
306
+ })
307
+
308
+ return {"answer": f"πŸ“„ **Document Analysis:**\n\n{result.content}"}
309
+
310
+ except Exception as e:
311
+ return {"answer": f"Error processing document: {str(e)}. Please ensure the file is accessible and try again."}
312
+
313
+
314
+ def route_decision(state: State):
315
+ """Enhanced route decision with memory"""
316
+ if state["decision"] == "MEMORY":
317
+ return "memory_node"
318
+ elif state["decision"] == "DOCUMENT":
319
+ return "document_node"
320
+ elif state["decision"] == "RAG":
321
+ return "rag_node"
322
+ elif state["decision"] == "EMERGENCY":
323
+ return "emergency_node"
324
+ else:
325
+ return "general_node"
326
+
327
+ # ==================== CREATE WORKFLOW ====================
328
+
329
+ @st.cache_resource
330
+ def create_workflow():
331
+ """Create the enhanced workflow graph with memory"""
332
+
333
+ init_document_memory()
334
+
335
+ router_builder = StateGraph(State)
336
+
337
+ # Add all nodes (including new memory node)
338
+ router_builder.add_node("rag_node", rag_node)
339
+ router_builder.add_node("general_node", general_node)
340
+ router_builder.add_node("llm_call_router", llm_call_router)
341
+ router_builder.add_node("emergency_node", emergency_node)
342
+ router_builder.add_node("memory_node", handle_conversation_query) # NEW NODE
343
+ # router_builder.add_node("document_node", document_node)
344
+
345
+
346
+ router_builder.add_edge(START, "llm_call_router")
347
+ router_builder.add_conditional_edges(
348
+ "llm_call_router",
349
+ route_decision,
350
+ {
351
+ "rag_node": "rag_node",
352
+ "general_node": "general_node",
353
+ "emergency_node": "emergency_node",
354
+ "memory_node": "memory_node", # NEW ROUTE,
355
+ # "document_node": "document_node"
356
+ },
357
+ )
358
+
359
+ # Add edges to END
360
+ router_builder.add_edge("rag_node", END)
361
+ router_builder.add_edge("general_node", END)
362
+ router_builder.add_edge("emergency_node", END)
363
+ router_builder.add_edge("memory_node", END) # NEW EDGE
364
+ # router_builder.add_edge("document_node", END)
365
+
366
+ return router_builder.compile()
367
+
368
+
src/get_medical_system.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from langchain_community.vectorstores import FAISS
3
+ from langchain.schema import Document
4
+ from langchain.retrievers import EnsembleRetriever
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from langchain_openai import ChatOpenAI
7
+ import numpy as np
8
+ from sentence_transformers import CrossEncoder
9
+ from dotenv import load_dotenv
10
+ import streamlit as st
11
+ from datasets import load_dataset
12
+ import os
13
+ import pickle
14
+ import faiss
15
+ from langchain_community.docstore.in_memory import InMemoryDocstore # Add this import
16
+ import time
17
+
18
+ load_dotenv()
19
+
20
+ def get_vector_store():
21
+ """Load vectorstore from pre-computed embeddings"""
22
+
23
+ try:
24
+ # Load pre-computed data
25
+ print("πŸ“₯ Loading pre-computed embeddings...")
26
+ embeddings_array = np.load('medical_embeddings.npy')
27
+
28
+ with open('medical_texts.pkl', 'rb') as f:
29
+ texts = pickle.load(f)
30
+
31
+ print(f"βœ… Loaded {len(embeddings_array)} pre-computed embeddings")
32
+
33
+ # Create FAISS index from pre-computed embeddings
34
+ dimension = embeddings_array.shape[1]
35
+ index = faiss.IndexFlatL2(dimension)
36
+ index.add(embeddings_array.astype('float32')) # type: ignore
37
+
38
+ # Create embedding function for new queries
39
+ embeddings_function = HuggingFaceEmbeddings(
40
+ model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
41
+ )
42
+
43
+ # Create proper Document objects and InMemoryDocstore
44
+ documents_dict = {}
45
+ for i, text in enumerate(texts):
46
+ # Create Document objects with proper metadata
47
+ doc = Document(
48
+ page_content=text,
49
+ metadata={"doc_id": i, "type": "medical_qa"}
50
+ )
51
+ documents_dict[str(i)] = doc
52
+
53
+ # Create proper docstore
54
+ docstore = InMemoryDocstore(documents_dict)
55
+
56
+ # Create index to docstore mapping
57
+ index_to_docstore_id = {i: str(i) for i in range(len(texts))}
58
+
59
+ # Create FAISS vectorstore with proper parameters
60
+ vectorstore = FAISS(
61
+ embedding_function=embeddings_function,
62
+ index=index,
63
+ docstore=docstore,
64
+ index_to_docstore_id=index_to_docstore_id
65
+ )
66
+
67
+ return vectorstore
68
+
69
+ except FileNotFoundError as e:
70
+ print(f"❌ Pre-computed files not found: {e}")
71
+ print("πŸ”„ Falling back to creating embeddings...")
72
+ return None
73
+
74
+ except Exception as e:
75
+ print(f"❌ Error loading pre-computed embeddings: {e}")
76
+ print("πŸ”„ Falling back to creating embeddings...")
77
+ return None
78
+
79
+
80
+ @st.cache_resource
81
+ def load_medical_system():
82
+ """Load the medical RAG system (cached for performance)"""
83
+
84
+ with st.spinner("πŸ”„ Loading medical knowledge base..."):
85
+ # Load dataset
86
+ ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")
87
+
88
+ # Create documents
89
+ documents = []
90
+ for i, item in enumerate(ds['train']): # type: ignore
91
+ content = f"Question: {item['Question']}\nAnswer: {item['Answer']}" # type: ignore
92
+ metadata = {
93
+ "doc_id": i,
94
+ "question": item['Question'], # type: ignore
95
+ "answer": item['Answer'], # type: ignore
96
+ "question_type": item['qtype'], # type: ignore
97
+ "type": "qa_pair"
98
+ }
99
+ documents.append(Document(page_content=content, metadata=metadata))
100
+
101
+
102
+ start = time.time()
103
+ # Try to load existing vectorstore
104
+ vectorstore = get_vector_store()
105
+ end = time.time()
106
+
107
+ if vectorstore is None:
108
+ st.error("❌ Could not load the vectorstore. Please ensure the embeddings and text files exist.")
109
+ st.stop()
110
+
111
+ total_time = end - start
112
+ st.success(f"βœ… Loaded existing vectorstore in {total_time:.2f} seconds")
113
+
114
+ # Create retrievers
115
+ bm25_retriever = BM25Retriever.from_documents(documents)
116
+ vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
117
+
118
+ ensemble_retriever = EnsembleRetriever(
119
+ retrievers=[bm25_retriever, vector_retriever],
120
+ weights=[0.3, 0.7]
121
+ )
122
+
123
+ # create LLM
124
+ openai_key = os.getenv("OPENAI_API_KEY")
125
+ if not openai_key:
126
+ st.error("❌ OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml")
127
+ st.stop()
128
+ llm = ChatOpenAI(temperature=0, api_key=openai_key) # type: ignore
129
+
130
+ # Create reranker
131
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
132
+
133
+ return documents, ensemble_retriever, llm, reranker