jarvis_gaia_agent / tools /guest_info.py
onisj's picture
feat(tools): add more tool to extend the functionaily of jarvis
751d628
import logging
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from datasets import load_dataset
from rank_bm25 import BM25Okapi
logger = logging.getLogger(__name__)
class GuestInfoInput(BaseModel):
query: str = Field(description="Query about guest information")
async def guest_info_func(query: str) -> str:
"""
Retrieve guest information based on a query.
Args:
query (str): Query about guest information.
Returns:
str: Guest information or error message.
"""
try:
logger.info(f"Retrieving guest info for query: {query}")
dataset = load_dataset("agents-course/unit3-invitees", split="train")
logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset")
documents = [f"{row['name']} {row['relation']}" for row in dataset]
tokenized_docs = [doc.lower().split() for doc in documents]
bm25 = BM25Okapi(tokenized_docs)
tokenized_query = query.lower().split()
scores = bm25.get_scores(tokenized_query)
best_idx = scores.argmax()
if scores[best_idx] > 0:
return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}"
return "No matching guest found"
except Exception as e:
logger.error(f"Error retrieving guest info for query '{query}': {e}")
return f"Error: {str(e)}"
guest_info_retriever_tool = StructuredTool.from_function(
func=guest_info_func,
name="guest_info_retriever_tool",
args_schema=GuestInfoInput,
coroutine=guest_info_func
)