Spaces:
Runtime error
Runtime error
| import torch | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from chunker_final import chunk_documents_to_dict | |
| import numpy as np | |
| class Retriever: | |
| def __init__(self, docs: dict) -> None: | |
| self.chunked_docs = chunk_documents_to_dict(docs) | |
| self.chunk_ids = list(self.chunked_docs.keys()) | |
| self.chunk_texts = list(self.chunked_docs.values()) | |
| tokenized_chunks = [text.lower().split(" ") for text in self.chunk_texts] | |
| self.bm25 = BM25Okapi(tokenized_chunks) | |
| self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1') | |
| self.doc_embeddings = self.sbert.encode(self.chunk_texts) | |
| def get_docs(self, query, method, n=15) -> dict: | |
| if method == "BM25": | |
| scores = self._get_bm25_scores(query) | |
| elif method == "semantic": | |
| scores = self._get_semantic_scores(query) | |
| elif method == "combined search": | |
| bm25_scores = self._get_bm25_scores(query) | |
| semantic_scores = self._get_semantic_scores(query) | |
| scores = 0.3 * bm25_scores + 0.7 * semantic_scores | |
| else: | |
| raise ValueError(f"Invalid search method: {method}") | |
| sorted_indices = scores.argsort(descending=True) | |
| result = {self.chunk_ids[i]: self.chunk_texts[i] for i in sorted_indices[:n]} | |
| return result | |
| def rerank(self, query, retrieved_docs: dict) -> dict: | |
| query_embedding = self.sbert.encode(query) | |
| rerank_scores = {} | |
| for chunk_id, chunk_text in retrieved_docs.items(): | |
| chunk_embedding = self.sbert.encode(chunk_text) | |
| similarity = np.dot(query_embedding, chunk_embedding) / ( | |
| np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding) | |
| ) | |
| rerank_scores[chunk_id] = similarity | |
| sorted_chunks = sorted(rerank_scores.items(), key=lambda x: x[1], reverse=True) | |
| reranked_docs = {chunk_id: retrieved_docs[chunk_id] for chunk_id, _ in sorted_chunks} | |
| return reranked_docs | |
| def _get_bm25_scores(self, query): | |
| tokenized_query = query.lower().split(" ") | |
| return torch.tensor(self.bm25.get_scores(tokenized_query)) | |
| def _get_semantic_scores(self, query): | |
| query_embedding = self.sbert.encode(query) | |
| scores = np.dot(self.doc_embeddings, query_embedding) / ( | |
| np.linalg.norm(self.doc_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| return torch.tensor(scores) |