Spaces:
Runtime error
Runtime error
from smolagents import Tool | |
from langchain.docstore.document import Document | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import datasets | |
from typing import List | |
class SentenceTransformerRetriever: | |
"""Retriever that uses SentenceTransformer embeddings for semantic search.""" | |
def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"): | |
"""Initialize with documents and a SentenceTransformer model. | |
Args: | |
docs: List of Document objects | |
model_name: Name of the SentenceTransformer model to use | |
""" | |
self.docs = docs | |
self.model = SentenceTransformer(model_name) | |
# Create embeddings for all documents | |
self.doc_texts = [doc.page_content for doc in self.docs] | |
# Ensure we get numpy arrays for document embeddings | |
self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True) | |
def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]: | |
"""Return documents relevant to the query. | |
Args: | |
query: Query string | |
k: Number of documents to return | |
Returns: | |
List of relevant Document objects | |
""" | |
# Encode the query and ensure we get a numpy array | |
query_embedding = self.model.encode(query, convert_to_numpy=True) | |
# Calculate similarities | |
# Calculate cosine similarity manually to avoid tensor conversion issues | |
similarities = [] | |
for doc_embedding in self.doc_embeddings: | |
# Calculate cosine similarity between query and document | |
dot_product = np.dot(query_embedding, doc_embedding) | |
query_norm = np.linalg.norm(query_embedding) | |
doc_norm = np.linalg.norm(doc_embedding) | |
similarity = dot_product / (query_norm * doc_norm) | |
similarities.append(similarity) | |
# Convert to numpy array | |
similarities = np.array(similarities) | |
# Get the top k most similar documents | |
# Sort indices by similarity in descending order and take the top k | |
top_k_indices = np.argsort(-similarities)[:k] | |
# Return the top k documents | |
return [self.docs[i] for i in top_k_indices] | |
class GuestInfoRetrieverTool(Tool): | |
name = "guest_info_retriever" | |
description = "Retrieves detailed information about gala guests based on their name or relation using semantic search." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The name or relation of the guest you want information about." | |
} | |
} | |
output_type = "string" | |
def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"): | |
self.is_initialized = False | |
self.retriever = SentenceTransformerRetriever(docs, model_name) | |
def forward(self, query: str): | |
results = self.retriever.get_relevant_documents(query) | |
if results: | |
return "\n\n".join([doc.page_content for doc in results[:3]]) | |
else: | |
return "No matching guest information found." | |
def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"): | |
"""Load the guest dataset and create a retriever tool. | |
Args: | |
model_name: Name of the SentenceTransformer model to use | |
Returns: | |
GuestInfoRetrieverTool: A tool for retrieving guest information | |
""" | |
# Load the dataset | |
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") | |
# Convert dataset entries into Document objects | |
docs = [ | |
Document( | |
page_content="\n".join([ | |
f"Name: {guest['name']}", | |
f"Relation: {guest['relation']}", | |
f"Description: {guest['description']}", | |
f"Email: {guest['email']}" | |
]), | |
metadata={"name": guest["name"]} | |
) | |
for guest in guest_dataset | |
] | |
# Return the tool with the specified model | |
return GuestInfoRetrieverTool(docs, model_name=model_name) | |