""" Retrieval and FAISS Embedding Module for Medical QA Chatbot ============================================================ This module handles: 1. Embedding documents 2. Building and saving FAISS index 3. Retrieval with initial FAISS search + reranking using BioBERT similarity """ import faiss import pandas as pd import numpy as np import torch from sentence_transformers import SentenceTransformer, util from sklearn.preprocessing import normalize from Query_processing import preprocess_query import os # ------------------------------- # File Paths # ------------------------------- # Get the project root directory (one level up from script_dir) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Absolute paths for dataset and index files csv_path = os.path.join(project_root, 'Datasets', 'flattened_drug_dataset_cleaned.csv') faiss_index_path = os.path.join(project_root, 'Vectors', 'faiss_index.idx') doc_metadata_path = os.path.join(project_root, 'Vectors', 'doc_metadata.pkl') doc_vectors_path = os.path.join(project_root, 'Vectors', 'doc_vectors.npy') # Load the dataset df = pd.read_csv(csv_path).dropna(subset=['chunk_text']) # ------------------------------- # Model Initialization # ------------------------------- fast_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') biobert = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') # ------------------------------- # Function: Embed and Build FAISS Index # ------------------------------- def Embed_and_FAISS(): """ Embeds the drug dataset and builds a FAISS index for fast retrieval. Saves the index, metadata, and document vectors to disk. """ print("Embedding document chunks using fast embedder...") # Build full context strings df['full_text'] = df.apply(lambda x: f"{x['drug_name']} | {x['section']} > {x['subsection']} | {x['chunk_text']}", axis=1) full_texts = df['full_text'].tolist() doc_embeddings = fast_embedder.encode(full_texts, convert_to_numpy=True, show_progress_bar=True) # Normalize embeddings and build index doc_embeddings = normalize(doc_embeddings, axis=1, norm='l2') dimension = doc_embeddings.shape[1] index = faiss.IndexFlatIP(dimension) index.add(doc_embeddings) # Save index and metadata faiss.write_index(index, faiss_index_path) df.to_pickle(doc_metadata_path) np.save(doc_vectors_path, doc_embeddings) print("FAISS index built and saved successfully.") # ------------------------------- # Function: Retrieve with Context and Averaged Embeddings # ------------------------------- def retrieve_with_context_averagedembeddings(query, top_k=10, predicted_intent=None, detected_entities=None, alpha=0.8): """ Retrieve top chunks using FAISS followed by reranking with BioBERT similarity. Parameters: query (str): User query text. top_k (int): Number of top results to retrieve. predicted_intent (str, optional): Detected intent to adjust retrieval. detected_entities (list, optional): List of named entities. alpha (float): Weight for combining query and intent embeddings. Returns: pd.DataFrame: Retrieved chunks with metadata and reranked scores. """ print(f"[Retrieval Pipeline Started] Query: {query}") # Embed and normalize the query query_vec = fast_embedder.encode([query], convert_to_numpy=True) if predicted_intent: intent_vec = fast_embedder.encode([predicted_intent], convert_to_numpy=True) query_vec = normalize((alpha * query_vec + (1 - alpha) * intent_vec), axis=1) # Load FAISS index and search index = faiss.read_index(faiss_index_path) D, I = index.search(query_vec, top_k) df_meta = pd.read_pickle(doc_metadata_path) retrieved_df = df_meta.loc[I[0]].copy() retrieved_df['faiss_score'] = D[0] # BioBERT reranking query_emb = biobert.encode(query, convert_to_tensor=True) chunk_embs = biobert.encode(retrieved_df['full_text'].tolist(), convert_to_tensor=True) cos_scores = util.pytorch_cos_sim(query_emb, chunk_embs)[0] reranked_idx = torch.argsort(cos_scores, descending=True) # Boost scores based on intent, subsection match, or entity presence results = [] for idx in reranked_idx: idx = int(idx) row = retrieved_df.iloc[idx] score = cos_scores[idx].item() section = row['section'][0] if isinstance(row['section'], tuple) else row['section'] subsection = row['subsection'][0] if isinstance(row['subsection'], tuple) else row['subsection'] if isinstance(predicted_intent, tuple): predicted_intent = predicted_intent[0] if predicted_intent and section.strip().lower() == predicted_intent.strip().lower(): score += 0.05 if predicted_intent and predicted_intent.lower() in subsection.strip().lower(): score += 0.03 if detected_entities: if any(ent.lower() in row['chunk_text'].lower() for ent in detected_entities): score += 0.1 results.append({ 'chunk_id': row['chunk_id'], 'drug_name': row['drug_name'], 'section': row['section'], 'subsection': row['subsection'], 'chunk_text': row['chunk_text'], 'faiss_score': row['faiss_score'], 'semantic_similarity_score': score }) return pd.DataFrame(results) # ------------------------------- # Function: Retrieval Wrapper # ------------------------------- def Retrieval_averagedQP(raw_query, intent, entities, top_k=10, alpha=0.8): """ Wrapper to retrieve top-k chunks given a raw user query. Parameters: raw_query (str): The user query. intent (str): Predicted intent from query processing. entities (list): Detected biomedical entities. top_k (int): Number of top results to return. alpha (float): Weighting between query and intent embeddings. Returns: pd.DataFrame: Top retrieved chunks with scores. """ results_df = retrieve_with_context_averagedembeddings( raw_query, top_k=top_k, predicted_intent=intent, detected_entities=entities, alpha=alpha ) return results_df[['chunk_id', 'drug_name', 'section', 'subsection', 'chunk_text', 'faiss_score', 'semantic_similarity_score']]