Unit_3_Agentic_RAG / retriever.py
VPCSinfo's picture
[IMP] added improvement
ff3806f
from smolagents import Tool
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer
import numpy as np
import datasets
from typing import List
class SentenceTransformerRetriever:
"""Retriever that uses SentenceTransformer embeddings for semantic search."""
def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"):
"""Initialize with documents and a SentenceTransformer model.
Args:
docs: List of Document objects
model_name: Name of the SentenceTransformer model to use
"""
self.docs = docs
self.model = SentenceTransformer(model_name)
# Create embeddings for all documents
self.doc_texts = [doc.page_content for doc in self.docs]
# Ensure we get numpy arrays for document embeddings
self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True)
def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]:
"""Return documents relevant to the query.
Args:
query: Query string
k: Number of documents to return
Returns:
List of relevant Document objects
"""
# Encode the query and ensure we get a numpy array
query_embedding = self.model.encode(query, convert_to_numpy=True)
# Calculate similarities
# Calculate cosine similarity manually to avoid tensor conversion issues
similarities = []
for doc_embedding in self.doc_embeddings:
# Calculate cosine similarity between query and document
dot_product = np.dot(query_embedding, doc_embedding)
query_norm = np.linalg.norm(query_embedding)
doc_norm = np.linalg.norm(doc_embedding)
similarity = dot_product / (query_norm * doc_norm)
similarities.append(similarity)
# Convert to numpy array
similarities = np.array(similarities)
# Get the top k most similar documents
# Sort indices by similarity in descending order and take the top k
top_k_indices = np.argsort(-similarities)[:k]
# Return the top k documents
return [self.docs[i] for i in top_k_indices]
class GuestInfoRetrieverTool(Tool):
name = "guest_info_retriever"
description = "Retrieves detailed information about gala guests based on their name or relation using semantic search."
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest you want information about."
}
}
output_type = "string"
def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"):
self.is_initialized = False
self.retriever = SentenceTransformerRetriever(docs, model_name)
def forward(self, query: str):
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results[:3]])
else:
return "No matching guest information found."
def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"):
"""Load the guest dataset and create a retriever tool.
Args:
model_name: Name of the SentenceTransformer model to use
Returns:
GuestInfoRetrieverTool: A tool for retrieving guest information
"""
# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
# Convert dataset entries into Document objects
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={"name": guest["name"]}
)
for guest in guest_dataset
]
# Return the tool with the specified model
return GuestInfoRetrieverTool(docs, model_name=model_name)