|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain.schema import Document
|
|
from langchain.retrievers import EnsembleRetriever
|
|
from langchain_community.retrievers import BM25Retriever
|
|
from langchain_openai import ChatOpenAI
|
|
import numpy as np
|
|
from sentence_transformers import CrossEncoder
|
|
from dotenv import load_dotenv
|
|
import streamlit as st
|
|
from datasets import load_dataset
|
|
import os
|
|
import pickle
|
|
import faiss
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore
|
|
import time
|
|
|
|
load_dotenv()
|
|
|
|
def get_vector_store():
|
|
"""Load vectorstore from pre-computed embeddings"""
|
|
|
|
try:
|
|
|
|
print("π₯ Loading pre-computed embeddings...")
|
|
embeddings_array = np.load('medical_embeddings.npy')
|
|
|
|
with open('medical_texts.pkl', 'rb') as f:
|
|
texts = pickle.load(f)
|
|
|
|
print(f"β
Loaded {len(embeddings_array)} pre-computed embeddings")
|
|
|
|
|
|
dimension = embeddings_array.shape[1]
|
|
index = faiss.IndexFlatL2(dimension)
|
|
index.add(embeddings_array.astype('float32'))
|
|
|
|
|
|
embeddings_function = HuggingFaceEmbeddings(
|
|
model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
|
|
)
|
|
|
|
|
|
documents_dict = {}
|
|
for i, text in enumerate(texts):
|
|
|
|
doc = Document(
|
|
page_content=text,
|
|
metadata={"doc_id": i, "type": "medical_qa"}
|
|
)
|
|
documents_dict[str(i)] = doc
|
|
|
|
|
|
docstore = InMemoryDocstore(documents_dict)
|
|
|
|
|
|
index_to_docstore_id = {i: str(i) for i in range(len(texts))}
|
|
|
|
|
|
vectorstore = FAISS(
|
|
embedding_function=embeddings_function,
|
|
index=index,
|
|
docstore=docstore,
|
|
index_to_docstore_id=index_to_docstore_id
|
|
)
|
|
|
|
return vectorstore
|
|
|
|
except FileNotFoundError as e:
|
|
print(f"β Pre-computed files not found: {e}")
|
|
print("π Falling back to creating embeddings...")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"β Error loading pre-computed embeddings: {e}")
|
|
print("π Falling back to creating embeddings...")
|
|
return None
|
|
|
|
|
|
@st.cache_resource
|
|
def load_medical_system():
|
|
"""Load the medical RAG system (cached for performance)"""
|
|
|
|
with st.spinner("π Loading medical knowledge base..."):
|
|
|
|
ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")
|
|
|
|
|
|
documents = []
|
|
for i, item in enumerate(ds['train']):
|
|
content = f"Question: {item['Question']}\nAnswer: {item['Answer']}"
|
|
metadata = {
|
|
"doc_id": i,
|
|
"question": item['Question'],
|
|
"answer": item['Answer'],
|
|
"question_type": item['qtype'],
|
|
"type": "qa_pair"
|
|
}
|
|
documents.append(Document(page_content=content, metadata=metadata))
|
|
|
|
|
|
start = time.time()
|
|
|
|
vectorstore = get_vector_store()
|
|
end = time.time()
|
|
|
|
if vectorstore is None:
|
|
st.error("β Could not load the vectorstore. Please ensure the embeddings and text files exist.")
|
|
st.stop()
|
|
|
|
total_time = end - start
|
|
st.success(f"β
Loaded existing vectorstore in {total_time:.2f} seconds")
|
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(documents)
|
|
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
|
|
|
|
ensemble_retriever = EnsembleRetriever(
|
|
retrievers=[bm25_retriever, vector_retriever],
|
|
weights=[0.3, 0.7]
|
|
)
|
|
|
|
|
|
openai_key = os.getenv("OPENAI_API_KEY")
|
|
if not openai_key:
|
|
st.error("β OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml")
|
|
st.stop()
|
|
llm = ChatOpenAI(temperature=0, api_key=openai_key)
|
|
|
|
|
|
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
|
|
return documents, ensemble_retriever, llm, reranker |