jarvis_gaia_agent / retriever.py
onisj's picture
feat(tools): add more tool to extend the functionaily of jarvis
751d628
import json
import os
import logging
import torch
from typing import List
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
try:
from datasets import load_dataset
except ImportError:
load_dataset = None
logger = logging.getLogger(__name__)
def get_device():
"""
Determine the appropriate device for PyTorch.
Returns:
str: Device name ('cuda', 'mps', or 'cpu').
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
"""
Load guest dataset from a local JSON file or Hugging Face dataset.
Args:
dataset_path (str): Path to local JSON file or Hugging Face dataset name.
Returns:
List[Document]: List of Document objects with guest information.
"""
try:
# Try loading from Hugging Face dataset if datasets library is available
if load_dataset and not os.path.exists(dataset_path):
logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
guest_dataset = load_dataset(dataset_path, split="train")
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={
"name": guest["name"],
"relation": guest["relation"],
"description": guest["description"],
"email": guest["email"]
}
)
for guest in guest_dataset
]
logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
return docs
# Try loading from local JSON file
if os.path.exists(dataset_path):
logger.info(f"Loading guest dataset from local path: {dataset_path}")
with open(dataset_path, 'r') as f:
guests = json.load(f)
docs = [
Document(
page_content=guest.get('description', ''),
metadata={
'name': guest.get('name', ''),
'relation': guest.get('relation', ''),
'description': guest.get('description', ''),
'email': guest.get('email', '') # Optional email field
}
)
for guest in guests
]
logger.info(f"Loaded {len(docs)} guests from local JSON")
return docs
# Fallback to mock dataset if both fail
logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
docs = [
Document(
page_content="\n".join([
"Name: Dr. Nikola Tesla",
"Relation: old friend from university days",
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"Email: [email protected]"
]),
metadata={
"name": "Dr. Nikola Tesla",
"relation": "old friend from university days",
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"email": "[email protected]"
}
)
]
logger.info("Loaded mock dataset with 1 guest")
return docs
except Exception as e:
logger.error(f"Failed to load guest dataset: {e}")
# Return mock dataset as final fallback
docs = [
Document(
page_content="\n".join([
"Name: Dr. Nikola Tesla",
"Relation: old friend from university days",
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"Email: [email protected]"
]),
metadata={
"name": "Dr. Nikola Tesla",
"relation": "old friend from university days",
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"email": "[email protected]"
}
)
]
logger.info("Loaded mock dataset with 1 guest due to error")
return docs
class BM25Retriever:
"""
A retriever class using SentenceTransformer for embedding-based search.
"""
def __init__(self, dataset_path: str):
"""
Initialize the retriever with a SentenceTransformer model.
Args:
dataset_path (str): Path to the dataset for retrieval.
Raises:
Exception: If embedder initialization fails.
"""
try:
self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
self.dataset_path = dataset_path
logger.info("Initialized SentenceTransformer")
except Exception as e:
logger.error(f"Failed to initialize embedder: {e}")
raise
def search(self, query: str) -> List[dict]:
"""
Search the dataset for relevant guest information.
Args:
query (str): Search query (e.g., guest name or relation).
Returns:
List[dict]: List of matching guest metadata dictionaries.
"""
try:
# Load dataset
docs = load_guest_dataset(self.dataset_path)
if not docs:
logger.warning("No documents available for search")
return []
# Convert documents to text for BM25 (using metadata for consistency)
texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
from langchain_community.retrievers import BM25Retriever
retriever = BM25Retriever.from_texts(texts)
retriever.k = 3 # Limit to top 3 results
# Perform search
results = retriever.invoke(query)
# Map results back to original metadata
matches = [
docs[i].metadata
for i in range(len(docs))
if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
]
logger.info(f"Found {len(matches)} matches for query: {query}")
return matches[:3] # Return top 3 matches
except Exception as e:
logger.error(f"Search failed for query '{query}': {e}")
return []