from smolagents import Tool from langchain.docstore.document import Document from sentence_transformers import SentenceTransformer import numpy as np import datasets from typing import List class SentenceTransformerRetriever: """Retriever that uses SentenceTransformer embeddings for semantic search.""" def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"): """Initialize with documents and a SentenceTransformer model. Args: docs: List of Document objects model_name: Name of the SentenceTransformer model to use """ self.docs = docs self.model = SentenceTransformer(model_name) # Create embeddings for all documents self.doc_texts = [doc.page_content for doc in self.docs] # Ensure we get numpy arrays for document embeddings self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True) def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]: """Return documents relevant to the query. Args: query: Query string k: Number of documents to return Returns: List of relevant Document objects """ # Encode the query and ensure we get a numpy array query_embedding = self.model.encode(query, convert_to_numpy=True) # Calculate similarities # Calculate cosine similarity manually to avoid tensor conversion issues similarities = [] for doc_embedding in self.doc_embeddings: # Calculate cosine similarity between query and document dot_product = np.dot(query_embedding, doc_embedding) query_norm = np.linalg.norm(query_embedding) doc_norm = np.linalg.norm(doc_embedding) similarity = dot_product / (query_norm * doc_norm) similarities.append(similarity) # Convert to numpy array similarities = np.array(similarities) # Get the top k most similar documents # Sort indices by similarity in descending order and take the top k top_k_indices = np.argsort(-similarities)[:k] # Return the top k documents return [self.docs[i] for i in top_k_indices] class GuestInfoRetrieverTool(Tool): name = "guest_info_retriever" description = "Retrieves detailed information about gala guests based on their name or relation using semantic search." inputs = { "query": { "type": "string", "description": "The name or relation of the guest you want information about." } } output_type = "string" def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"): self.is_initialized = False self.retriever = SentenceTransformerRetriever(docs, model_name) def forward(self, query: str): results = self.retriever.get_relevant_documents(query) if results: return "\n\n".join([doc.page_content for doc in results[:3]]) else: return "No matching guest information found." def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"): """Load the guest dataset and create a retriever tool. Args: model_name: Name of the SentenceTransformer model to use Returns: GuestInfoRetrieverTool: A tool for retrieving guest information """ # Load the dataset guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") # Convert dataset entries into Document objects docs = [ Document( page_content="\n".join([ f"Name: {guest['name']}", f"Relation: {guest['relation']}", f"Description: {guest['description']}", f"Email: {guest['email']}" ]), metadata={"name": guest["name"]} ) for guest in guest_dataset ] # Return the tool with the specified model return GuestInfoRetrieverTool(docs, model_name=model_name)