MedicalQAChatBotRAG / src /get_medical_system.py
DeathBlade020's picture
Update src/get_medical_system.py
2ac2fa7 verified
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_openai import ChatOpenAI
import numpy as np
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
import streamlit as st
from datasets import load_dataset
import os
import pickle
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore # Add this import
import time
load_dotenv()
def get_vector_store():
"""Load vectorstore from pre-computed embeddings"""
try:
# Load pre-computed data
if not os.path.exists('src/medical_embeddings.npy'):
raise FileNotFoundError("medical_embeddings.npy not found")
if not os.path.exists('src/medical_texts.pkl'):
raise FileNotFoundError("medical_texts.pkl not found")
print("πŸ“₯ Loading pre-computed embeddings...")
embeddings_array = np.load('src/medical_embeddings.npy')
with open('src/medical_texts.pkl', 'rb') as f:
texts = pickle.load(f)
print(f"βœ… Loaded {len(embeddings_array)} pre-computed embeddings")
# Create FAISS index from pre-computed embeddings
dimension = embeddings_array.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings_array.astype('float32')) # type: ignore
# import os
# os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp'
# os.makedirs('/tmp', exist_ok=True)
# Create embedding function for new queries
embeddings_function = HuggingFaceEmbeddings(
model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
)
# Create proper Document objects and InMemoryDocstore
documents_dict = {}
documents = []
for i, text in enumerate(texts):
# Create Document objects with proper metadata
doc = Document(
page_content=text,
metadata={"doc_id": i, "type": "medical_qa"}
)
documents_dict[str(i)] = doc
documents.append(doc)
# Create proper docstore
docstore = InMemoryDocstore(documents_dict)
# Create index to docstore mapping
index_to_docstore_id = {i: str(i) for i in range(len(texts))}
# Create FAISS vectorstore with proper parameters
vectorstore = FAISS(
embedding_function=embeddings_function,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id
)
return vectorstore, documents
except FileNotFoundError as e:
print(f"❌ Pre-computed files not found: {e}")
print("πŸ”„ Falling back to creating embeddings...")
return None, None
except Exception as e:
print(f"❌ Error loading pre-computed embeddings: {e}")
print("πŸ”„ Falling back to creating embeddings...")
return None, None
@st.cache_resource
def load_medical_system():
"""Load the medical RAG system (cached for performance)"""
with st.spinner("πŸ”„ Loading medical knowledge base..."):
# Load dataset
# ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")
# # Create documents
# documents = []
# for i, item in enumerate(ds['train']): # type: ignore
# content = f"Question: {item['Question']}\nAnswer: {item['Answer']}" # type: ignore
# metadata = {
# "doc_id": i,
# "question": item['Question'], # type: ignore
# "answer": item['Answer'], # type: ignore
# "question_type": item['qtype'], # type: ignore
# "type": "qa_pair"
# }
# documents.append(Document(page_content=content, metadata=metadata))
start = time.time()
# Try to load existing vectorstore
vectorstore, documents = get_vector_store()
end = time.time()
if vectorstore is None or documents is None:
st.error("❌ Could not load the vectorstore. Please ensure the embeddings and text files exist.")
st.stop()
total_time = end - start
st.success(f"βœ… Loaded existing vectorstore in {total_time:.2f} seconds")
# Create retrievers
bm25_retriever = BM25Retriever.from_documents(documents)
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_retriever],
weights=[0.3, 0.7]
)
# create LLM
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
st.error("❌ OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml")
st.stop()
llm = ChatOpenAI(temperature=0, api_key=openai_key) # type: ignore
# Create reranker
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return documents, ensemble_retriever, llm, reranker