|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain_openai import ChatOpenAI |
|
import numpy as np |
|
from sentence_transformers import CrossEncoder |
|
from dotenv import load_dotenv |
|
import streamlit as st |
|
from datasets import load_dataset |
|
import os |
|
import pickle |
|
import faiss |
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
import time |
|
|
|
load_dotenv() |
|
|
|
def get_vector_store(): |
|
"""Load vectorstore from pre-computed embeddings""" |
|
|
|
try: |
|
|
|
|
|
if not os.path.exists('src/medical_embeddings.npy'): |
|
raise FileNotFoundError("medical_embeddings.npy not found") |
|
if not os.path.exists('src/medical_texts.pkl'): |
|
raise FileNotFoundError("medical_texts.pkl not found") |
|
print("π₯ Loading pre-computed embeddings...") |
|
embeddings_array = np.load('src/medical_embeddings.npy') |
|
|
|
with open('src/medical_texts.pkl', 'rb') as f: |
|
texts = pickle.load(f) |
|
|
|
print(f"β
Loaded {len(embeddings_array)} pre-computed embeddings") |
|
|
|
|
|
dimension = embeddings_array.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings_array.astype('float32')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings_function = HuggingFaceEmbeddings( |
|
model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" |
|
) |
|
|
|
|
|
documents_dict = {} |
|
documents = [] |
|
for i, text in enumerate(texts): |
|
|
|
doc = Document( |
|
page_content=text, |
|
metadata={"doc_id": i, "type": "medical_qa"} |
|
) |
|
documents_dict[str(i)] = doc |
|
documents.append(doc) |
|
|
|
|
|
docstore = InMemoryDocstore(documents_dict) |
|
|
|
|
|
index_to_docstore_id = {i: str(i) for i in range(len(texts))} |
|
|
|
|
|
vectorstore = FAISS( |
|
embedding_function=embeddings_function, |
|
index=index, |
|
docstore=docstore, |
|
index_to_docstore_id=index_to_docstore_id |
|
) |
|
|
|
return vectorstore, documents |
|
|
|
except FileNotFoundError as e: |
|
print(f"β Pre-computed files not found: {e}") |
|
print("π Falling back to creating embeddings...") |
|
return None, None |
|
|
|
except Exception as e: |
|
print(f"β Error loading pre-computed embeddings: {e}") |
|
print("π Falling back to creating embeddings...") |
|
return None, None |
|
|
|
|
|
@st.cache_resource |
|
def load_medical_system(): |
|
"""Load the medical RAG system (cached for performance)""" |
|
|
|
with st.spinner("π Loading medical knowledge base..."): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time.time() |
|
|
|
vectorstore, documents = get_vector_store() |
|
end = time.time() |
|
|
|
if vectorstore is None or documents is None: |
|
st.error("β Could not load the vectorstore. Please ensure the embeddings and text files exist.") |
|
st.stop() |
|
|
|
total_time = end - start |
|
st.success(f"β
Loaded existing vectorstore in {total_time:.2f} seconds") |
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(documents) |
|
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, vector_retriever], |
|
weights=[0.3, 0.7] |
|
) |
|
|
|
|
|
openai_key = os.getenv("OPENAI_API_KEY") |
|
if not openai_key: |
|
st.error("β OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml") |
|
st.stop() |
|
llm = ChatOpenAI(temperature=0, api_key=openai_key) |
|
|
|
|
|
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
|
return documents, ensemble_retriever, llm, reranker |