File size: 5,817 Bytes
9d76e23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torchmetrics.retrieval import RetrievalPrecision, RetrievalRecall, RetrievalNormalizedDCG
from typing import List, Tuple, Dict, Any, Set
import sys
import os
import logging

# Add the project root to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Constants for field names — update here if needed
FIELD_QUERY = "query"
FIELD_RELEVANT_DOCS = "relevant_docs"

def evaluate_with_torchmetrics(model: Any, test_data: List[Tuple[str, List[str]]], k: int = 10) -> Tuple[float, float, float]:
    device_str = model.device if hasattr(model, 'device') and isinstance(model.device, str) else 'cpu'
    device = torch.device(device_str)

    precision_metric = RetrievalPrecision(top_k=k).to(device)  # type: ignore
    recall_metric = RetrievalRecall(top_k=k).to(device)        # type: ignore
    ndcg_metric = RetrievalNormalizedDCG(top_k=k).to(device)   # type: ignore

    for i, (query, relevant_items_list) in enumerate(test_data):
        relevant_items_set: Set[str] = set(relevant_items_list)
        recommendation_output: Dict[str, Any] = model.get_recommendations(query, k=k)
        recommended_docs: List[Dict[str, Any]] = recommendation_output.get("retrieved_documents", [])

        if not recommended_docs:
            logging.warning(f"No recommendations for query: {query}")
            current_preds = torch.tensor([], dtype=torch.float, device=device)
            current_targets = torch.tensor([], dtype=torch.bool, device=device)
        else:
            recommended_ids = [str(doc['id']) for doc in recommended_docs]
            recommended_scores = [float(doc.get('score', 1.0)) for doc in recommended_docs]
            current_targets = torch.tensor([1 if item_id in relevant_items_set else 0 for item_id in recommended_ids],
                                           dtype=torch.bool, device=device)
            current_preds = torch.tensor(recommended_scores, dtype=torch.float, device=device)

        current_indexes = torch.full_like(current_preds, fill_value=i, dtype=torch.long, device=device)

        precision_metric.update(current_preds, current_targets, indexes=current_indexes)
        recall_metric.update(current_preds, current_targets, indexes=current_indexes)
        ndcg_metric.update(current_preds, current_targets, indexes=current_indexes)

    if not precision_metric.indexes:
        logging.warning("No data accumulated in metrics. Returning 0.0 for all scores.")
        return 0.0, 0.0, 0.0

    precision = precision_metric.compute()
    recall = recall_metric.compute()
    ndcg = ndcg_metric.compute()
    return precision.item(), recall.item(), ndcg.item()

def get_all_queries_and_relevant_docs() -> List[Tuple[str, List[str]]]:
    from pymongo import MongoClient

    # TODO: Move MongoDB URI to environment variables or a config file for security
    client = MongoClient("mongodb+srv://sundram22verma:[email protected]/NewsDataSet?retryWrites=true&w=majority")
    db = client["NewsDataSet"]
    collection = db["parsedXmlArticles"]

    # Print a sample document to help debug field names
    sample_doc = collection.find_one()
    print("Sample document from MongoDB:")
    print(sample_doc)

    # Get all documents that have headlines and IDs
    find_conditions = {
        "hl": {"$exists": True},
        "id": {"$exists": True}
    }
    projection = {
        "hl": 1,
        "id": 1,
        "syn": 1,
        "_id": 0
    }

    cursor = collection.find(find_conditions, projection)
    fetched_docs = list(cursor)
    logging.info(f"Fetched {len(fetched_docs)} documents from MongoDB.")

    test_data_from_db: List[Tuple[str, List[str]]] = []

    # Create test data using headlines as queries and IDs as relevant documents
    for doc in fetched_docs:
        headline = doc.get("hl")
        doc_id = doc.get("id")

        if isinstance(headline, str) and headline.strip() and isinstance(doc_id, str):
            # For each headline, find similar documents based on content
            similar_docs = []
            for other_doc in fetched_docs:
                if other_doc.get("id") != doc_id:  # Don't include the same document
                    # Add the document ID to similar docs
                    similar_docs.append(other_doc.get("id"))
            
            if similar_docs:  # Only add if we found similar documents
                test_data_from_db.append((headline, similar_docs))

    if not test_data_from_db:
        logging.error("No valid test data created from MongoDB documents.")
    else:
        logging.info(f"Created {len(test_data_from_db)} test queries with relevant documents.")
    
    return test_data_from_db

if __name__ == "__main__":
    from core.recommender import RecoRecommender  # Updated import path

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG,  # Set to DEBUG for verbose output
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

    print("Loading recommender components (models, data, index)...")
    try:
        model = RecoRecommender()
        model.load_components()
        print("Recommender components loaded successfully.")
    except Exception as e:
        logging.exception("Error loading recommender components.")
        sys.exit(1)

    print("Fetching all queries and relevant documents...")
    test_data = get_all_queries_and_relevant_docs()
    print(f"Loaded {len(test_data)} queries for evaluation.")

    if not test_data:
        print("No valid test data available. Exiting.")
        sys.exit(1)

    print("Evaluating with k=5...")
    precision, recall, ndcg = evaluate_with_torchmetrics(model, test_data, k=5)
    print(f"Precision@5: {precision:.4f}")
    print(f"Recall@5: {recall:.4f}")
    print(f"NDCG@5: {ndcg:.4f}")