Spaces:
Sleeping
Sleeping
import torch | |
from torchmetrics.retrieval import RetrievalPrecision, RetrievalRecall, RetrievalNormalizedDCG | |
from typing import List, Tuple, Dict, Any, Set | |
import sys | |
import os | |
import logging | |
# Add the project root to sys.path | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
# Constants for field names — update here if needed | |
FIELD_QUERY = "query" | |
FIELD_RELEVANT_DOCS = "relevant_docs" | |
def evaluate_with_torchmetrics(model: Any, test_data: List[Tuple[str, List[str]]], k: int = 10) -> Tuple[float, float, float]: | |
device_str = model.device if hasattr(model, 'device') and isinstance(model.device, str) else 'cpu' | |
device = torch.device(device_str) | |
precision_metric = RetrievalPrecision(top_k=k).to(device) # type: ignore | |
recall_metric = RetrievalRecall(top_k=k).to(device) # type: ignore | |
ndcg_metric = RetrievalNormalizedDCG(top_k=k).to(device) # type: ignore | |
for i, (query, relevant_items_list) in enumerate(test_data): | |
relevant_items_set: Set[str] = set(relevant_items_list) | |
recommendation_output: Dict[str, Any] = model.get_recommendations(query, k=k) | |
recommended_docs: List[Dict[str, Any]] = recommendation_output.get("retrieved_documents", []) | |
if not recommended_docs: | |
logging.warning(f"No recommendations for query: {query}") | |
current_preds = torch.tensor([], dtype=torch.float, device=device) | |
current_targets = torch.tensor([], dtype=torch.bool, device=device) | |
else: | |
recommended_ids = [str(doc['id']) for doc in recommended_docs] | |
recommended_scores = [float(doc.get('score', 1.0)) for doc in recommended_docs] | |
current_targets = torch.tensor([1 if item_id in relevant_items_set else 0 for item_id in recommended_ids], | |
dtype=torch.bool, device=device) | |
current_preds = torch.tensor(recommended_scores, dtype=torch.float, device=device) | |
current_indexes = torch.full_like(current_preds, fill_value=i, dtype=torch.long, device=device) | |
precision_metric.update(current_preds, current_targets, indexes=current_indexes) | |
recall_metric.update(current_preds, current_targets, indexes=current_indexes) | |
ndcg_metric.update(current_preds, current_targets, indexes=current_indexes) | |
if not precision_metric.indexes: | |
logging.warning("No data accumulated in metrics. Returning 0.0 for all scores.") | |
return 0.0, 0.0, 0.0 | |
precision = precision_metric.compute() | |
recall = recall_metric.compute() | |
ndcg = ndcg_metric.compute() | |
return precision.item(), recall.item(), ndcg.item() | |
def get_all_queries_and_relevant_docs() -> List[Tuple[str, List[str]]]: | |
from pymongo import MongoClient | |
# TODO: Move MongoDB URI to environment variables or a config file for security | |
client = MongoClient("mongodb+srv://sundram22verma:[email protected]/NewsDataSet?retryWrites=true&w=majority") | |
db = client["NewsDataSet"] | |
collection = db["parsedXmlArticles"] | |
# Print a sample document to help debug field names | |
sample_doc = collection.find_one() | |
print("Sample document from MongoDB:") | |
print(sample_doc) | |
# Get all documents that have headlines and IDs | |
find_conditions = { | |
"hl": {"$exists": True}, | |
"id": {"$exists": True} | |
} | |
projection = { | |
"hl": 1, | |
"id": 1, | |
"syn": 1, | |
"_id": 0 | |
} | |
cursor = collection.find(find_conditions, projection) | |
fetched_docs = list(cursor) | |
logging.info(f"Fetched {len(fetched_docs)} documents from MongoDB.") | |
test_data_from_db: List[Tuple[str, List[str]]] = [] | |
# Create test data using headlines as queries and IDs as relevant documents | |
for doc in fetched_docs: | |
headline = doc.get("hl") | |
doc_id = doc.get("id") | |
if isinstance(headline, str) and headline.strip() and isinstance(doc_id, str): | |
# For each headline, find similar documents based on content | |
similar_docs = [] | |
for other_doc in fetched_docs: | |
if other_doc.get("id") != doc_id: # Don't include the same document | |
# Add the document ID to similar docs | |
similar_docs.append(other_doc.get("id")) | |
if similar_docs: # Only add if we found similar documents | |
test_data_from_db.append((headline, similar_docs)) | |
if not test_data_from_db: | |
logging.error("No valid test data created from MongoDB documents.") | |
else: | |
logging.info(f"Created {len(test_data_from_db)} test queries with relevant documents.") | |
return test_data_from_db | |
if __name__ == "__main__": | |
from core.recommender import RecoRecommender # Updated import path | |
# Configure logging | |
logging.basicConfig( | |
level=logging.DEBUG, # Set to DEBUG for verbose output | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
print("Loading recommender components (models, data, index)...") | |
try: | |
model = RecoRecommender() | |
model.load_components() | |
print("Recommender components loaded successfully.") | |
except Exception as e: | |
logging.exception("Error loading recommender components.") | |
sys.exit(1) | |
print("Fetching all queries and relevant documents...") | |
test_data = get_all_queries_and_relevant_docs() | |
print(f"Loaded {len(test_data)} queries for evaluation.") | |
if not test_data: | |
print("No valid test data available. Exiting.") | |
sys.exit(1) | |
print("Evaluating with k=5...") | |
precision, recall, ndcg = evaluate_with_torchmetrics(model, test_data, k=5) | |
print(f"Precision@5: {precision:.4f}") | |
print(f"Recall@5: {recall:.4f}") | |
print(f"NDCG@5: {ndcg:.4f}") | |