import logging from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field from datasets import load_dataset from rank_bm25 import BM25Okapi logger = logging.getLogger(__name__) class GuestInfoInput(BaseModel): query: str = Field(description="Query about guest information") async def guest_info_func(query: str) -> str: """ Retrieve guest information based on a query. Args: query (str): Query about guest information. Returns: str: Guest information or error message. """ try: logger.info(f"Retrieving guest info for query: {query}") dataset = load_dataset("agents-course/unit3-invitees", split="train") logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset") documents = [f"{row['name']} {row['relation']}" for row in dataset] tokenized_docs = [doc.lower().split() for doc in documents] bm25 = BM25Okapi(tokenized_docs) tokenized_query = query.lower().split() scores = bm25.get_scores(tokenized_query) best_idx = scores.argmax() if scores[best_idx] > 0: return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}" return "No matching guest found" except Exception as e: logger.error(f"Error retrieving guest info for query '{query}': {e}") return f"Error: {str(e)}" guest_info_retriever_tool = StructuredTool.from_function( func=guest_info_func, name="guest_info_retriever_tool", args_schema=GuestInfoInput, coroutine=guest_info_func )