import json import os import logging import torch from typing import List from langchain_core.documents import Document from sentence_transformers import SentenceTransformer try: from datasets import load_dataset except ImportError: load_dataset = None logger = logging.getLogger(__name__) def get_device(): """ Determine the appropriate device for PyTorch. Returns: str: Device name ('cuda', 'mps', or 'cpu'). """ if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]: """ Load guest dataset from a local JSON file or Hugging Face dataset. Args: dataset_path (str): Path to local JSON file or Hugging Face dataset name. Returns: List[Document]: List of Document objects with guest information. """ try: # Try loading from Hugging Face dataset if datasets library is available if load_dataset and not os.path.exists(dataset_path): logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}") guest_dataset = load_dataset(dataset_path, split="train") docs = [ Document( page_content="\n".join([ f"Name: {guest['name']}", f"Relation: {guest['relation']}", f"Description: {guest['description']}", f"Email: {guest['email']}" ]), metadata={ "name": guest["name"], "relation": guest["relation"], "description": guest["description"], "email": guest["email"] } ) for guest in guest_dataset ] logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset") return docs # Try loading from local JSON file if os.path.exists(dataset_path): logger.info(f"Loading guest dataset from local path: {dataset_path}") with open(dataset_path, 'r') as f: guests = json.load(f) docs = [ Document( page_content=guest.get('description', ''), metadata={ 'name': guest.get('name', ''), 'relation': guest.get('relation', ''), 'description': guest.get('description', ''), 'email': guest.get('email', '') # Optional email field } ) for guest in guests ] logger.info(f"Loaded {len(docs)} guests from local JSON") return docs # Fallback to mock dataset if both fail logger.warning(f"Dataset not found at {dataset_path}, using mock dataset") docs = [ Document( page_content="\n".join([ "Name: Dr. Nikola Tesla", "Relation: old friend from university days", "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", "Email: nikola.tesla@gmail.com" ]), metadata={ "name": "Dr. Nikola Tesla", "relation": "old friend from university days", "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", "email": "nikola.tesla@gmail.com" } ) ] logger.info("Loaded mock dataset with 1 guest") return docs except Exception as e: logger.error(f"Failed to load guest dataset: {e}") # Return mock dataset as final fallback docs = [ Document( page_content="\n".join([ "Name: Dr. Nikola Tesla", "Relation: old friend from university days", "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", "Email: nikola.tesla@gmail.com" ]), metadata={ "name": "Dr. Nikola Tesla", "relation": "old friend from university days", "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", "email": "nikola.tesla@gmail.com" } ) ] logger.info("Loaded mock dataset with 1 guest due to error") return docs class BM25Retriever: """ A retriever class using SentenceTransformer for embedding-based search. """ def __init__(self, dataset_path: str): """ Initialize the retriever with a SentenceTransformer model. Args: dataset_path (str): Path to the dataset for retrieval. Raises: Exception: If embedder initialization fails. """ try: self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device()) self.dataset_path = dataset_path logger.info("Initialized SentenceTransformer") except Exception as e: logger.error(f"Failed to initialize embedder: {e}") raise def search(self, query: str) -> List[dict]: """ Search the dataset for relevant guest information. Args: query (str): Search query (e.g., guest name or relation). Returns: List[dict]: List of matching guest metadata dictionaries. """ try: # Load dataset docs = load_guest_dataset(self.dataset_path) if not docs: logger.warning("No documents available for search") return [] # Convert documents to text for BM25 (using metadata for consistency) texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs] from langchain_community.retrievers import BM25Retriever retriever = BM25Retriever.from_texts(texts) retriever.k = 3 # Limit to top 3 results # Perform search results = retriever.invoke(query) # Map results back to original metadata matches = [ docs[i].metadata for i in range(len(docs)) if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results) ] logger.info(f"Found {len(matches)} matches for query: {query}") return matches[:3] # Return top 3 matches except Exception as e: logger.error(f"Search failed for query '{query}': {e}") return []