Spaces:
Starting
Starting
import json | |
import os | |
import logging | |
import torch | |
from typing import List | |
from langchain_core.documents import Document | |
from sentence_transformers import SentenceTransformer | |
try: | |
from datasets import load_dataset | |
except ImportError: | |
load_dataset = None | |
logger = logging.getLogger(__name__) | |
def get_device(): | |
""" | |
Determine the appropriate device for PyTorch. | |
Returns: | |
str: Device name ('cuda', 'mps', or 'cpu'). | |
""" | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
return "cpu" | |
def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]: | |
""" | |
Load guest dataset from a local JSON file or Hugging Face dataset. | |
Args: | |
dataset_path (str): Path to local JSON file or Hugging Face dataset name. | |
Returns: | |
List[Document]: List of Document objects with guest information. | |
""" | |
try: | |
# Try loading from Hugging Face dataset if datasets library is available | |
if load_dataset and not os.path.exists(dataset_path): | |
logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}") | |
guest_dataset = load_dataset(dataset_path, split="train") | |
docs = [ | |
Document( | |
page_content="\n".join([ | |
f"Name: {guest['name']}", | |
f"Relation: {guest['relation']}", | |
f"Description: {guest['description']}", | |
f"Email: {guest['email']}" | |
]), | |
metadata={ | |
"name": guest["name"], | |
"relation": guest["relation"], | |
"description": guest["description"], | |
"email": guest["email"] | |
} | |
) | |
for guest in guest_dataset | |
] | |
logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset") | |
return docs | |
# Try loading from local JSON file | |
if os.path.exists(dataset_path): | |
logger.info(f"Loading guest dataset from local path: {dataset_path}") | |
with open(dataset_path, 'r') as f: | |
guests = json.load(f) | |
docs = [ | |
Document( | |
page_content=guest.get('description', ''), | |
metadata={ | |
'name': guest.get('name', ''), | |
'relation': guest.get('relation', ''), | |
'description': guest.get('description', ''), | |
'email': guest.get('email', '') # Optional email field | |
} | |
) | |
for guest in guests | |
] | |
logger.info(f"Loaded {len(docs)} guests from local JSON") | |
return docs | |
# Fallback to mock dataset if both fail | |
logger.warning(f"Dataset not found at {dataset_path}, using mock dataset") | |
docs = [ | |
Document( | |
page_content="\n".join([ | |
"Name: Dr. Nikola Tesla", | |
"Relation: old friend from university days", | |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
"Email: [email protected]" | |
]), | |
metadata={ | |
"name": "Dr. Nikola Tesla", | |
"relation": "old friend from university days", | |
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
"email": "[email protected]" | |
} | |
) | |
] | |
logger.info("Loaded mock dataset with 1 guest") | |
return docs | |
except Exception as e: | |
logger.error(f"Failed to load guest dataset: {e}") | |
# Return mock dataset as final fallback | |
docs = [ | |
Document( | |
page_content="\n".join([ | |
"Name: Dr. Nikola Tesla", | |
"Relation: old friend from university days", | |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
"Email: [email protected]" | |
]), | |
metadata={ | |
"name": "Dr. Nikola Tesla", | |
"relation": "old friend from university days", | |
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
"email": "[email protected]" | |
} | |
) | |
] | |
logger.info("Loaded mock dataset with 1 guest due to error") | |
return docs | |
class BM25Retriever: | |
""" | |
A retriever class using SentenceTransformer for embedding-based search. | |
""" | |
def __init__(self, dataset_path: str): | |
""" | |
Initialize the retriever with a SentenceTransformer model. | |
Args: | |
dataset_path (str): Path to the dataset for retrieval. | |
Raises: | |
Exception: If embedder initialization fails. | |
""" | |
try: | |
self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device()) | |
self.dataset_path = dataset_path | |
logger.info("Initialized SentenceTransformer") | |
except Exception as e: | |
logger.error(f"Failed to initialize embedder: {e}") | |
raise | |
def search(self, query: str) -> List[dict]: | |
""" | |
Search the dataset for relevant guest information. | |
Args: | |
query (str): Search query (e.g., guest name or relation). | |
Returns: | |
List[dict]: List of matching guest metadata dictionaries. | |
""" | |
try: | |
# Load dataset | |
docs = load_guest_dataset(self.dataset_path) | |
if not docs: | |
logger.warning("No documents available for search") | |
return [] | |
# Convert documents to text for BM25 (using metadata for consistency) | |
texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs] | |
from langchain_community.retrievers import BM25Retriever | |
retriever = BM25Retriever.from_texts(texts) | |
retriever.k = 3 # Limit to top 3 results | |
# Perform search | |
results = retriever.invoke(query) | |
# Map results back to original metadata | |
matches = [ | |
docs[i].metadata | |
for i in range(len(docs)) | |
if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results) | |
] | |
logger.info(f"Found {len(matches)} matches for query: {query}") | |
return matches[:3] # Return top 3 matches | |
except Exception as e: | |
logger.error(f"Search failed for query '{query}': {e}") | |
return [] |