recommendation / src /core /recommender.py
sundaram22verma's picture
Bug fix
ca2bd00
import pandas as pd # Import pandas for data manipulation, aliased as pd
import numpy as np # Import numpy for numerical operations, aliased as np
import faiss # Import faiss for efficient similarity search
import json # Import json for working with JSON data
from typing import List, Dict, Optional # Import typing hints for better code readability and static analysis
from datetime import datetime # Import datetime for handling date and time objects
import logging # Import logging module for application logging
from sentence_transformers import SentenceTransformer, CrossEncoder, util # Import specific classes from sentence_transformers library
from indicnlp.tokenize import indic_tokenize # Import tokenizer for Indic languages
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory # Import normalizer factory for Indic languages
import torch # Import PyTorch for deep learning functionalities
import os # Import os module for interacting with the operating system (e.g., file paths)
from indicnlp import common # Import common utilities from the indicnlp library
import pickle # Import pickle for saving/loading Python objects (like lists of IDs)
import re # Import re module for regular expression operations
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Import classes for loading pre-trained models from Hugging Face Transformers
from fastapi import HTTPException # Import HTTPException for raising HTTP errors in FastAPI applications
import asyncio # Import asyncio for asynchronous operations
from src.fine_tuning.trainer import ModelTrainer
from src.fine_tuning.config import MODEL_STATUS
from src.config.settings import ( # Import configuration constants from the settings file
EMBED_MODEL_NAME, GENERATOR_MODEL_NAME, RERANKER_MODEL_NAME, # Model names
INDEX_PATH, INTERACTION_LOG_PATH, INDIC_NLP_RESOURCES_PATH, MONGO_FAISS_META_COLLECTION_NAME, # File paths & DB names
HEADLINE_COL, SEOLOCATION_COL, DEEPLINK_COL, LAST_UPDATED_COL, FINE_TUNED_RERANKER_SAVE_PATH, # Column names & paths
IMAGE_ID_COL, IMAGE_RATIO_COL, IMAGE_SIZE_COL, TAXONOMY_COL, INDEX_IDS_PATH, # Column names
SYN_COL, KEY_COL, ID_COL, TOPIC_COL, PROPERTY_COL, # Existing Column names
DEFAULT_K, SIMILARITY_THRESHOLD, CANDIDATE_MULTIPLIER # Recommendation parameters
)
from src.database.mongodb import mongodb # Import the MongoDB client instance
logger = logging.getLogger(__name__) # Initialize a logger instance for this module
class RecoRecommender:
"""
A RAG-based recommender system for multi language content using FAISS retrieval.
"""
FAISS_IDS_DOC_ID = "faiss_indexed_document_ids_v1" # Document ID for storing indexed IDs in MongoDB
MODEL_METADATA_DOC_ID = "model_metadata"
EMBEDDING_CHECKSUM_KEY = "embedding_checksum"
MODEL_VERSION_KEY = "model_version"
FINE_TUNING_CONFIG_KEY = "fine_tuning_config"
EMBEDDING_STATUS_KEY = "embedding_status"
MODEL_STATUS = MODEL_STATUS # Add MODEL_STATUS as class attribute
def __init__(self):
"""Initialize the recommender with configuration."""
logger.info("Initializing RecoRecommender...")
# Configuration
self.headline_col = HEADLINE_COL
self.key = KEY_COL
self.syn = SYN_COL
self.id_col = ID_COL
self.topic_col = TOPIC_COL
self.property_col = PROPERTY_COL
self.taxonomy_col = TAXONOMY_COL
self.seolocation_col = SEOLOCATION_COL
self.deeplink_col = DEEPLINK_COL
self.last_updated_col = LAST_UPDATED_COL
self.image_id_col = IMAGE_ID_COL
self.image_ratio_col = IMAGE_RATIO_COL
self.image_size_col = IMAGE_SIZE_COL
self.processed_content_col = f"Processed Content"
# State Variables
self.df: Optional[pd.DataFrame] = None
self.index: Optional[faiss.Index] = None
self.embed_model: Optional[SentenceTransformer] = None
self.base_reranker: Optional[CrossEncoder] = None
self.fine_tuned_reranker: Optional[CrossEncoder] = None
self.indexed_ids: List[str] = []
self.tokenizer: Optional[AutoTokenizer] = None
self.generator: Optional[AutoModelForSeq2SeqLM] = None
self.normalizer = None
# self.device = "cuda" if torch.cuda.is_available() else "cpu" # Old device detection
logger.info("Determining compute device...")
if torch.cuda.is_available():
self.device = "cuda"
logger.info("CUDA is available. Using NVIDIA GPU.")
else:
try:
import intel_extension_for_pytorch as ipex # Attempt to import IPEX
if hasattr(torch, 'xpu') and torch.xpu.is_available():
self.device = "xpu"
logger.info("Intel XPU is available via IPEX. Using Intel GPU.")
else:
self.device = "cpu"
logger.info("CUDA not available. Intel XPU not available or IPEX not fully configured. Using CPU.")
except ImportError:
self.device = "cpu"
logger.info("CUDA not available. Intel Extension for PyTorch (IPEX) not found. Using CPU.")
except Exception as e: # Catch other potential errors during XPU check
self.device = "cpu"
logger.error(f"Error during Intel XPU check: {e}. Using CPU.")
logger.info(f"Selected device: {self.device}")
# Initialize MongoDB collection only if connection is available
if mongodb.db is not None:
self.faiss_meta_collection = mongodb.db[MONGO_FAISS_META_COLLECTION_NAME]
else:
self.faiss_meta_collection = None
logger.warning("MongoDB not available. Some features may be limited.")
# Initialize model trainer
self.model_trainer = ModelTrainer(RERANKER_MODEL_NAME, device=self.device) # Pass the determined device
self._setup_indic_nlp()
logger.info("RecoRecommender initialized.")
def _setup_indic_nlp(self):
"""Initialize Indic NLP resources."""
logger.info("Setting up Indic NLP resources...") # Log the start of Indic NLP setup
if not os.path.exists(INDIC_NLP_RESOURCES_PATH): # Check if the Indic NLP resources path exists
raise FileNotFoundError(f"Indic NLP resources not found at {INDIC_NLP_RESOURCES_PATH}") # Raise error if path not found
os.environ["INDIC_RESOURCES_PATH"] = INDIC_NLP_RESOURCES_PATH # Set environment variable for Indic NLP resource path
try: # Start a try-except block for error handling
common.set_resources_path(INDIC_NLP_RESOURCES_PATH) # Set the resource path for the indicnlp library
self.normalizer = IndicNormalizerFactory().get_normalizer("hi") # Initialize the language normalizer
logger.info("Indic NLP resources setup complete.") # Log successful setup
except Exception as e: # Catch any exception during setup
logger.error(f"Error setting up Indic NLP resources: {e}", exc_info=True) # Log the error with traceback
raise # Re-raise the exception to halt execution if setup fails
def load_models(self):
"""Load all required ML models."""
logger.info("Loading ML models...") # Log the start of ML model loading
try: # Start a try-except block for error handling
if self.embed_model is None: # Check if the embedding model is not already loaded
logger.info(f"Loading embedding model: {EMBED_MODEL_NAME}") # Log which embedding model is being loaded
self.embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=self.device) # Load the sentence transformer model
if self.tokenizer is None or self.generator is None: # Check if tokenizer or generator model is not loaded
logger.info(f"Loading generator model: {GENERATOR_MODEL_NAME}") # Log which generator model is being loaded
self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_NAME, model_max_length=1024) # Load tokenizer for the generator
self.generator = AutoModelForSeq2SeqLM.from_pretrained(GENERATOR_MODEL_NAME) # Load the sequence-to-sequence LM
if self.device == "cuda": # Check if CUDA (GPU) is the selected device
self.generator = self.generator.to(self.device) # Move the generator model to the GPU
# Load base reranker model
if self.base_reranker is None:
logger.info(f"Loading base reranker model: {RERANKER_MODEL_NAME}")
self.base_reranker = CrossEncoder(RERANKER_MODEL_NAME, device=self.device)
self.reranker = self.base_reranker # Default to base model
# Try to load fine-tuned model if available
self._load_fine_tuned_model()
logger.info("ML models loaded.") # Log successful loading of all models
except Exception as e: # Catch any exception during model loading
logger.error(f"Error loading ML models: {e}", exc_info=True) # Log the error with traceback
raise # Re-raise the exception
def _load_fine_tuned_model(self):
"""Attempt to load the fine-tuned reranker model if available."""
try:
metadata = self.model_trainer.get_model_status()
if metadata.get("current_model_status") == MODEL_STATUS["FINE_TUNED"]:
current_version = metadata.get("current_version", "v0")
model_path = str(self.model_trainer.get_model_path(current_version))
if os.path.exists(model_path):
logger.info(f"Loading fine-tuned reranker model version {current_version}")
self.fine_tuned_reranker = CrossEncoder(model_path, device=self.device)
self.reranker = self.fine_tuned_reranker
logger.info("Successfully loaded fine-tuned reranker model")
return
logger.info("No fine-tuned model found or not in fine-tuned state, using base model")
self.reranker = self.base_reranker
except Exception as e:
logger.error(f"Error loading fine-tuned model: {e}. Falling back to base model.")
self.reranker = self.base_reranker
def _calculate_embedding_checksum(self, content: str) -> str:
"""Calculate a checksum for content to detect changes in embedding logic."""
import hashlib
# Include model name and any relevant preprocessing parameters in the checksum
checksum_content = f"{content}_{EMBED_MODEL_NAME}"
return hashlib.md5(checksum_content.encode()).hexdigest()
def _get_model_metadata(self) -> Dict:
"""Retrieve current model metadata from MongoDB."""
try:
if self.faiss_meta_collection is None:
logger.warning("MongoDB not available. Returning default metadata.")
return {
"_id": self.MODEL_METADATA_DOC_ID,
"embedding_model_name": EMBED_MODEL_NAME,
}
metadata = self.faiss_meta_collection.find_one({"_id": self.MODEL_METADATA_DOC_ID})
if metadata:
return metadata
# Default metadata for the recommender service, focusing on aspects not directly
# managed by ModelTrainer's file-based metadata (e.g., embedding model name).
# Reranker model status, version, and its specific fine-tuning configuration
# are primarily managed by ModelTrainer and its model_metadata.json file.
return {
"_id": self.MODEL_METADATA_DOC_ID,
"embedding_model_name": EMBED_MODEL_NAME,
# Add other global operational metadata specific to RecoRecommender here if needed.
}
except Exception as e:
logger.error(f"Error retrieving model metadata: {e}")
# Return a minimal fallback to ensure core functionalities can proceed if possible
return {"_id": self.MODEL_METADATA_DOC_ID, "embedding_model_name": EMBED_MODEL_NAME}
def _update_model_metadata(self, updates: Dict) -> bool:
"""Update model metadata in MongoDB."""
try:
if self.faiss_meta_collection is None:
logger.warning("MongoDB not available. Cannot update model metadata.")
return False
result = self.faiss_meta_collection.update_one(
{"_id": self.MODEL_METADATA_DOC_ID},
{"$set": {**updates, "metadata_last_updated": datetime.now()}}, # Key changed for clarity
upsert=True
)
return result.acknowledged
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
return False
def _increment_model_version(self) -> str:
"""DEPRECATED: Reranker model versioning is handled by ModelTrainer."""
# This method appears unused and its logic conflicts with ModelTrainer's
# file-system based versioning for fine-tuned models.
logger.warning("RecoRecommender._increment_model_version() is deprecated. Reranker versioning is managed by ModelTrainer.")
# Fallback to ModelTrainer's current version if absolutely needed, but ideally this method should be removed.
return self.model_trainer.get_model_status().get("current_version", "v0")
def _needs_reembedding_batch(self, doc_ids: List[str], current_checksum: str) -> List[str]:
"""Check which documents from a batch need reembedding."""
try:
if self.faiss_meta_collection is None:
logger.warning("MongoDB not available. Assuming all documents need reembedding.")
return doc_ids
# Query for all documents in one go
metadata_docs = self.faiss_meta_collection.find(
{"_id": {"$in": doc_ids}}
)
# Create a dict for quick lookup
metadata_map = {doc["_id"]: doc.get(self.EMBEDDING_CHECKSUM_KEY) for doc in metadata_docs}
# Return IDs that need reembedding
return [
doc_id for doc_id in doc_ids
if doc_id not in metadata_map or metadata_map[doc_id] != current_checksum
]
except Exception as e:
logger.error(f"Error checking embedding status for batch: {e}")
return doc_ids # If error, assume all need reembedding
def _update_embedding_metadata(self, doc_id: str, checksum: str):
"""Update metadata for document embeddings."""
try:
self.faiss_meta_collection.update_one(
{"_id": doc_id},
{
"$set": {
self.EMBEDDING_CHECKSUM_KEY: checksum,
"last_updated": datetime.now()
}
},
upsert=True
)
except Exception as e:
logger.error(f"Error updating embedding metadata for doc {doc_id}: {e}")
def _needs_reembedding(self, doc_id: str, current_checksum: str) -> bool:
"""Check if a document needs to be reembedded."""
try:
metadata = self.faiss_meta_collection.find_one({"_id": doc_id})
if not metadata or metadata.get(self.EMBEDDING_CHECKSUM_KEY) != current_checksum:
return True
return False
except Exception as e:
logger.error(f"Error checking embedding status for doc {doc_id}: {e}")
return True
async def update_embeddings_and_index(self, force_reload_data: bool = True):
"""
Updates embeddings and the FAISS index.
If existing indexed documents have changed or the embedding model is different,
a full rebuild of the index is triggered.
Otherwise, only new documents are added incrementally.
The FAISS index file (INDEX_PATH) and metadata are saved.
"""
if force_reload_data:
logger.info("Reloading data from MongoDB for index update...")
self._load_data_from_mongo()
if self.df is None or self.df.empty:
logger.warning("DataFrame is empty. Cannot update embeddings or index.")
return
if self.embed_model is None:
logger.info("Embedding model not loaded. Loading models...")
self.load_models() # Ensure embed_model is available
logger.info("Checking for documents requiring embedding updates or additions...")
trigger_full_rebuild = False
if self.index and self.indexed_ids:
df_indexed_docs = self.df[self.df[self.id_col].isin(self.indexed_ids)]
content_map_for_indexed_ids = pd.Series(
df_indexed_docs[self.processed_content_col].values,
index=df_indexed_docs[self.id_col]
).to_dict()
for doc_id in self.indexed_ids:
if doc_id in content_map_for_indexed_ids:
current_content = content_map_for_indexed_ids[doc_id]
current_checksum = self._calculate_embedding_checksum(current_content)
if self._needs_reembedding(doc_id, current_checksum):
logger.info(f"Existing document ID {doc_id} requires re-embedding. Content/model change detected.")
trigger_full_rebuild = True
break
else:
logger.info(f"Document ID {doc_id} was in index but not in current data (deleted). Rebuild needed.")
trigger_full_rebuild = True
break
if self.index is None:
logger.info("No existing FAISS index found. A full build is required.")
trigger_full_rebuild = True
if trigger_full_rebuild:
logger.info("Triggering a full rebuild of the FAISS index.")
self.build_indexes_and_save(data_already_loaded=True) # self.df is already loaded
return
logger.info("No changes detected in existing indexed documents that require a full rebuild.")
current_indexed_ids_set = set(self.indexed_ids)
all_df_ids_set = set(self.df[self.id_col].tolist())
new_doc_ids_to_add = list(all_df_ids_set - current_indexed_ids_set)
if new_doc_ids_to_add:
logger.info(f"Found {len(new_doc_ids_to_add)} new documents to add to the index.")
new_docs_df = self.df[self.df[self.id_col].isin(new_doc_ids_to_add)].copy()
if not new_docs_df.empty:
new_embeddings, actual_new_ids_embedded = self._generate_embeddings(new_docs_df)
if new_embeddings.size > 0:
self.index.add(new_embeddings.astype(np.float32))
self.indexed_ids.extend(actual_new_ids_embedded)
logger.info(f"Added {len(actual_new_ids_embedded)} new vectors to FAISS index. Total vectors: {self.index.ntotal}.")
faiss.write_index(self.index, INDEX_PATH)
self.faiss_meta_collection.update_one(
{"_id": self.FAISS_IDS_DOC_ID},
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now(), "total_vectors": self.index.ntotal}},
upsert=True
)
logger.info("Updated FAISS index and list of indexed IDs saved.")
for doc_id in actual_new_ids_embedded:
content_for_checksum = new_docs_df[new_docs_df[self.id_col] == doc_id][self.processed_content_col].iloc[0]
checksum = self._calculate_embedding_checksum(content_for_checksum)
self._update_embedding_metadata(doc_id, checksum)
logger.info(f"Updated embedding metadata for {len(actual_new_ids_embedded)} new documents.")
else:
logger.info("No embeddings were generated for the new documents (e.g., content was empty).")
else:
logger.info("Identified new document IDs, but the corresponding DataFrame slice was empty.")
else:
logger.info("No new documents to add. Index is up-to-date with the current DataFrame.")
# Ensure index is saved if any changes (new docs added) occurred, even if not a full rebuild
# The logic above already saves if new_embeddings.size > 0.
# If no new docs and no rebuild, this point is reached.
# Redundant save check removed; saves are handled within the conditional branches above.
def fallback_to_base_model(self):
"""Switch to using the base model for inference."""
if self.base_reranker is not None:
self.reranker = self.base_reranker
logger.info("Switched to base reranker model")
else:
logger.warning("Base reranker model not available")
def use_fine_tuned_model(self):
"""Switch to using the fine-tuned model for inference."""
if self.fine_tuned_reranker is not None:
self.reranker = self.fine_tuned_reranker
logger.info("Switched to fine-tuned reranker model")
else:
logger.warning("Fine-tuned model not available, staying with current model")
def _load_data_from_mongo(self):
"""Load and preprocess data from MongoDB."""
logger.info("Loading data from MongoDB...") # Log the start of data loading from MongoDB
try: # Start a try-except block for error handling
# Fetch documents with projection
projection = { # Define which fields to retrieve from MongoDB
self.id_col: 1,
self.headline_col: 1, # Include the headline column
self.key: 1, # Include the key column
self.syn: 1, # Include the system column
self.topic_col: 1, # Include the topic column
self.taxonomy_col: 1, # Include the taxonomy column
self.property_col: 1, # Include the property column
self.seolocation_col: 1,
self.deeplink_col: 1,
self.last_updated_col: 1,
self.image_id_col: 1,
self.image_ratio_col: 1,
self.image_size_col: 1,
"_id": 0 # Exclude the default MongoDB _id field
}
cursor = mongodb.news_collection.find({}, projection) # Execute the find query on the news_collection
data = list(cursor) # Convert the cursor result to a list of dictionaries
if not data: # Check if no data was returned from MongoDB
logger.warning("No documents found in MongoDB.") # Log a warning if no documents are found
self.df = pd.DataFrame(columns=[ # Create an empty DataFrame with expected columns
self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col,
self.topic_col, self.property_col, self.processed_content_col,
self.seolocation_col, self.deeplink_col, self.last_updated_col,
self.image_id_col, self.image_ratio_col, self.image_size_col
])
return # Exit the method as there's no data to process
self.df = pd.DataFrame(data) # Convert the list of dictionaries to a pandas DataFrame
logger.info(f"Loaded {len(self.df)} documents from MongoDB.") # Log the number of documents loaded
# Data cleaning and preprocessing
required_cols = [self.id_col, self.headline_col, self.key, self.syn] # Update the list of required columns
for col in required_cols: # Iterate over the list of required columns
if col not in self.df.columns: # Check if a required column is missing in the DataFrame
raise ValueError(f"Required column '{col}' not found in MongoDB documents.") # Raise error if missing
# Handle optional columns that are primarily textual or have simple defaults
textual_optional_cols = {
self.topic_col: "N/A", # Default to "N/A" string
self.property_col: "N/A", # Default to "N/A" string
self.key: "" # Default to empty string, used for text processing
}
for col, default_val in textual_optional_cols.items():
if col not in self.df.columns:
logger.warning(f"Optional column '{col}' not found. Adding with default value '{default_val}'.") # Log a warning
self.df[col] = default_val
else:
self.df[col] = self.df[col].fillna(default_val).astype(str) # Fill NA and ensure string
# Handle other optional columns that are expected to be None if missing or all NaN
other_optional_cols = [
self.seolocation_col, self.deeplink_col, self.last_updated_col,
self.image_id_col, self.image_ratio_col, self.image_size_col
]
for col in other_optional_cols:
if col not in self.df.columns:
logger.warning(f"Optional column '{col}' not found. Adding with default value None.")
self.df[col] = None
else:
# If column exists, convert np.nan to None for object dtypes.
# For numeric dtypes, np.nan is the standard missing value representation.
if pd.api.types.is_object_dtype(self.df[col].dtype):
self.df[col] = self.df[col].replace({np.nan: None})
# Special handling for taxonomy column (list of objects)
if self.taxonomy_col not in self.df.columns:
logger.warning(f"Taxonomy column '{self.taxonomy_col}' not found. Adding with default empty list for each row.")
self.df[self.taxonomy_col] = [[] for _ in range(len(self.df))]
else:
# Apply the cleaning function to the taxonomy column
self.df[self.taxonomy_col] = self.df[self.taxonomy_col].apply(self._clean_taxonomy_cell)
# Clean data
initial_len = len(self.df) # Store the number of rows before cleaning
self.df = self.df.dropna(subset=[self.id_col, self.headline_col, self.syn]) # Drop rows where essential columns are NaN
self.df = self.df[self.df[self.headline_col].apply(lambda x: isinstance(x, str) and x.strip() != '')] # Keep rows with non-empty string headlines
self.df = self.df[self.df[self.syn].apply(lambda x: isinstance(x, str) and x.strip() != '')]
if len(self.df) < initial_len: # Check if any rows were dropped during cleaning
logger.warning(f"Dropped {initial_len - len(self.df)} rows due to missing/invalid values.") # Log the number of dropped rows
if self.df.empty: # Check if the DataFrame is empty after cleaning
logger.warning("DataFrame is empty after cleaning.") # Log a warning
self.df = pd.DataFrame(columns=[ # Re-initialize an empty DataFrame with expected columns
self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col,
self.topic_col, self.property_col, self.processed_content_col,
self.seolocation_col, self.deeplink_col, self.last_updated_col,
self.image_id_col, self.image_ratio_col, self.image_size_col
])
return # Exit if DataFrame is empty
# Improved contextual preprocessing: combine headline, synopsis, and keywords for richer embeddings
logger.info("Preprocessing and combining content (headline, synopsis, keywords, taxonomy) for contextual embeddings...")
self.df[self.headline_col] = self.df[self.headline_col].astype(str)
self.df[self.syn] = self.df[self.syn].astype(str)
self.df[self.key] = self.df[self.key].astype(str)
def combine_and_preprocess_content(row):
# Combine headline, synopsis, keywords, and taxonomy for better context
headline = row[self.headline_col].strip()
synopsis = row[self.syn].strip()
keywords = str(row[self.key]).strip() # Ensure string, though already handled
# Preprocess each part
processed_headline = self._preprocess_text(headline)
processed_synopsis = self._preprocess_text(synopsis)
processed_keywords = self._preprocess_text(keywords)
# Process taxonomy terms
taxonomy_terms_list = row[self.taxonomy_col]
processed_taxonomy_names = []
if isinstance(taxonomy_terms_list, list):
for term_obj in taxonomy_terms_list:
if isinstance(term_obj, dict) and "name" in term_obj and term_obj["name"]:
processed_taxonomy_names.append(self._preprocess_text(str(term_obj["name"])))
processed_taxonomy_string = " ".join(p_name for p_name in processed_taxonomy_names if p_name)
# Join all parts, prioritizing non-empty fields
parts = [p for p in [processed_headline, processed_synopsis, processed_keywords, processed_taxonomy_string] if p]
# A more structured combination might help models understand the different parts
structured_parts = []
if processed_headline: structured_parts.append(f"शीर्षक: {processed_headline}")
if processed_synopsis: structured_parts.append(f"सारांश: {processed_synopsis}")
if processed_keywords: structured_parts.append(f"कीवर्ड: {processed_keywords}")
if processed_taxonomy_string: structured_parts.append(f"श्रेणी: {processed_taxonomy_string}")
return " ".join(structured_parts)
self.df[self.processed_content_col] = self.df.apply(combine_and_preprocess_content, axis=1)
logger.info("Data loading and preprocessing complete.") # Log completion of data loading and preprocessing
except Exception as e: # Catch any exception during the process
# Ensure df is None or empty on failure
logger.error(f"Error loading data from MongoDB: {e}", exc_info=True) # Log the error with traceback
self.df = None # Set DataFrame to None to indicate failure
raise # Re-raise the exception
def _clean_taxonomy_cell(self, cell_value) -> list:
"""
Cleans individual cells of the taxonomy column.
Ensures that the cell content is a list, handling scalars, NaNs, and unexpected array types.
"""
if isinstance(cell_value, list):
# If it's already a list (empty or not), return it.
return cell_value
elif pd.api.types.is_scalar(cell_value) and pd.isna(cell_value):
# Handles scalar np.nan, None, pd.NA by returning an empty list.
return []
elif isinstance(cell_value, (np.ndarray, pd.Series)):
# Handles unexpected np.ndarray or pd.Series in a cell.
logger.warning(
f"Unexpected array/Series type in taxonomy column cell: {type(cell_value)}. "
f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list."
)
return []
else:
# Handles other scalar types (e.g., string, int) or unhandled types.
# These are converted to an empty list, mimicking the original lambda's behavior.
if not pd.api.types.is_scalar(cell_value): # Log if it's an unexpected non-scalar, non-list, non-array type
logger.warning(
f"Unexpected non-scalar, non-list/array type in taxonomy column cell: {type(cell_value)}. "
f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list."
)
return []
def _preprocess_text(self, text: str) -> str:
"""Normalize and tokenize Hindi text."""
"""
Normalize and tokenize text. Applies Indic normalization for Hindi text
and generic tokenization for others.
"""
if not isinstance(text, str):
logger.warning(f"Received non-string input for preprocessing: {type(text)}")
return ""
text = text.strip() # Remove leading/trailing whitespace
if not text: # Handle empty strings after stripping
return ""
try:
# Check for presence of Devanagari characters to identify Hindi text.
# This is a basic check; for more complex scenarios, consider a language detection library.
is_text = bool(re.search(r'[\u0900-\u097F]', text))
if is_text:
if self.normalizer is None:
logger.error("IndicNormalizer not initialized, but text detected. Proceeding without normalization.")
normalized_text = text # Fallback: use original text if normalizer is missing
else:
normalized_text = self.normalizer.normalize(text)
else:
# For non-Hindi text, skip Hindi-specific normalization
# logger.debug(f"Skipping Hindi normalization for non-Hindi text: {text[:50]}...") # Optional: enable if needed
normalized_text = text
tokens = indic_tokenize.trivial_tokenize(normalized_text) # Tokenize the (potentially normalized) text
return " ".join(tokens)
except Exception as e:
logger.error(f"Error during text preprocessing for text starting with '{text[:50]}...': {e}", exc_info=True)
# Fallback to returning the original text to prevent downstream errors
return text
def _generate_embeddings(self, df_subset: pd.DataFrame) -> tuple[np.ndarray, List[str]]:
"""
Generate embeddings for the processed content column of a DataFrame subset.
Returns a tuple of (embeddings_array, list_of_ids).
"""
if not self.embed_model: # Check if the embedding model is loaded
raise RuntimeError("Embedding model not loaded") # Raise an error if the model is not loaded
if df_subset.empty or self.processed_content_col not in df_subset.columns or self.id_col not in df_subset.columns:
logger.warning(
f"DataFrame subset is empty or missing required columns "
f"('{self.processed_content_col}', '{self.id_col}') for embedding generation."
)
return np.array([]), []
texts_to_embed = df_subset[self.processed_content_col].tolist()
# Ensure all texts are strings (preprocessing should ideally handle this, but as a safeguard)
texts_to_embed = [str(t) if pd.notna(t) else "" for t in texts_to_embed]
ids_to_embed = df_subset[self.id_col].tolist()
if not texts_to_embed: # If all texts became empty strings or list was initially empty
logger.warning("No non-empty texts available in the subset to generate embeddings.")
return np.array([]), []
try: # Start a try-except block for error handling
embeddings = self.embed_model.encode(texts_to_embed, show_progress_bar=False, convert_to_numpy=True) # Generate embeddings
return embeddings, ids_to_embed # Return the generated embeddings and their IDs
except Exception as e: # Catch any exception during embedding generation
logger.error(f"Error generating embeddings: {e}", exc_info=True) # Log the error with traceback
return np.array([]), [] # Return empty array and list in case of error
def _build_faiss_index(self) -> List[str]:
"""Build FAISS index from processed content and return the list of indexed IDs."""
if self.df is None or self.df.empty or self.processed_content_col not in self.df.columns:
raise ValueError("Cannot build FAISS index: Data not ready (DataFrame is None, empty, or missing processed content column)")
if self.embed_model is None:
raise RuntimeError("Cannot build FAISS index: Embedding model not loaded")
logger.info("Building FAISS index...")
if self.df.empty:
raise ValueError("DataFrame is empty, cannot build FAISS index")
# Process in batches for better memory management and parallelization
BATCH_SIZE = 1000 # Adjust based on available memory
total_docs = len(self.df)
all_embeddings = []
all_ids = []
for start_idx in range(0, total_docs, BATCH_SIZE):
end_idx = min(start_idx + BATCH_SIZE, total_docs)
batch_df = self.df.iloc[start_idx:end_idx]
# Generate embeddings for the batch
batch_embeddings, batch_ids = self._generate_embeddings(batch_df)
if batch_embeddings.size > 0:
all_embeddings.append(batch_embeddings)
all_ids.extend(batch_ids)
logger.info(f"Processed batch {start_idx//BATCH_SIZE + 1}/{(total_docs + BATCH_SIZE - 1)//BATCH_SIZE}")
if not all_embeddings:
logger.warning("No embeddings were generated. FAISS index will be empty.")
self.index = None
return []
# Concatenate all batch embeddings
embeddings = np.vstack(all_embeddings).astype(np.float32)
dimension = embeddings.shape[1]
# Initialize HNSW index with optimized parameters
hnsw_m = 32 # Number of neighbors per layer
ef_construction = 100 # Higher value = better accuracy but slower construction
self.index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT)
self.index.hnsw.efConstruction = ef_construction
# Add vectors in batches to reduce memory usage
BATCH_SIZE_INDEX = 10000 # Adjust based on available memory
for i in range(0, len(embeddings), BATCH_SIZE_INDEX):
batch = embeddings[i:i + BATCH_SIZE_INDEX]
self.index.add(batch)
logger.info(f"Added batch {i//BATCH_SIZE_INDEX + 1}/{(len(embeddings) + BATCH_SIZE_INDEX - 1)//BATCH_SIZE_INDEX} to index")
logger.info(f"FAISS index built with {self.index.ntotal} vectors")
return all_ids
def _load_faiss_index_and_ids(self):
"""Load FAISS index and corresponding IDs from files."""
if not os.path.exists(INDEX_PATH): # Check if the FAISS index file exists at the specified path
raise FileNotFoundError(f"FAISS index file not found at {INDEX_PATH}") # Raise error if file not found
logger.info(f"Loading FAISS index from: {INDEX_PATH}") # Log the path from which the index is being loaded
try: # Start a try-except block for error handling
self.index = faiss.read_index(INDEX_PATH) # Read the FAISS index from the file
logger.info(f"FAISS index loaded with {self.index.ntotal} vectors") # Log the number of vectors in the loaded index
except Exception as e: # Catch any exception during index loading
logger.error(f"Error loading FAISS index: {e}", exc_info=True) # Log the error with traceback
self.index = None # Ensure index is None on failure
raise # Re-raise the exception
# Try to load IDs from MongoDB if available
if self.faiss_meta_collection is not None:
logger.info(f"Loading FAISS index IDs from MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'")
try:
ids_document = self.faiss_meta_collection.find_one({"_id": self.FAISS_IDS_DOC_ID})
if ids_document and "ids" in ids_document:
self.indexed_ids = ids_document["ids"]
logger.info(f"Loaded {len(self.indexed_ids)} indexed IDs from MongoDB.")
# Basic consistency check
if self.index and self.index.ntotal != len(self.indexed_ids):
logger.warning(
f"FAISS index vector count ({self.index.ntotal}) "
f"does not match loaded ID count from MongoDB ({len(self.indexed_ids)}). "
"Index might be inconsistent. Consider rebuilding."
)
else:
logger.warning(f"FAISS index IDs document not found in MongoDB or 'ids' field missing. Will attempt to build if necessary.")
self.indexed_ids = [] # Initialize as empty if not found
except Exception as e:
logger.error(f"Error loading FAISS index IDs from MongoDB: {e}", exc_info=True)
self.indexed_ids = [] # Ensure IDs list is empty on failure
# We don't re-raise here, as load_components will decide if a rebuild is needed
# based on whether self.indexed_ids is populated.
else:
logger.warning("MongoDB not available. Cannot load indexed IDs. Operating with empty ID list.")
self.indexed_ids = []
def build_indexes_and_save(self, data_already_loaded: bool = False):
"""
Load data (if not already loaded), build FAISS index from current self.df, and save.
Assumes self.df is populated if data_already_loaded is True.
"""
logger.info("Starting index building process...") # Log the start of the index building and saving process
try: # Start a try-except block for error handling
if not data_already_loaded:
self._load_data_from_mongo() # Load data from MongoDB if not already loaded
if self.df is None or self.df.empty: # Check if DataFrame is None or empty after loading
raise ValueError("Data is empty. Cannot build index.") # Raise error if data is not loaded properly
if self.embed_model is None: # Ensure models are loaded before building index
self.load_models()
# Build the FAISS index using the loaded data and models
# This method now returns the list of IDs that were indexed
indexed_ids = self._build_faiss_index()
self.indexed_ids = indexed_ids # Store the list of IDs
# Save FAISS index
logger.info(f"Saving FAISS index to: {INDEX_PATH}") # Log the path where the index will be saved
index_dir = os.path.dirname(INDEX_PATH) # Get the directory part of the index path
if index_dir and not os.path.exists(index_dir): # If the directory exists and is not empty string
os.makedirs(index_dir) # Create the directory if it doesn't exist
# Save the index and the corresponding IDs
faiss.write_index(self.index, INDEX_PATH)
# Save indexed_ids to MongoDB
logger.info(f"Saving FAISS index IDs to MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'")
try:
if self.faiss_meta_collection is not None:
self.faiss_meta_collection.update_one(
{"_id": self.FAISS_IDS_DOC_ID},
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}},
upsert=True
)
logger.info(f"Saved {len(self.indexed_ids)} indexed IDs to MongoDB.")
else:
logger.warning("MongoDB not available. Skipping indexed IDs save.")
except Exception as e:
logger.error(f"Error saving FAISS index IDs to MongoDB: {e}", exc_info=True)
# Don't raise here as the FAISS index was saved successfully
# The IDs can be regenerated if needed
logger.info("Index building and saving complete") # Log successful completion
except Exception as e: # Catch any exception during the process
logger.error(f"Error during index building: {e}", exc_info=True) # Log the error with traceback
raise # Re-raise the exception
def load_components(self):
"""Load all components (models, data, index)."""
logger.info("Loading components...") # Log the start of component loading
try: # Start a try-except block for error handling
self.load_models() # Load all machine learning models
# Try to load data from MongoDB, but handle failures gracefully
try:
self._load_data_from_mongo() # Load data from MongoDB and preprocess it
logger.info("Successfully loaded data from MongoDB")
except Exception as mongo_error:
logger.warning(f"Failed to load data from MongoDB: {mongo_error}")
logger.info("Attempting to work with existing FAISS index without MongoDB data...")
# Set df to None to indicate no MongoDB data is available
self.df = None
# Try to load FAISS index and IDs
try:
self._load_faiss_index_and_ids() # Tries to load .bin and IDs from Mongo.
# Sets self.index and self.indexed_ids.
# self.indexed_ids will be [] if Mongo data for IDs is missing.
# Raises FileNotFoundError if .bin (INDEX_PATH) is missing.
if self.index and self.index.ntotal > 0:
logger.info(f"FAISS index loaded successfully with {self.index.ntotal} vectors")
# If we have MongoDB data, proceed with normal logic
if self.df is not None and not self.df.empty:
# Consistency check and incremental update logic
if not self.indexed_ids:
logger.warning("FAISS index file loaded, but no corresponding IDs found in MongoDB. Rebuilding for consistency.")
self.build_indexes_and_save(data_already_loaded=True)
else:
logger.info("Existing FAISS index and IDs loaded from storage.")
# Proceed with incremental update logic
current_df_ids = set(self.df[self.id_col].tolist())
indexed_ids_set = set(self.indexed_ids)
new_ids_to_add = list(current_df_ids - indexed_ids_set)
if new_ids_to_add:
logger.info(f"Found {len(new_ids_to_add)} new documents to add to the index.")
new_docs_df = self.df[self.df[self.id_col].isin(new_ids_to_add)].copy()
new_embeddings, new_doc_ids_added = self._generate_embeddings(new_docs_df)
if new_embeddings.size > 0:
self.index.add(new_embeddings.astype(np.float32))
self.indexed_ids.extend(new_doc_ids_added)
logger.info(f"Added {len(new_doc_ids_added)} new vectors to FAISS index. Total vectors: {self.index.ntotal}")
# Save the updated FAISS index
faiss.write_index(self.index, INDEX_PATH)
# Try to save the updated IDs to MongoDB, but don't fail if it doesn't work
try:
self.faiss_meta_collection.update_one(
{"_id": self.FAISS_IDS_DOC_ID},
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}},
upsert=True
)
logger.info("Updated FAISS index and IDs saved to MongoDB.")
except Exception as e:
logger.warning(f"Could not save IDs to MongoDB: {e}")
else:
logger.info("No new documents found to add to the index. Index is up-to-date.")
else:
# No MongoDB data available, but we have a FAISS index
logger.info("FAISS index available but no MongoDB data. Operating in limited mode.")
if not self.indexed_ids:
logger.warning("No indexed IDs available. Some functionality may be limited.")
else:
# This case handles if self.index is None (FileNotFoundError caught below)
# or if index was loaded but empty and no IDs from Mongo.
if self.df is not None and not self.df.empty:
logger.info("FAISS index and/or IDs not found or empty. Building new index.")
self.build_indexes_and_save(data_already_loaded=True)
else:
logger.warning("No data available (neither MongoDB nor FAISS index). Cannot build index.")
except FileNotFoundError: # This means INDEX_PATH (.bin file) was not found.
logger.warning(f"FAISS index file ({INDEX_PATH}) not found.")
if self.df is not None and not self.df.empty:
logger.info("Building index from scratch.")
self.build_indexes_and_save(data_already_loaded=True)
else:
logger.error("Cannot build index: no data available.")
except Exception as e:
logger.error(f"Error loading FAISS index: {e}", exc_info=True)
if self.df is not None and not self.df.empty:
logger.info("Attempting to rebuild index due to loading error.")
self.build_indexes_and_save(data_already_loaded=True)
else:
logger.error("Cannot rebuild index: no data available.")
logger.info("Components loaded successfully") # Log successful loading of all components
except Exception as e: # Catch any exception during component loading
logger.error(f"Error loading components: {e}", exc_info=True) # Log the error with traceback
raise # Re-raise the exception
def get_recommendations(
self, # Added self parameter
query: str,
k: int = DEFAULT_K, # Number of recommendations to return, defaults to DEFAULT_K from config
similarity_threshold: float = SIMILARITY_THRESHOLD # Similarity threshold, defaults to SIMILARITY_THRESHOLD from config
) -> Dict:
"""
Get recommendations for a query.
Returns a dictionary with retrieved documents and generated response.
"""
# Check prerequisites
if self.df is None or self.df.empty: # Check if the DataFrame (content data) is loaded
raise HTTPException(status_code=503, detail="Recommender data not available") # Raise 503 error if data is missing
if not all([self.index, self.embed_model, self.reranker, self.generator]): # Check if all essential components are loaded
missing = [ # List comprehension to find names of missing components
name for name, component in [ # Iterate through component names and their instances
("FAISS index", self.index), # FAISS index
("Embedding model", self.embed_model), # Embedding model
("Reranker model", self.reranker), # Reranker model
("Generator model", self.generator) # Generator model
] if component is None # Check if the component is None (not loaded)
]
raise HTTPException( # Raise 503 error if components are missing
status_code=503, # HTTP status code for Service Unavailable
detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}" # Error detail listing missing components
)
logger.info(f"Processing recommendation request: query='{query}', k={k}") # Log the incoming recommendation request
# Preprocess the query
processed_query = self._preprocess_text(query) # Preprocess the input Hindi query
query_embedding, _ = self._generate_embeddings(pd.DataFrame({self.processed_content_col: [processed_query], self.id_col: ["query"]})) # Generate embedding for the processed query
if query_embedding.size == 0: # Check if the query embedding is empty (e.g., generation failed)
logger.warning("Query embedding is empty.") # Log a warning
return { # Return an empty result
"retrieved_documents": [], # Empty list of documents
"generated_response": "No recommendations found. (Query embedding failed)" # Informative message
}
# Retrieve candidates from FAISS
num_candidates = max(k * CANDIDATE_MULTIPLIER, k) # Determine number of candidates to fetch (k * multiplier, or at least k)
try: # Start a try-except block for FAISS search
D, I = self.index.search(query_embedding.astype(np.float32), num_candidates) # Search FAISS index (D=distances, I=indices)
except Exception as e: # Catch any exception during FAISS search
logger.error(f"Error during FAISS search: {e}", exc_info=True) # Log the error with traceback
return { # Return an empty result
"retrieved_documents": [], # Empty list of documents
"generated_response": "No recommendations found. (FAISS search failed)" # Informative message
}
# Process FAISS results
retrieved_faiss_indices = I[0]
retrieved_faiss_scores = D[0]
# Map FAISS indices to original document IDs and collect scores
valid_candidate_data = []
for faiss_idx, score in zip(retrieved_faiss_indices, retrieved_faiss_scores):
# Ensure FAISS index is valid and within bounds of self.indexed_ids
if faiss_idx != -1 and faiss_idx < len(self.indexed_ids):
valid_candidate_data.append({
"original_id": self.indexed_ids[faiss_idx], # Actual ID of the item
"faiss_score": score
})
elif faiss_idx != -1: # Log if faiss_idx is valid but out of bounds for indexed_ids
logger.warning(
f"FAISS index {faiss_idx} is out of bounds for self.indexed_ids (len: {len(self.indexed_ids)}). Skipping."
)
if not valid_candidate_data:
logger.info(f"No valid candidates found from FAISS/ID mapping for query '{query}'.")
return {
"retrieved_documents": [],
"generated_response": f"No recommendations found for '{query}' (no FAISS results or ID mapping issue)."
}
# Create a DataFrame from valid FAISS candidates (contains 'original_id' and 'faiss_score')
faiss_candidates_df = pd.DataFrame(valid_candidate_data)
# Fetch full candidate details from self.df by merging
# This uses the 'original_id' (which are actual item IDs) to robustly fetch data
# and preserves the order from FAISS retrieval.
candidates = pd.merge(
faiss_candidates_df, # Left DataFrame (dictates order and includes faiss_score)
self.df, # Right DataFrame (provides full item details)
left_on="original_id",# Key in faiss_candidates_df
right_on=self.id_col, # Key in self.df
how="inner" # Ensures only items present in both are kept
)
# If 'original_id' column is different from self.id_col and still exists, drop it as it's redundant
if "original_id" in candidates.columns and "original_id" != self.id_col:
candidates = candidates.drop(columns=["original_id"])
# Ensure the ID column is of a consistent type if needed, though it should match indexed_ids type
# candidates[self.id_col] = candidates[self.id_col].astype(str) # Example if IDs need to be strings
# Filter out exact matches with query (case-insensitive, strip spaces)
candidates = candidates[
(candidates[self.headline_col].str.strip().str.lower() != query.strip().lower()) &
(candidates[self.syn].str.strip().str.lower() != query.strip().lower())
]
if candidates.empty:
logger.info(f"No candidates left after filtering exact query matches for query '{query}'.")
return {"retrieved_documents": [], "generated_response": f"No distinct recommendations found for '{query}'."}
candidates = candidates.drop_duplicates(subset=[self.syn]).copy() # Use .copy() after selection/drop_duplicates
if candidates.empty:
logger.info(f"No candidates left after dropping duplicates for query '{query}'.")
return {"retrieved_documents": [], "generated_response": f"No unique recommendations found for '{query}'."}
# Rerank using cross-encoder
#rerank_pairs = [(query, str(row[self.headline_col])) for _, row in candidates.iterrows()] # Create pairs of (query, candidate_headline) for reranking
rerank_pairs = [(query, str(row[self.syn])) for _, row in candidates.iterrows()]
if rerank_pairs: # Check if there are any candidate pairs to rerank
rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False) # Predict reranking scores
logger.info(f"Raw rerank scores for query '{query}': {rerank_scores.tolist()}") # Log raw scores
candidates["rerank_score"] = rerank_scores # Add rerank scores as a new column
candidates = candidates.sort_values("rerank_score", ascending=False) # Sort candidates by rerank score in descending order
logger.debug(f"Top candidates before thresholding (query='{query}', threshold={similarity_threshold}):")
for _, row in candidates.head().iterrows(): # Log top few candidates before filtering
logger.debug(f" ID: {row[self.id_col]}, Synopsis: {str(row[self.syn])[:50]}..., Rerank Score: {row['rerank_score']:.4f}")
#candidates = candidates[candidates["rerank_score"] >= similarity_threshold]
candidates = candidates[candidates["rerank_score"] >= similarity_threshold]
logger.info(f"Number of candidates after applying similarity_threshold ({similarity_threshold}): {len(candidates)}")
else: # If no pairs to rerank (e.g., all candidates were filtered out)
logger.info(f"No candidate pairs to rerank for query '{query}'.")
candidates["rerank_score"] = 0.0 # Assign a default rerank score of 0.0
# Select top-k
top_candidates = candidates.head(k) # Select the top-k candidates after reranking
# Prepare output
retrieved_documents = [] # Initialize an empty list to store formatted retrieved documents
for _, row in top_candidates.iterrows(): # Iterate through the top-k candidate rows
taxonomy_data = row[self.taxonomy_col]
taxonomy_names = []
if isinstance(taxonomy_data, list):
for term_obj in taxonomy_data:
if isinstance(term_obj, dict) and term_obj.get("name"): # Check key exists and has a value
taxonomy_names.append(str(term_obj["name"]))
retrieved_documents.append({ # Create a dictionary for each retrieved document
"id": row[self.id_col], # Document ID
"hl": str(row[self.headline_col]), # Document headline
"synopsis": row[self.syn], # Document primary content (synopsis)
"keywords": row[self.key], # Document secondary content (keywords)
"type": row.get(self.topic_col, None), # Document topic (or None if not available)
"taxonomy": taxonomy_names, # List of taxonomy names
"score": row["rerank_score"], # Rerank score of the document
"seolocation": row.get(self.seolocation_col, None),
"dl": row.get(self.deeplink_col, None),
"lu": row.get(self.last_updated_col, None),
"imageid": row.get(self.image_id_col, None),
"imgratio": row.get(self.image_ratio_col, None),
"imgsize": row.get(self.image_size_col, None)
})
# Optionally, generate a response using the generator model (not implemented here)
generated_response = f"Top {len(retrieved_documents)} recommendations for '{query}'." # Create a simple generated response string
return { # Return the final result as a dictionary
"retrieved_documents": retrieved_documents, # List of retrieved documents
"generated_response": generated_response # Generated textual response
}
def _format_retrieved_documents(self, documents_df: pd.DataFrame) -> List[Dict]:
"""Helper function to format DataFrame rows into a list of document dictionaries."""
retrieved_documents = []
for _, row in documents_df.iterrows():
taxonomy_data = row[self.taxonomy_col]
taxonomy_names = []
if isinstance(taxonomy_data, list):
for term_obj in taxonomy_data:
if isinstance(term_obj, dict) and term_obj.get("name"):
taxonomy_names.append(str(term_obj["name"]))
retrieved_documents.append({
"id": row[self.id_col],
"hl": str(row[self.headline_col]),
"synopsis": str(row.get(self.syn, "")), # Ensure synopsis is string
"keywords": str(row.get(self.key, "")), # Ensure keywords is string
"type": row.get(self.topic_col, None),
"taxonomy": taxonomy_names,
"score": row.get("rerank_score", 0.0), # Use .get for safety if rerank_score might be missing
"seolocation": row.get(self.seolocation_col, None),
"dl": row.get(self.deeplink_col, None),
"lu": row.get(self.last_updated_col, None),
"imageid": row.get(self.image_id_col, None),
"imgratio": row.get(self.image_ratio_col, None),
"imgsize": row.get(self.image_size_col, None)
})
return retrieved_documents
def get_recommendations_by_id(
self,
msid: str,
k: int = DEFAULT_K,
similarity_threshold: float = SIMILARITY_THRESHOLD
) -> Dict:
"""
Get recommendations based on a given item ID (msid).
Finds items similar to the item identified by msid.
Returns a dictionary with retrieved documents.
"""
# Check prerequisites
if self.df is None or self.df.empty:
logger.error("Recommender data not available for get_recommendations_by_id.")
raise HTTPException(status_code=503, detail="Recommender data not available")
# Generator model is not strictly needed for this item-to-item recommendation path
if not all([self.index, self.embed_model, self.reranker]):
missing = [
name for name, component in [
("FAISS index", self.index),
("Embedding model", self.embed_model),
("Reranker model", self.reranker),
] if component is None
]
logger.error(f"Recommender not fully initialized for get_recommendations_by_id. Missing: {', '.join(missing)}")
raise HTTPException(
status_code=503,
detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}"
)
logger.info(f"Processing recommendation request for item ID: '{msid}', k={k}")
# Find the source item in the DataFrame
source_item_row = self.df[self.df[self.id_col] == msid]
if source_item_row.empty:
logger.warning(f"Item with ID '{msid}' not found in DataFrame.")
# Ensure the response structure matches what RecommendationResponse expects
# even on failure, to avoid FastAPI server errors during response model validation.
return {"retrieved_documents": [], "generated_response": f"Item with ID '{msid}' not found."}
# Or, if you prefer to let the route handler catch this as a 404:
# raise HTTPException(status_code=404, detail=f"Item with ID '{msid}' not found")
source_item = source_item_row.iloc[0]
source_item_content_for_embedding = source_item[self.processed_content_col]
source_item_content_for_reranking = str(source_item[self.syn]) # Using synopsis for reranker query
# Generate embedding for the source item's content
item_embedding, _ = self._generate_embeddings(
pd.DataFrame({
self.processed_content_col: [source_item_content_for_embedding],
self.id_col: [msid] # Dummy ID, as _generate_embeddings expects it
})
)
if item_embedding.size == 0:
logger.warning(f"Embedding generation failed for item ID '{msid}'.")
return {"retrieved_documents": [], "generated_response": "No recommendations found (source item embedding failed)"}
# Retrieve candidates from FAISS: fetch k+1 (or more with multiplier) to account for filtering source item
num_candidates_to_fetch = max((k + 1) * CANDIDATE_MULTIPLIER, k + 1)
try:
D, I = self.index.search(item_embedding.astype(np.float32), num_candidates_to_fetch)
except Exception as e:
logger.error(f"Error during FAISS search for item ID '{msid}': {e}", exc_info=True)
return {"retrieved_documents": [], "generated_response": "No recommendations found (FAISS search failed)"}
candidate_faiss_indices = I[0]
# candidate_scores = D[0] # FAISS scores, can be used if needed
valid_mask = candidate_faiss_indices != -1
candidate_faiss_indices = candidate_faiss_indices[valid_mask]
if len(candidate_faiss_indices) == 0:
logger.info(f"No candidates found from FAISS for item ID '{msid}'.")
return {"retrieved_documents": [], "generated_response": f"No similar items found for ID '{msid}'."}
candidate_original_ids = [self.indexed_ids[i] for i in candidate_faiss_indices if i < len(self.indexed_ids)]
# Fetch candidates from the main DataFrame, excluding the source item itself
candidates_df = self.df[self.df[self.id_col].isin(candidate_original_ids) & (self.df[self.id_col] != msid)].copy()
candidates_df = candidates_df.drop_duplicates(subset=[self.syn]) # Avoid duplicate content
if candidates_df.empty:
logger.info(f"No candidates left after filtering for item ID '{msid}'.")
return {"retrieved_documents": [], "generated_response": f"No other similar items found for ID '{msid}'."}
# Rerank using cross-encoder
rerank_pairs = [(source_item_content_for_reranking, str(row[self.syn])) for _, row in candidates_df.iterrows()]
if rerank_pairs:
rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False)
candidates_df["rerank_score"] = rerank_scores
candidates_df = candidates_df.sort_values("rerank_score", ascending=False)
candidates_df = candidates_df[candidates_df["rerank_score"] >= similarity_threshold]
else:
candidates_df["rerank_score"] = 0.0 # Default score if no pairs or reranking skipped
top_candidates = candidates_df.head(k)
retrieved_documents = self._format_retrieved_documents(top_candidates)
generated_response = f"Top {len(retrieved_documents)} recommendations similar to item ID '{msid}'."
if not retrieved_documents:
generated_response = f"No recommendations found similar to item ID '{msid}'."
return {"retrieved_documents": retrieved_documents, "generated_response": generated_response}
def prepare_reranker_training_data_from_new_feedback_format(self, user_id: str, training_event_details: Dict) -> List[Dict]:
"""
Prepares training data for the reranker model from a user's feedback document.
Generates both positive and negative training samples using semantic similarity
based negative sampling for better model discrimination.
"""
if self.df is None or self.df.empty:
logger.warning(f"User {user_id}: DataFrame not loaded. Cannot prepare training data.")
return []
training_samples = []
query_msid = training_event_details.get("query_msid")
positive_msids_list = training_event_details.get("positive_msids")
if not query_msid or not isinstance(query_msid, str):
logger.warning(f"User {user_id}: 'query_msid' missing or invalid in training_event_details. Details: {str(training_event_details)[:200]}")
return []
if not isinstance(positive_msids_list, list) or not positive_msids_list:
logger.warning(f"User {user_id}: 'positive_msids' field missing, not a list, or empty in training_event_details for query_msid '{query_msid}'. Details: {str(training_event_details)[:200]}")
return []
source_item_row = self.df[self.df[self.id_col] == query_msid]
if source_item_row.empty:
logger.warning(f"User {user_id}: Query item (msid: {query_msid}) for training data not found in DataFrame.")
return []
query_text = str(source_item_row.iloc[0].get(self.syn, "")).strip()
if not query_text:
logger.warning(f"User {user_id}: Query text (synopsis) is empty for msid {query_msid}. Skipping training sample generation for this event.")
return []
# Process positive samples
positive_samples = []
for positive_msid in positive_msids_list:
if not isinstance(positive_msid, str) or not positive_msid.strip():
continue
clicked_item_row = self.df[self.df[self.id_col] == positive_msid]
if not clicked_item_row.empty:
candidate_text_positive = str(clicked_item_row.iloc[0].get(self.syn, "")).strip()
if candidate_text_positive:
positive_samples.append({
"query_text": query_text,
"candidate_text": candidate_text_positive,
"label": 1.0,
"msid": positive_msid
})
if not positive_samples:
logger.warning(f"User {user_id}: No valid positive samples found for query_msid {query_msid}")
return []
# Generate negative samples through semantic similarity based sampling
num_negatives_per_positive = 5 # Increased for better training
all_msids = set(self.df[self.id_col].tolist())
positive_msids_set = set(p["msid"] for p in positive_samples)
# Get query embedding for semantic search
query_embedding, _ = self._generate_embeddings(
pd.DataFrame({
self.processed_content_col: [query_text],
self.id_col: ["temp_query"]
})
)
if query_embedding.size > 0:
# Get semantically similar candidates (harder negatives)
D, I = self.index.search(query_embedding.astype(np.float32), k=50) # Get more candidates
candidate_indices = I[0]
candidate_msids = [
self.indexed_ids[idx] for idx in candidate_indices
if idx != -1 and idx < len(self.indexed_ids)
]
# Filter out positives and query item
negative_candidates = [
msid for msid in candidate_msids
if msid not in positive_msids_set and msid != query_msid
]
import random
for pos_sample in positive_samples:
# Mix of hard and random negatives
num_hard_negatives = min(3, len(negative_candidates))
num_random_negatives = num_negatives_per_positive - num_hard_negatives
# Select hard negatives (semantically similar)
hard_negatives = negative_candidates[:num_hard_negatives]
# Select random negatives from remaining pool
remaining_candidates = list(all_msids - positive_msids_set - set(hard_negatives) - {query_msid})
random_negatives = random.sample(remaining_candidates, num_random_negatives)
# Combine hard and random negatives
selected_negatives = hard_negatives + random_negatives
for neg_msid in selected_negatives:
neg_item_row = self.df[self.df[self.id_col] == neg_msid]
if not neg_item_row.empty:
candidate_text_negative = str(neg_item_row.iloc[0].get(self.syn, "")).strip()
if candidate_text_negative:
training_samples.append({
"query_text": pos_sample["query_text"],
"candidate_text": candidate_text_negative,
"label": 0.0,
"msid": neg_msid
})
# Add positive samples to final training data
training_samples.extend(positive_samples)
if training_samples:
logger.info(f"User {user_id}: Prepared {len(training_samples)} training samples ({len(positive_samples)} positive, {len(training_samples)-len(positive_samples)} negative) from the interaction event (query_msid: {query_msid}).")
return training_samples
def _log_training_data_for_refinement(self, training_data: List[Dict]):
"""Save training data for future fine-tuning."""
if not training_data:
logger.info("No training data provided for saving.")
return
self.model_trainer.prepare_training_data(training_data)
async def check_and_trigger_fine_tuning(self):
"""
Check if fine-tuning should be triggered based on conditions and start if needed.
Returns True if fine-tuning was triggered, False otherwise.
"""
try:
if not self.model_trainer.check_training_conditions():
return False
# Start fine-tuning process in background
asyncio.create_task(self._run_fine_tuning_process())
return True
except Exception as e:
logger.error(f"Error in check_and_trigger_fine_tuning: {e}")
return False
async def _run_fine_tuning_process(self):
"""Run the fine-tuning process in the background."""
try:
# Start fine-tuning
logger.info("Starting fine-tuning process")
# Run fine-tuning using the model trainer
new_version = await asyncio.to_thread(self.model_trainer.fine_tune)
if new_version:
# Load and validate the fine-tuned model before deploying
model_path = str(self.model_trainer.get_model_path(new_version))
if os.path.exists(model_path):
# Load the new model for validation
# This is a synchronous operation, potentially okay if quick,
# but could be moved to to_thread if model loading is slow.
new_model = await asyncio.to_thread(CrossEncoder, model_path, device=self.device)
# Validate model performance
validation_passed = await self._validate_fine_tuned_model(new_model)
if validation_passed:
logger.info(f"Fine-tuned model validation passed. Deploying version: {new_version}")
self.fine_tuned_reranker = new_model # new_model is already on self.device
self.reranker = self.fine_tuned_reranker # Switch active reranker
# Update embeddings and index if needed
# update_embeddings_and_index is synchronous and can be long
await asyncio.to_thread(self.update_embeddings_and_index)
logger.info(f"Fine-tuning process completed successfully. New version: {new_version}")
else:
logger.warning("Fine-tuned model validation failed. Keeping current model.")
self.reranker = self.base_reranker
else:
logger.error(f"Fine-tuned model file not found at {model_path}")
self.reranker = self.base_reranker
else:
logger.error("Fine-tuning process failed")
self.reranker = self.base_reranker
except Exception as e:
logger.error(f"Error during fine-tuning process: {e}")
self.reranker = self.base_reranker
async def _validate_fine_tuned_model(self, new_model: CrossEncoder) -> bool:
"""
Validate the fine-tuned model's performance before deployment.
Uses multiple metrics for a more comprehensive evaluation.
Returns True if validation passes, False otherwise.
"""
try:
# Get a sample of validation data
validation_data = self.model_trainer.get_validation_data()
if not validation_data:
logger.warning("No validation data available")
return False
# Initialize metrics
base_metrics = {
"true_positives": 0,
"false_positives": 0,
"true_negatives": 0,
"false_negatives": 0,
"scores": []
}
new_metrics = {
"true_positives": 0,
"false_positives": 0,
"true_negatives": 0,
"false_negatives": 0,
"scores": []
}
# Evaluate both models on validation data
for sample in validation_data:
query = sample["query_text"]
candidate = sample["candidate_text"]
label = float(sample["label"])
# Get predictions from both models
# predict is synchronous and CPU/GPU bound
base_pred_array = await asyncio.to_thread(self.base_reranker.predict, [(query, candidate)])
base_pred = base_pred_array[0]
new_pred_array = await asyncio.to_thread(new_model.predict, [(query, candidate)])
new_pred = new_pred_array[0]
# Update metrics for base model
base_metrics["scores"].append(base_pred)
if label == 1.0:
if base_pred >= 0.5:
base_metrics["true_positives"] += 1
else:
base_metrics["false_negatives"] += 1
else:
if base_pred >= 0.5:
base_metrics["false_positives"] += 1
else:
base_metrics["true_negatives"] += 1
# Update metrics for new model
new_metrics["scores"].append(new_pred)
if label == 1.0:
if new_pred >= 0.5:
new_metrics["true_positives"] += 1
else:
new_metrics["false_negatives"] += 1
else:
if new_pred >= 0.5:
new_metrics["false_positives"] += 1
else:
new_metrics["true_negatives"] += 1
if not base_metrics["scores"] or not new_metrics["scores"]:
logger.warning("No predictions generated during validation")
return False
# Calculate metrics for both models
def calculate_model_metrics(metrics):
tp = metrics["true_positives"]
fp = metrics["false_positives"]
tn = metrics["true_negatives"]
fn = metrics["false_negatives"]
# Prevent division by zero
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"accuracy": accuracy
}
base_performance = calculate_model_metrics(base_metrics)
new_performance = calculate_model_metrics(new_metrics)
# Log detailed performance comparison
logger.info("Model validation results:")
logger.info(f"Base model metrics: {base_performance}")
logger.info(f"Fine-tuned model metrics: {new_performance}")
# More lenient validation criteria
min_relative_improvement = 0.001 # 0.1% minimum relative improvement
min_absolute_f1 = 0.3 # Lower minimum F1 score required
# Check if new model shows improvement in any metric
f1_improvement = new_performance["f1"] - base_performance["f1"]
precision_improvement = new_performance["precision"] - base_performance["precision"]
recall_improvement = new_performance["recall"] - base_performance["recall"]
accuracy_improvement = new_performance["accuracy"] - base_performance["accuracy"]
# Calculate relative improvements
relative_f1_imp = f1_improvement / base_performance["f1"] if base_performance["f1"] > 0 else float('inf')
relative_prec_imp = precision_improvement / base_performance["precision"] if base_performance["precision"] > 0 else float('inf')
relative_recall_imp = recall_improvement / base_performance["recall"] if base_performance["recall"] > 0 else float('inf')
relative_acc_imp = accuracy_improvement / base_performance["accuracy"] if base_performance["accuracy"] > 0 else float('inf')
# Model passes validation if it shows improvement in any metric and meets minimum F1
if new_performance["f1"] >= min_absolute_f1 and (
relative_f1_imp >= min_relative_improvement or
relative_prec_imp >= min_relative_improvement or
relative_recall_imp >= min_relative_improvement or
relative_acc_imp >= min_relative_improvement
):
logger.info(
f"Fine-tuned model shows improvement. Metrics changes:\n"
f"F1: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n"
f"Precision: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n"
f"Recall: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n"
f"Accuracy: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})"
)
return True
else:
logger.warning(
f"Fine-tuned model does not meet improvement criteria.\n"
f"F1 change: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n"
f"Precision change: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n"
f"Recall change: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n"
f"Accuracy change: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})"
)
return False
except Exception as e:
logger.error(f"Error during model validation: {e}")
return False
def reload_fine_tuned_model(self):
"""
Reload the fine-tuned reranker model from disk and switch to it if available.
"""
self._load_fine_tuned_model()
logger.info("Reloaded fine-tuned reranker model (if available).")
def get_recommendations_summary(self, msid: str, k: int = DEFAULT_K, summary: bool = True, smart_tip: bool = True):
"""
Synchronous wrapper for recommendations with summary and smart tip. This is a simplified version for Gradio or direct calls.
"""
# Get base recommendations by msid
recommendations_data = self.get_recommendations_by_id(msid, k)
if not recommendations_data or "retrieved_documents" not in recommendations_data:
return {
"generated_response": f"No recommendations found for item ID '{msid}'.",
"retrieved_documents": []
}
retrieved_docs = recommendations_data.get("retrieved_documents", [])
if not retrieved_docs:
return recommendations_data
# Fetch article details from MongoDB
from src.database.mongodb import mongodb
doc_ids_to_fetch = [doc["id"] for doc in retrieved_docs if doc.get("id")]
articles_details_map = {}
if doc_ids_to_fetch and (summary or smart_tip):
projection = {"_id": 0, "id": 1}
if summary:
projection.update({"story": 1, "syn": 1})
if smart_tip:
projection.update({"seolocation": 1, "tn": 1, "hl": 1})
fetched_articles_list = list(mongodb.news_collection.find({"id": {"$in": doc_ids_to_fetch}}, projection))
for article in fetched_articles_list:
if article.get("id"):
if summary and not article.get("story") and article.get("syn"):
article["story"] = article["syn"]
articles_details_map[article["id"]] = article
# Helper functions for summary and smart tip
def _generate_summary(article_data):
try:
from src.test_summarize import get_summary_points
story = article_data.get("story", "")
if not story:
return None
summary_points = get_summary_points(story)
if isinstance(summary_points, list):
return " ".join(summary_points) if summary_points else None
elif isinstance(summary_points, str):
return summary_points if summary_points.strip() else None
return None
except Exception:
return None
def _generate_smart_tip(article_data):
seolocation = article_data.get("seolocation")
title = article_data.get("tn")
headline = article_data.get("hl")
if not all([seolocation, title, headline]):
return None
# Find related articles
topic = title.lower() if title else ""
query = {}
if topic:
query["$or"] = [
{"tn": {"$regex": topic, "$options": "i"}},
{"hl": {"$regex": topic, "$options": "i"}}
]
if article_data.get("id"):
query["id"] = {"$ne": article_data["id"]}
related_articles = list(mongodb.news_collection.find(query, {"hl": 1, "seolocation": 1, "tn": 1, "_id": 0}).limit(3))
suggestions = []
for rel_article in related_articles:
if rel_article.get("hl") and rel_article.get("seolocation"):
suggestions.append({
"label": rel_article.get("hl", ""),
"url": rel_article.get("seolocation", "")
})
if not suggestions:
suggestions = [{
"label": headline,
"url": seolocation
}]
return {
"title": f"\U0001F50D Smart Tip: {title}",
"description": "You might also be interested in:",
"suggestions": suggestions
}
# Process each document
processed_documents = []
for doc in retrieved_docs:
article_data = articles_details_map.get(doc.get("id"))
if summary and article_data:
doc["summary"] = _generate_summary(article_data)
if smart_tip and article_data:
doc["smart_tip"] = _generate_smart_tip(article_data)
processed_documents.append(doc)
recommendations_data["retrieved_documents"] = processed_documents
return recommendations_data
def get_recommendations_user_feedback(self, user_id: str, msid: str, clicked_msid: str, k: int = DEFAULT_K):
"""
Synchronous wrapper for user feedback recommendations. Returns recommendations based on clicked items.
"""
# clicked_msid can be a comma-separated string
actual_clicked_msids = [s.strip() for s in clicked_msid.split(',') if s.strip()]
combined_recommendations_docs = []
seen_recommendation_ids = set()
for c_msid in actual_clicked_msids:
result = self.get_recommendations_by_id(c_msid, k)
for doc in result.get("retrieved_documents", []):
if doc['id'] not in seen_recommendation_ids:
combined_recommendations_docs.append(doc)
seen_recommendation_ids.add(doc['id'])
if not combined_recommendations_docs:
recommendations_result = {"retrieved_documents": [], "generated_response": "No recommendations found for the clicked items."}
else:
combined_recommendations_docs.sort(key=lambda x: x.get('score', 0.0), reverse=True)
final_retrieved_documents = combined_recommendations_docs[:k]
recommendations_result = {
"retrieved_documents": final_retrieved_documents,
"generated_response": f"Top {len(final_retrieved_documents)} recommendations based on your recent clicks on: {', '.join(actual_clicked_msids)}."
}
return recommendations_result
# Instantiate the recommender for use by other modules
recommender = RecoRecommender()