from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.schema import Document from langchain.retrievers import EnsembleRetriever from langchain_community.retrievers import BM25Retriever from langchain_openai import ChatOpenAI import numpy as np from sentence_transformers import CrossEncoder from dotenv import load_dotenv import streamlit as st from datasets import load_dataset import os import pickle import faiss from langchain_community.docstore.in_memory import InMemoryDocstore # Add this import import time load_dotenv() def get_vector_store(): """Load vectorstore from pre-computed embeddings""" try: # Load pre-computed data if not os.path.exists('src/medical_embeddings.npy'): raise FileNotFoundError("medical_embeddings.npy not found") if not os.path.exists('src/medical_texts.pkl'): raise FileNotFoundError("medical_texts.pkl not found") print("📥 Loading pre-computed embeddings...") embeddings_array = np.load('src/medical_embeddings.npy') with open('src/medical_texts.pkl', 'rb') as f: texts = pickle.load(f) print(f"✅ Loaded {len(embeddings_array)} pre-computed embeddings") # Create FAISS index from pre-computed embeddings dimension = embeddings_array.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings_array.astype('float32')) # type: ignore # import os # os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp' # os.makedirs('/tmp', exist_ok=True) # Create embedding function for new queries embeddings_function = HuggingFaceEmbeddings( model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" ) # Create proper Document objects and InMemoryDocstore documents_dict = {} documents = [] for i, text in enumerate(texts): # Create Document objects with proper metadata doc = Document( page_content=text, metadata={"doc_id": i, "type": "medical_qa"} ) documents_dict[str(i)] = doc documents.append(doc) # Create proper docstore docstore = InMemoryDocstore(documents_dict) # Create index to docstore mapping index_to_docstore_id = {i: str(i) for i in range(len(texts))} # Create FAISS vectorstore with proper parameters vectorstore = FAISS( embedding_function=embeddings_function, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id ) return vectorstore, documents except FileNotFoundError as e: print(f"❌ Pre-computed files not found: {e}") print("🔄 Falling back to creating embeddings...") return None, None except Exception as e: print(f"❌ Error loading pre-computed embeddings: {e}") print("🔄 Falling back to creating embeddings...") return None, None @st.cache_resource def load_medical_system(): """Load the medical RAG system (cached for performance)""" with st.spinner("🔄 Loading medical knowledge base..."): # Load dataset # ds = load_dataset("keivalya/MedQuad-MedicalQnADataset") # # Create documents # documents = [] # for i, item in enumerate(ds['train']): # type: ignore # content = f"Question: {item['Question']}\nAnswer: {item['Answer']}" # type: ignore # metadata = { # "doc_id": i, # "question": item['Question'], # type: ignore # "answer": item['Answer'], # type: ignore # "question_type": item['qtype'], # type: ignore # "type": "qa_pair" # } # documents.append(Document(page_content=content, metadata=metadata)) start = time.time() # Try to load existing vectorstore vectorstore, documents = get_vector_store() end = time.time() if vectorstore is None or documents is None: st.error("❌ Could not load the vectorstore. Please ensure the embeddings and text files exist.") st.stop() total_time = end - start st.success(f"✅ Loaded existing vectorstore in {total_time:.2f} seconds") # Create retrievers bm25_retriever = BM25Retriever.from_documents(documents) vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, vector_retriever], weights=[0.3, 0.7] ) # create LLM openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: st.error("❌ OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml") st.stop() llm = ChatOpenAI(temperature=0, api_key=openai_key) # type: ignore # Create reranker reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') return documents, ensemble_retriever, llm, reranker