recommendation / src /accuracy.py
sundaram22verma's picture
initial commit
9d76e23
import torch
from torchmetrics.retrieval import RetrievalPrecision, RetrievalRecall, RetrievalNormalizedDCG
from typing import List, Tuple, Dict, Any, Set
import sys
import os
import logging
# Add the project root to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Constants for field names — update here if needed
FIELD_QUERY = "query"
FIELD_RELEVANT_DOCS = "relevant_docs"
def evaluate_with_torchmetrics(model: Any, test_data: List[Tuple[str, List[str]]], k: int = 10) -> Tuple[float, float, float]:
device_str = model.device if hasattr(model, 'device') and isinstance(model.device, str) else 'cpu'
device = torch.device(device_str)
precision_metric = RetrievalPrecision(top_k=k).to(device) # type: ignore
recall_metric = RetrievalRecall(top_k=k).to(device) # type: ignore
ndcg_metric = RetrievalNormalizedDCG(top_k=k).to(device) # type: ignore
for i, (query, relevant_items_list) in enumerate(test_data):
relevant_items_set: Set[str] = set(relevant_items_list)
recommendation_output: Dict[str, Any] = model.get_recommendations(query, k=k)
recommended_docs: List[Dict[str, Any]] = recommendation_output.get("retrieved_documents", [])
if not recommended_docs:
logging.warning(f"No recommendations for query: {query}")
current_preds = torch.tensor([], dtype=torch.float, device=device)
current_targets = torch.tensor([], dtype=torch.bool, device=device)
else:
recommended_ids = [str(doc['id']) for doc in recommended_docs]
recommended_scores = [float(doc.get('score', 1.0)) for doc in recommended_docs]
current_targets = torch.tensor([1 if item_id in relevant_items_set else 0 for item_id in recommended_ids],
dtype=torch.bool, device=device)
current_preds = torch.tensor(recommended_scores, dtype=torch.float, device=device)
current_indexes = torch.full_like(current_preds, fill_value=i, dtype=torch.long, device=device)
precision_metric.update(current_preds, current_targets, indexes=current_indexes)
recall_metric.update(current_preds, current_targets, indexes=current_indexes)
ndcg_metric.update(current_preds, current_targets, indexes=current_indexes)
if not precision_metric.indexes:
logging.warning("No data accumulated in metrics. Returning 0.0 for all scores.")
return 0.0, 0.0, 0.0
precision = precision_metric.compute()
recall = recall_metric.compute()
ndcg = ndcg_metric.compute()
return precision.item(), recall.item(), ndcg.item()
def get_all_queries_and_relevant_docs() -> List[Tuple[str, List[str]]]:
from pymongo import MongoClient
# TODO: Move MongoDB URI to environment variables or a config file for security
client = MongoClient("mongodb+srv://sundram22verma:[email protected]/NewsDataSet?retryWrites=true&w=majority")
db = client["NewsDataSet"]
collection = db["parsedXmlArticles"]
# Print a sample document to help debug field names
sample_doc = collection.find_one()
print("Sample document from MongoDB:")
print(sample_doc)
# Get all documents that have headlines and IDs
find_conditions = {
"hl": {"$exists": True},
"id": {"$exists": True}
}
projection = {
"hl": 1,
"id": 1,
"syn": 1,
"_id": 0
}
cursor = collection.find(find_conditions, projection)
fetched_docs = list(cursor)
logging.info(f"Fetched {len(fetched_docs)} documents from MongoDB.")
test_data_from_db: List[Tuple[str, List[str]]] = []
# Create test data using headlines as queries and IDs as relevant documents
for doc in fetched_docs:
headline = doc.get("hl")
doc_id = doc.get("id")
if isinstance(headline, str) and headline.strip() and isinstance(doc_id, str):
# For each headline, find similar documents based on content
similar_docs = []
for other_doc in fetched_docs:
if other_doc.get("id") != doc_id: # Don't include the same document
# Add the document ID to similar docs
similar_docs.append(other_doc.get("id"))
if similar_docs: # Only add if we found similar documents
test_data_from_db.append((headline, similar_docs))
if not test_data_from_db:
logging.error("No valid test data created from MongoDB documents.")
else:
logging.info(f"Created {len(test_data_from_db)} test queries with relevant documents.")
return test_data_from_db
if __name__ == "__main__":
from core.recommender import RecoRecommender # Updated import path
# Configure logging
logging.basicConfig(
level=logging.DEBUG, # Set to DEBUG for verbose output
format='%(asctime)s - %(levelname)s - %(message)s'
)
print("Loading recommender components (models, data, index)...")
try:
model = RecoRecommender()
model.load_components()
print("Recommender components loaded successfully.")
except Exception as e:
logging.exception("Error loading recommender components.")
sys.exit(1)
print("Fetching all queries and relevant documents...")
test_data = get_all_queries_and_relevant_docs()
print(f"Loaded {len(test_data)} queries for evaluation.")
if not test_data:
print("No valid test data available. Exiting.")
sys.exit(1)
print("Evaluating with k=5...")
precision, recall, ndcg = evaluate_with_torchmetrics(model, test_data, k=5)
print(f"Precision@5: {precision:.4f}")
print(f"Recall@5: {recall:.4f}")
print(f"NDCG@5: {ndcg:.4f}")