Spaces:
Sleeping
Sleeping
import pandas as pd # Import pandas for data manipulation, aliased as pd | |
import numpy as np # Import numpy for numerical operations, aliased as np | |
import faiss # Import faiss for efficient similarity search | |
import json # Import json for working with JSON data | |
from typing import List, Dict, Optional # Import typing hints for better code readability and static analysis | |
from datetime import datetime # Import datetime for handling date and time objects | |
import logging # Import logging module for application logging | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util # Import specific classes from sentence_transformers library | |
from indicnlp.tokenize import indic_tokenize # Import tokenizer for Indic languages | |
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory # Import normalizer factory for Indic languages | |
import torch # Import PyTorch for deep learning functionalities | |
import os # Import os module for interacting with the operating system (e.g., file paths) | |
from indicnlp import common # Import common utilities from the indicnlp library | |
import pickle # Import pickle for saving/loading Python objects (like lists of IDs) | |
import re # Import re module for regular expression operations | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Import classes for loading pre-trained models from Hugging Face Transformers | |
from fastapi import HTTPException # Import HTTPException for raising HTTP errors in FastAPI applications | |
import asyncio # Import asyncio for asynchronous operations | |
from src.fine_tuning.trainer import ModelTrainer | |
from src.fine_tuning.config import MODEL_STATUS | |
from src.config.settings import ( # Import configuration constants from the settings file | |
EMBED_MODEL_NAME, GENERATOR_MODEL_NAME, RERANKER_MODEL_NAME, # Model names | |
INDEX_PATH, INTERACTION_LOG_PATH, INDIC_NLP_RESOURCES_PATH, MONGO_FAISS_META_COLLECTION_NAME, # File paths & DB names | |
HEADLINE_COL, SEOLOCATION_COL, DEEPLINK_COL, LAST_UPDATED_COL, FINE_TUNED_RERANKER_SAVE_PATH, # Column names & paths | |
IMAGE_ID_COL, IMAGE_RATIO_COL, IMAGE_SIZE_COL, TAXONOMY_COL, INDEX_IDS_PATH, # Column names | |
SYN_COL, KEY_COL, ID_COL, TOPIC_COL, PROPERTY_COL, # Existing Column names | |
DEFAULT_K, SIMILARITY_THRESHOLD, CANDIDATE_MULTIPLIER # Recommendation parameters | |
) | |
from src.database.mongodb import mongodb # Import the MongoDB client instance | |
logger = logging.getLogger(__name__) # Initialize a logger instance for this module | |
class RecoRecommender: | |
""" | |
A RAG-based recommender system for multi language content using FAISS retrieval. | |
""" | |
FAISS_IDS_DOC_ID = "faiss_indexed_document_ids_v1" # Document ID for storing indexed IDs in MongoDB | |
MODEL_METADATA_DOC_ID = "model_metadata" | |
EMBEDDING_CHECKSUM_KEY = "embedding_checksum" | |
MODEL_VERSION_KEY = "model_version" | |
FINE_TUNING_CONFIG_KEY = "fine_tuning_config" | |
EMBEDDING_STATUS_KEY = "embedding_status" | |
MODEL_STATUS = MODEL_STATUS # Add MODEL_STATUS as class attribute | |
def __init__(self): | |
"""Initialize the recommender with configuration.""" | |
logger.info("Initializing RecoRecommender...") | |
# Configuration | |
self.headline_col = HEADLINE_COL | |
self.key = KEY_COL | |
self.syn = SYN_COL | |
self.id_col = ID_COL | |
self.topic_col = TOPIC_COL | |
self.property_col = PROPERTY_COL | |
self.taxonomy_col = TAXONOMY_COL | |
self.seolocation_col = SEOLOCATION_COL | |
self.deeplink_col = DEEPLINK_COL | |
self.last_updated_col = LAST_UPDATED_COL | |
self.image_id_col = IMAGE_ID_COL | |
self.image_ratio_col = IMAGE_RATIO_COL | |
self.image_size_col = IMAGE_SIZE_COL | |
self.processed_content_col = f"Processed Content" | |
# State Variables | |
self.df: Optional[pd.DataFrame] = None | |
self.index: Optional[faiss.Index] = None | |
self.embed_model: Optional[SentenceTransformer] = None | |
self.base_reranker: Optional[CrossEncoder] = None | |
self.fine_tuned_reranker: Optional[CrossEncoder] = None | |
self.indexed_ids: List[str] = [] | |
self.tokenizer: Optional[AutoTokenizer] = None | |
self.generator: Optional[AutoModelForSeq2SeqLM] = None | |
self.normalizer = None | |
# self.device = "cuda" if torch.cuda.is_available() else "cpu" # Old device detection | |
logger.info("Determining compute device...") | |
if torch.cuda.is_available(): | |
self.device = "cuda" | |
logger.info("CUDA is available. Using NVIDIA GPU.") | |
else: | |
try: | |
import intel_extension_for_pytorch as ipex # Attempt to import IPEX | |
if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
self.device = "xpu" | |
logger.info("Intel XPU is available via IPEX. Using Intel GPU.") | |
else: | |
self.device = "cpu" | |
logger.info("CUDA not available. Intel XPU not available or IPEX not fully configured. Using CPU.") | |
except ImportError: | |
self.device = "cpu" | |
logger.info("CUDA not available. Intel Extension for PyTorch (IPEX) not found. Using CPU.") | |
except Exception as e: # Catch other potential errors during XPU check | |
self.device = "cpu" | |
logger.error(f"Error during Intel XPU check: {e}. Using CPU.") | |
logger.info(f"Selected device: {self.device}") | |
# Initialize MongoDB collection only if connection is available | |
if mongodb.db is not None: | |
self.faiss_meta_collection = mongodb.db[MONGO_FAISS_META_COLLECTION_NAME] | |
else: | |
self.faiss_meta_collection = None | |
logger.warning("MongoDB not available. Some features may be limited.") | |
# Initialize model trainer | |
self.model_trainer = ModelTrainer(RERANKER_MODEL_NAME, device=self.device) # Pass the determined device | |
self._setup_indic_nlp() | |
logger.info("RecoRecommender initialized.") | |
def _setup_indic_nlp(self): | |
"""Initialize Indic NLP resources.""" | |
logger.info("Setting up Indic NLP resources...") # Log the start of Indic NLP setup | |
if not os.path.exists(INDIC_NLP_RESOURCES_PATH): # Check if the Indic NLP resources path exists | |
raise FileNotFoundError(f"Indic NLP resources not found at {INDIC_NLP_RESOURCES_PATH}") # Raise error if path not found | |
os.environ["INDIC_RESOURCES_PATH"] = INDIC_NLP_RESOURCES_PATH # Set environment variable for Indic NLP resource path | |
try: # Start a try-except block for error handling | |
common.set_resources_path(INDIC_NLP_RESOURCES_PATH) # Set the resource path for the indicnlp library | |
self.normalizer = IndicNormalizerFactory().get_normalizer("hi") # Initialize the language normalizer | |
logger.info("Indic NLP resources setup complete.") # Log successful setup | |
except Exception as e: # Catch any exception during setup | |
logger.error(f"Error setting up Indic NLP resources: {e}", exc_info=True) # Log the error with traceback | |
raise # Re-raise the exception to halt execution if setup fails | |
def load_models(self): | |
"""Load all required ML models.""" | |
logger.info("Loading ML models...") # Log the start of ML model loading | |
try: # Start a try-except block for error handling | |
if self.embed_model is None: # Check if the embedding model is not already loaded | |
logger.info(f"Loading embedding model: {EMBED_MODEL_NAME}") # Log which embedding model is being loaded | |
self.embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=self.device) # Load the sentence transformer model | |
if self.tokenizer is None or self.generator is None: # Check if tokenizer or generator model is not loaded | |
logger.info(f"Loading generator model: {GENERATOR_MODEL_NAME}") # Log which generator model is being loaded | |
self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_NAME, model_max_length=1024) # Load tokenizer for the generator | |
self.generator = AutoModelForSeq2SeqLM.from_pretrained(GENERATOR_MODEL_NAME) # Load the sequence-to-sequence LM | |
if self.device == "cuda": # Check if CUDA (GPU) is the selected device | |
self.generator = self.generator.to(self.device) # Move the generator model to the GPU | |
# Load base reranker model | |
if self.base_reranker is None: | |
logger.info(f"Loading base reranker model: {RERANKER_MODEL_NAME}") | |
self.base_reranker = CrossEncoder(RERANKER_MODEL_NAME, device=self.device) | |
self.reranker = self.base_reranker # Default to base model | |
# Try to load fine-tuned model if available | |
self._load_fine_tuned_model() | |
logger.info("ML models loaded.") # Log successful loading of all models | |
except Exception as e: # Catch any exception during model loading | |
logger.error(f"Error loading ML models: {e}", exc_info=True) # Log the error with traceback | |
raise # Re-raise the exception | |
def _load_fine_tuned_model(self): | |
"""Attempt to load the fine-tuned reranker model if available.""" | |
try: | |
metadata = self.model_trainer.get_model_status() | |
if metadata.get("current_model_status") == MODEL_STATUS["FINE_TUNED"]: | |
current_version = metadata.get("current_version", "v0") | |
model_path = str(self.model_trainer.get_model_path(current_version)) | |
if os.path.exists(model_path): | |
logger.info(f"Loading fine-tuned reranker model version {current_version}") | |
self.fine_tuned_reranker = CrossEncoder(model_path, device=self.device) | |
self.reranker = self.fine_tuned_reranker | |
logger.info("Successfully loaded fine-tuned reranker model") | |
return | |
logger.info("No fine-tuned model found or not in fine-tuned state, using base model") | |
self.reranker = self.base_reranker | |
except Exception as e: | |
logger.error(f"Error loading fine-tuned model: {e}. Falling back to base model.") | |
self.reranker = self.base_reranker | |
def _calculate_embedding_checksum(self, content: str) -> str: | |
"""Calculate a checksum for content to detect changes in embedding logic.""" | |
import hashlib | |
# Include model name and any relevant preprocessing parameters in the checksum | |
checksum_content = f"{content}_{EMBED_MODEL_NAME}" | |
return hashlib.md5(checksum_content.encode()).hexdigest() | |
def _get_model_metadata(self) -> Dict: | |
"""Retrieve current model metadata from MongoDB.""" | |
try: | |
if self.faiss_meta_collection is None: | |
logger.warning("MongoDB not available. Returning default metadata.") | |
return { | |
"_id": self.MODEL_METADATA_DOC_ID, | |
"embedding_model_name": EMBED_MODEL_NAME, | |
} | |
metadata = self.faiss_meta_collection.find_one({"_id": self.MODEL_METADATA_DOC_ID}) | |
if metadata: | |
return metadata | |
# Default metadata for the recommender service, focusing on aspects not directly | |
# managed by ModelTrainer's file-based metadata (e.g., embedding model name). | |
# Reranker model status, version, and its specific fine-tuning configuration | |
# are primarily managed by ModelTrainer and its model_metadata.json file. | |
return { | |
"_id": self.MODEL_METADATA_DOC_ID, | |
"embedding_model_name": EMBED_MODEL_NAME, | |
# Add other global operational metadata specific to RecoRecommender here if needed. | |
} | |
except Exception as e: | |
logger.error(f"Error retrieving model metadata: {e}") | |
# Return a minimal fallback to ensure core functionalities can proceed if possible | |
return {"_id": self.MODEL_METADATA_DOC_ID, "embedding_model_name": EMBED_MODEL_NAME} | |
def _update_model_metadata(self, updates: Dict) -> bool: | |
"""Update model metadata in MongoDB.""" | |
try: | |
if self.faiss_meta_collection is None: | |
logger.warning("MongoDB not available. Cannot update model metadata.") | |
return False | |
result = self.faiss_meta_collection.update_one( | |
{"_id": self.MODEL_METADATA_DOC_ID}, | |
{"$set": {**updates, "metadata_last_updated": datetime.now()}}, # Key changed for clarity | |
upsert=True | |
) | |
return result.acknowledged | |
except Exception as e: | |
logger.error(f"Error updating model metadata: {e}") | |
return False | |
def _increment_model_version(self) -> str: | |
"""DEPRECATED: Reranker model versioning is handled by ModelTrainer.""" | |
# This method appears unused and its logic conflicts with ModelTrainer's | |
# file-system based versioning for fine-tuned models. | |
logger.warning("RecoRecommender._increment_model_version() is deprecated. Reranker versioning is managed by ModelTrainer.") | |
# Fallback to ModelTrainer's current version if absolutely needed, but ideally this method should be removed. | |
return self.model_trainer.get_model_status().get("current_version", "v0") | |
def _needs_reembedding_batch(self, doc_ids: List[str], current_checksum: str) -> List[str]: | |
"""Check which documents from a batch need reembedding.""" | |
try: | |
if self.faiss_meta_collection is None: | |
logger.warning("MongoDB not available. Assuming all documents need reembedding.") | |
return doc_ids | |
# Query for all documents in one go | |
metadata_docs = self.faiss_meta_collection.find( | |
{"_id": {"$in": doc_ids}} | |
) | |
# Create a dict for quick lookup | |
metadata_map = {doc["_id"]: doc.get(self.EMBEDDING_CHECKSUM_KEY) for doc in metadata_docs} | |
# Return IDs that need reembedding | |
return [ | |
doc_id for doc_id in doc_ids | |
if doc_id not in metadata_map or metadata_map[doc_id] != current_checksum | |
] | |
except Exception as e: | |
logger.error(f"Error checking embedding status for batch: {e}") | |
return doc_ids # If error, assume all need reembedding | |
def _update_embedding_metadata(self, doc_id: str, checksum: str): | |
"""Update metadata for document embeddings.""" | |
try: | |
self.faiss_meta_collection.update_one( | |
{"_id": doc_id}, | |
{ | |
"$set": { | |
self.EMBEDDING_CHECKSUM_KEY: checksum, | |
"last_updated": datetime.now() | |
} | |
}, | |
upsert=True | |
) | |
except Exception as e: | |
logger.error(f"Error updating embedding metadata for doc {doc_id}: {e}") | |
def _needs_reembedding(self, doc_id: str, current_checksum: str) -> bool: | |
"""Check if a document needs to be reembedded.""" | |
try: | |
metadata = self.faiss_meta_collection.find_one({"_id": doc_id}) | |
if not metadata or metadata.get(self.EMBEDDING_CHECKSUM_KEY) != current_checksum: | |
return True | |
return False | |
except Exception as e: | |
logger.error(f"Error checking embedding status for doc {doc_id}: {e}") | |
return True | |
async def update_embeddings_and_index(self, force_reload_data: bool = True): | |
""" | |
Updates embeddings and the FAISS index. | |
If existing indexed documents have changed or the embedding model is different, | |
a full rebuild of the index is triggered. | |
Otherwise, only new documents are added incrementally. | |
The FAISS index file (INDEX_PATH) and metadata are saved. | |
""" | |
if force_reload_data: | |
logger.info("Reloading data from MongoDB for index update...") | |
self._load_data_from_mongo() | |
if self.df is None or self.df.empty: | |
logger.warning("DataFrame is empty. Cannot update embeddings or index.") | |
return | |
if self.embed_model is None: | |
logger.info("Embedding model not loaded. Loading models...") | |
self.load_models() # Ensure embed_model is available | |
logger.info("Checking for documents requiring embedding updates or additions...") | |
trigger_full_rebuild = False | |
if self.index and self.indexed_ids: | |
df_indexed_docs = self.df[self.df[self.id_col].isin(self.indexed_ids)] | |
content_map_for_indexed_ids = pd.Series( | |
df_indexed_docs[self.processed_content_col].values, | |
index=df_indexed_docs[self.id_col] | |
).to_dict() | |
for doc_id in self.indexed_ids: | |
if doc_id in content_map_for_indexed_ids: | |
current_content = content_map_for_indexed_ids[doc_id] | |
current_checksum = self._calculate_embedding_checksum(current_content) | |
if self._needs_reembedding(doc_id, current_checksum): | |
logger.info(f"Existing document ID {doc_id} requires re-embedding. Content/model change detected.") | |
trigger_full_rebuild = True | |
break | |
else: | |
logger.info(f"Document ID {doc_id} was in index but not in current data (deleted). Rebuild needed.") | |
trigger_full_rebuild = True | |
break | |
if self.index is None: | |
logger.info("No existing FAISS index found. A full build is required.") | |
trigger_full_rebuild = True | |
if trigger_full_rebuild: | |
logger.info("Triggering a full rebuild of the FAISS index.") | |
self.build_indexes_and_save(data_already_loaded=True) # self.df is already loaded | |
return | |
logger.info("No changes detected in existing indexed documents that require a full rebuild.") | |
current_indexed_ids_set = set(self.indexed_ids) | |
all_df_ids_set = set(self.df[self.id_col].tolist()) | |
new_doc_ids_to_add = list(all_df_ids_set - current_indexed_ids_set) | |
if new_doc_ids_to_add: | |
logger.info(f"Found {len(new_doc_ids_to_add)} new documents to add to the index.") | |
new_docs_df = self.df[self.df[self.id_col].isin(new_doc_ids_to_add)].copy() | |
if not new_docs_df.empty: | |
new_embeddings, actual_new_ids_embedded = self._generate_embeddings(new_docs_df) | |
if new_embeddings.size > 0: | |
self.index.add(new_embeddings.astype(np.float32)) | |
self.indexed_ids.extend(actual_new_ids_embedded) | |
logger.info(f"Added {len(actual_new_ids_embedded)} new vectors to FAISS index. Total vectors: {self.index.ntotal}.") | |
faiss.write_index(self.index, INDEX_PATH) | |
self.faiss_meta_collection.update_one( | |
{"_id": self.FAISS_IDS_DOC_ID}, | |
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now(), "total_vectors": self.index.ntotal}}, | |
upsert=True | |
) | |
logger.info("Updated FAISS index and list of indexed IDs saved.") | |
for doc_id in actual_new_ids_embedded: | |
content_for_checksum = new_docs_df[new_docs_df[self.id_col] == doc_id][self.processed_content_col].iloc[0] | |
checksum = self._calculate_embedding_checksum(content_for_checksum) | |
self._update_embedding_metadata(doc_id, checksum) | |
logger.info(f"Updated embedding metadata for {len(actual_new_ids_embedded)} new documents.") | |
else: | |
logger.info("No embeddings were generated for the new documents (e.g., content was empty).") | |
else: | |
logger.info("Identified new document IDs, but the corresponding DataFrame slice was empty.") | |
else: | |
logger.info("No new documents to add. Index is up-to-date with the current DataFrame.") | |
# Ensure index is saved if any changes (new docs added) occurred, even if not a full rebuild | |
# The logic above already saves if new_embeddings.size > 0. | |
# If no new docs and no rebuild, this point is reached. | |
# Redundant save check removed; saves are handled within the conditional branches above. | |
def fallback_to_base_model(self): | |
"""Switch to using the base model for inference.""" | |
if self.base_reranker is not None: | |
self.reranker = self.base_reranker | |
logger.info("Switched to base reranker model") | |
else: | |
logger.warning("Base reranker model not available") | |
def use_fine_tuned_model(self): | |
"""Switch to using the fine-tuned model for inference.""" | |
if self.fine_tuned_reranker is not None: | |
self.reranker = self.fine_tuned_reranker | |
logger.info("Switched to fine-tuned reranker model") | |
else: | |
logger.warning("Fine-tuned model not available, staying with current model") | |
def _load_data_from_mongo(self): | |
"""Load and preprocess data from MongoDB.""" | |
logger.info("Loading data from MongoDB...") # Log the start of data loading from MongoDB | |
try: # Start a try-except block for error handling | |
# Fetch documents with projection | |
projection = { # Define which fields to retrieve from MongoDB | |
self.id_col: 1, | |
self.headline_col: 1, # Include the headline column | |
self.key: 1, # Include the key column | |
self.syn: 1, # Include the system column | |
self.topic_col: 1, # Include the topic column | |
self.taxonomy_col: 1, # Include the taxonomy column | |
self.property_col: 1, # Include the property column | |
self.seolocation_col: 1, | |
self.deeplink_col: 1, | |
self.last_updated_col: 1, | |
self.image_id_col: 1, | |
self.image_ratio_col: 1, | |
self.image_size_col: 1, | |
"_id": 0 # Exclude the default MongoDB _id field | |
} | |
cursor = mongodb.news_collection.find({}, projection) # Execute the find query on the news_collection | |
data = list(cursor) # Convert the cursor result to a list of dictionaries | |
if not data: # Check if no data was returned from MongoDB | |
logger.warning("No documents found in MongoDB.") # Log a warning if no documents are found | |
self.df = pd.DataFrame(columns=[ # Create an empty DataFrame with expected columns | |
self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col, | |
self.topic_col, self.property_col, self.processed_content_col, | |
self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
self.image_id_col, self.image_ratio_col, self.image_size_col | |
]) | |
return # Exit the method as there's no data to process | |
self.df = pd.DataFrame(data) # Convert the list of dictionaries to a pandas DataFrame | |
logger.info(f"Loaded {len(self.df)} documents from MongoDB.") # Log the number of documents loaded | |
# Data cleaning and preprocessing | |
required_cols = [self.id_col, self.headline_col, self.key, self.syn] # Update the list of required columns | |
for col in required_cols: # Iterate over the list of required columns | |
if col not in self.df.columns: # Check if a required column is missing in the DataFrame | |
raise ValueError(f"Required column '{col}' not found in MongoDB documents.") # Raise error if missing | |
# Handle optional columns that are primarily textual or have simple defaults | |
textual_optional_cols = { | |
self.topic_col: "N/A", # Default to "N/A" string | |
self.property_col: "N/A", # Default to "N/A" string | |
self.key: "" # Default to empty string, used for text processing | |
} | |
for col, default_val in textual_optional_cols.items(): | |
if col not in self.df.columns: | |
logger.warning(f"Optional column '{col}' not found. Adding with default value '{default_val}'.") # Log a warning | |
self.df[col] = default_val | |
else: | |
self.df[col] = self.df[col].fillna(default_val).astype(str) # Fill NA and ensure string | |
# Handle other optional columns that are expected to be None if missing or all NaN | |
other_optional_cols = [ | |
self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
self.image_id_col, self.image_ratio_col, self.image_size_col | |
] | |
for col in other_optional_cols: | |
if col not in self.df.columns: | |
logger.warning(f"Optional column '{col}' not found. Adding with default value None.") | |
self.df[col] = None | |
else: | |
# If column exists, convert np.nan to None for object dtypes. | |
# For numeric dtypes, np.nan is the standard missing value representation. | |
if pd.api.types.is_object_dtype(self.df[col].dtype): | |
self.df[col] = self.df[col].replace({np.nan: None}) | |
# Special handling for taxonomy column (list of objects) | |
if self.taxonomy_col not in self.df.columns: | |
logger.warning(f"Taxonomy column '{self.taxonomy_col}' not found. Adding with default empty list for each row.") | |
self.df[self.taxonomy_col] = [[] for _ in range(len(self.df))] | |
else: | |
# Apply the cleaning function to the taxonomy column | |
self.df[self.taxonomy_col] = self.df[self.taxonomy_col].apply(self._clean_taxonomy_cell) | |
# Clean data | |
initial_len = len(self.df) # Store the number of rows before cleaning | |
self.df = self.df.dropna(subset=[self.id_col, self.headline_col, self.syn]) # Drop rows where essential columns are NaN | |
self.df = self.df[self.df[self.headline_col].apply(lambda x: isinstance(x, str) and x.strip() != '')] # Keep rows with non-empty string headlines | |
self.df = self.df[self.df[self.syn].apply(lambda x: isinstance(x, str) and x.strip() != '')] | |
if len(self.df) < initial_len: # Check if any rows were dropped during cleaning | |
logger.warning(f"Dropped {initial_len - len(self.df)} rows due to missing/invalid values.") # Log the number of dropped rows | |
if self.df.empty: # Check if the DataFrame is empty after cleaning | |
logger.warning("DataFrame is empty after cleaning.") # Log a warning | |
self.df = pd.DataFrame(columns=[ # Re-initialize an empty DataFrame with expected columns | |
self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col, | |
self.topic_col, self.property_col, self.processed_content_col, | |
self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
self.image_id_col, self.image_ratio_col, self.image_size_col | |
]) | |
return # Exit if DataFrame is empty | |
# Improved contextual preprocessing: combine headline, synopsis, and keywords for richer embeddings | |
logger.info("Preprocessing and combining content (headline, synopsis, keywords, taxonomy) for contextual embeddings...") | |
self.df[self.headline_col] = self.df[self.headline_col].astype(str) | |
self.df[self.syn] = self.df[self.syn].astype(str) | |
self.df[self.key] = self.df[self.key].astype(str) | |
def combine_and_preprocess_content(row): | |
# Combine headline, synopsis, keywords, and taxonomy for better context | |
headline = row[self.headline_col].strip() | |
synopsis = row[self.syn].strip() | |
keywords = str(row[self.key]).strip() # Ensure string, though already handled | |
# Preprocess each part | |
processed_headline = self._preprocess_text(headline) | |
processed_synopsis = self._preprocess_text(synopsis) | |
processed_keywords = self._preprocess_text(keywords) | |
# Process taxonomy terms | |
taxonomy_terms_list = row[self.taxonomy_col] | |
processed_taxonomy_names = [] | |
if isinstance(taxonomy_terms_list, list): | |
for term_obj in taxonomy_terms_list: | |
if isinstance(term_obj, dict) and "name" in term_obj and term_obj["name"]: | |
processed_taxonomy_names.append(self._preprocess_text(str(term_obj["name"]))) | |
processed_taxonomy_string = " ".join(p_name for p_name in processed_taxonomy_names if p_name) | |
# Join all parts, prioritizing non-empty fields | |
parts = [p for p in [processed_headline, processed_synopsis, processed_keywords, processed_taxonomy_string] if p] | |
# A more structured combination might help models understand the different parts | |
structured_parts = [] | |
if processed_headline: structured_parts.append(f"शीर्षक: {processed_headline}") | |
if processed_synopsis: structured_parts.append(f"सारांश: {processed_synopsis}") | |
if processed_keywords: structured_parts.append(f"कीवर्ड: {processed_keywords}") | |
if processed_taxonomy_string: structured_parts.append(f"श्रेणी: {processed_taxonomy_string}") | |
return " ".join(structured_parts) | |
self.df[self.processed_content_col] = self.df.apply(combine_and_preprocess_content, axis=1) | |
logger.info("Data loading and preprocessing complete.") # Log completion of data loading and preprocessing | |
except Exception as e: # Catch any exception during the process | |
# Ensure df is None or empty on failure | |
logger.error(f"Error loading data from MongoDB: {e}", exc_info=True) # Log the error with traceback | |
self.df = None # Set DataFrame to None to indicate failure | |
raise # Re-raise the exception | |
def _clean_taxonomy_cell(self, cell_value) -> list: | |
""" | |
Cleans individual cells of the taxonomy column. | |
Ensures that the cell content is a list, handling scalars, NaNs, and unexpected array types. | |
""" | |
if isinstance(cell_value, list): | |
# If it's already a list (empty or not), return it. | |
return cell_value | |
elif pd.api.types.is_scalar(cell_value) and pd.isna(cell_value): | |
# Handles scalar np.nan, None, pd.NA by returning an empty list. | |
return [] | |
elif isinstance(cell_value, (np.ndarray, pd.Series)): | |
# Handles unexpected np.ndarray or pd.Series in a cell. | |
logger.warning( | |
f"Unexpected array/Series type in taxonomy column cell: {type(cell_value)}. " | |
f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list." | |
) | |
return [] | |
else: | |
# Handles other scalar types (e.g., string, int) or unhandled types. | |
# These are converted to an empty list, mimicking the original lambda's behavior. | |
if not pd.api.types.is_scalar(cell_value): # Log if it's an unexpected non-scalar, non-list, non-array type | |
logger.warning( | |
f"Unexpected non-scalar, non-list/array type in taxonomy column cell: {type(cell_value)}. " | |
f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list." | |
) | |
return [] | |
def _preprocess_text(self, text: str) -> str: | |
"""Normalize and tokenize Hindi text.""" | |
""" | |
Normalize and tokenize text. Applies Indic normalization for Hindi text | |
and generic tokenization for others. | |
""" | |
if not isinstance(text, str): | |
logger.warning(f"Received non-string input for preprocessing: {type(text)}") | |
return "" | |
text = text.strip() # Remove leading/trailing whitespace | |
if not text: # Handle empty strings after stripping | |
return "" | |
try: | |
# Check for presence of Devanagari characters to identify Hindi text. | |
# This is a basic check; for more complex scenarios, consider a language detection library. | |
is_text = bool(re.search(r'[\u0900-\u097F]', text)) | |
if is_text: | |
if self.normalizer is None: | |
logger.error("IndicNormalizer not initialized, but text detected. Proceeding without normalization.") | |
normalized_text = text # Fallback: use original text if normalizer is missing | |
else: | |
normalized_text = self.normalizer.normalize(text) | |
else: | |
# For non-Hindi text, skip Hindi-specific normalization | |
# logger.debug(f"Skipping Hindi normalization for non-Hindi text: {text[:50]}...") # Optional: enable if needed | |
normalized_text = text | |
tokens = indic_tokenize.trivial_tokenize(normalized_text) # Tokenize the (potentially normalized) text | |
return " ".join(tokens) | |
except Exception as e: | |
logger.error(f"Error during text preprocessing for text starting with '{text[:50]}...': {e}", exc_info=True) | |
# Fallback to returning the original text to prevent downstream errors | |
return text | |
def _generate_embeddings(self, df_subset: pd.DataFrame) -> tuple[np.ndarray, List[str]]: | |
""" | |
Generate embeddings for the processed content column of a DataFrame subset. | |
Returns a tuple of (embeddings_array, list_of_ids). | |
""" | |
if not self.embed_model: # Check if the embedding model is loaded | |
raise RuntimeError("Embedding model not loaded") # Raise an error if the model is not loaded | |
if df_subset.empty or self.processed_content_col not in df_subset.columns or self.id_col not in df_subset.columns: | |
logger.warning( | |
f"DataFrame subset is empty or missing required columns " | |
f"('{self.processed_content_col}', '{self.id_col}') for embedding generation." | |
) | |
return np.array([]), [] | |
texts_to_embed = df_subset[self.processed_content_col].tolist() | |
# Ensure all texts are strings (preprocessing should ideally handle this, but as a safeguard) | |
texts_to_embed = [str(t) if pd.notna(t) else "" for t in texts_to_embed] | |
ids_to_embed = df_subset[self.id_col].tolist() | |
if not texts_to_embed: # If all texts became empty strings or list was initially empty | |
logger.warning("No non-empty texts available in the subset to generate embeddings.") | |
return np.array([]), [] | |
try: # Start a try-except block for error handling | |
embeddings = self.embed_model.encode(texts_to_embed, show_progress_bar=False, convert_to_numpy=True) # Generate embeddings | |
return embeddings, ids_to_embed # Return the generated embeddings and their IDs | |
except Exception as e: # Catch any exception during embedding generation | |
logger.error(f"Error generating embeddings: {e}", exc_info=True) # Log the error with traceback | |
return np.array([]), [] # Return empty array and list in case of error | |
def _build_faiss_index(self) -> List[str]: | |
"""Build FAISS index from processed content and return the list of indexed IDs.""" | |
if self.df is None or self.df.empty or self.processed_content_col not in self.df.columns: | |
raise ValueError("Cannot build FAISS index: Data not ready (DataFrame is None, empty, or missing processed content column)") | |
if self.embed_model is None: | |
raise RuntimeError("Cannot build FAISS index: Embedding model not loaded") | |
logger.info("Building FAISS index...") | |
if self.df.empty: | |
raise ValueError("DataFrame is empty, cannot build FAISS index") | |
# Process in batches for better memory management and parallelization | |
BATCH_SIZE = 1000 # Adjust based on available memory | |
total_docs = len(self.df) | |
all_embeddings = [] | |
all_ids = [] | |
for start_idx in range(0, total_docs, BATCH_SIZE): | |
end_idx = min(start_idx + BATCH_SIZE, total_docs) | |
batch_df = self.df.iloc[start_idx:end_idx] | |
# Generate embeddings for the batch | |
batch_embeddings, batch_ids = self._generate_embeddings(batch_df) | |
if batch_embeddings.size > 0: | |
all_embeddings.append(batch_embeddings) | |
all_ids.extend(batch_ids) | |
logger.info(f"Processed batch {start_idx//BATCH_SIZE + 1}/{(total_docs + BATCH_SIZE - 1)//BATCH_SIZE}") | |
if not all_embeddings: | |
logger.warning("No embeddings were generated. FAISS index will be empty.") | |
self.index = None | |
return [] | |
# Concatenate all batch embeddings | |
embeddings = np.vstack(all_embeddings).astype(np.float32) | |
dimension = embeddings.shape[1] | |
# Initialize HNSW index with optimized parameters | |
hnsw_m = 32 # Number of neighbors per layer | |
ef_construction = 100 # Higher value = better accuracy but slower construction | |
self.index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT) | |
self.index.hnsw.efConstruction = ef_construction | |
# Add vectors in batches to reduce memory usage | |
BATCH_SIZE_INDEX = 10000 # Adjust based on available memory | |
for i in range(0, len(embeddings), BATCH_SIZE_INDEX): | |
batch = embeddings[i:i + BATCH_SIZE_INDEX] | |
self.index.add(batch) | |
logger.info(f"Added batch {i//BATCH_SIZE_INDEX + 1}/{(len(embeddings) + BATCH_SIZE_INDEX - 1)//BATCH_SIZE_INDEX} to index") | |
logger.info(f"FAISS index built with {self.index.ntotal} vectors") | |
return all_ids | |
def _load_faiss_index_and_ids(self): | |
"""Load FAISS index and corresponding IDs from files.""" | |
if not os.path.exists(INDEX_PATH): # Check if the FAISS index file exists at the specified path | |
raise FileNotFoundError(f"FAISS index file not found at {INDEX_PATH}") # Raise error if file not found | |
logger.info(f"Loading FAISS index from: {INDEX_PATH}") # Log the path from which the index is being loaded | |
try: # Start a try-except block for error handling | |
self.index = faiss.read_index(INDEX_PATH) # Read the FAISS index from the file | |
logger.info(f"FAISS index loaded with {self.index.ntotal} vectors") # Log the number of vectors in the loaded index | |
except Exception as e: # Catch any exception during index loading | |
logger.error(f"Error loading FAISS index: {e}", exc_info=True) # Log the error with traceback | |
self.index = None # Ensure index is None on failure | |
raise # Re-raise the exception | |
# Try to load IDs from MongoDB if available | |
if self.faiss_meta_collection is not None: | |
logger.info(f"Loading FAISS index IDs from MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'") | |
try: | |
ids_document = self.faiss_meta_collection.find_one({"_id": self.FAISS_IDS_DOC_ID}) | |
if ids_document and "ids" in ids_document: | |
self.indexed_ids = ids_document["ids"] | |
logger.info(f"Loaded {len(self.indexed_ids)} indexed IDs from MongoDB.") | |
# Basic consistency check | |
if self.index and self.index.ntotal != len(self.indexed_ids): | |
logger.warning( | |
f"FAISS index vector count ({self.index.ntotal}) " | |
f"does not match loaded ID count from MongoDB ({len(self.indexed_ids)}). " | |
"Index might be inconsistent. Consider rebuilding." | |
) | |
else: | |
logger.warning(f"FAISS index IDs document not found in MongoDB or 'ids' field missing. Will attempt to build if necessary.") | |
self.indexed_ids = [] # Initialize as empty if not found | |
except Exception as e: | |
logger.error(f"Error loading FAISS index IDs from MongoDB: {e}", exc_info=True) | |
self.indexed_ids = [] # Ensure IDs list is empty on failure | |
# We don't re-raise here, as load_components will decide if a rebuild is needed | |
# based on whether self.indexed_ids is populated. | |
else: | |
logger.warning("MongoDB not available. Cannot load indexed IDs. Operating with empty ID list.") | |
self.indexed_ids = [] | |
def build_indexes_and_save(self, data_already_loaded: bool = False): | |
""" | |
Load data (if not already loaded), build FAISS index from current self.df, and save. | |
Assumes self.df is populated if data_already_loaded is True. | |
""" | |
logger.info("Starting index building process...") # Log the start of the index building and saving process | |
try: # Start a try-except block for error handling | |
if not data_already_loaded: | |
self._load_data_from_mongo() # Load data from MongoDB if not already loaded | |
if self.df is None or self.df.empty: # Check if DataFrame is None or empty after loading | |
raise ValueError("Data is empty. Cannot build index.") # Raise error if data is not loaded properly | |
if self.embed_model is None: # Ensure models are loaded before building index | |
self.load_models() | |
# Build the FAISS index using the loaded data and models | |
# This method now returns the list of IDs that were indexed | |
indexed_ids = self._build_faiss_index() | |
self.indexed_ids = indexed_ids # Store the list of IDs | |
# Save FAISS index | |
logger.info(f"Saving FAISS index to: {INDEX_PATH}") # Log the path where the index will be saved | |
index_dir = os.path.dirname(INDEX_PATH) # Get the directory part of the index path | |
if index_dir and not os.path.exists(index_dir): # If the directory exists and is not empty string | |
os.makedirs(index_dir) # Create the directory if it doesn't exist | |
# Save the index and the corresponding IDs | |
faiss.write_index(self.index, INDEX_PATH) | |
# Save indexed_ids to MongoDB | |
logger.info(f"Saving FAISS index IDs to MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'") | |
try: | |
if self.faiss_meta_collection is not None: | |
self.faiss_meta_collection.update_one( | |
{"_id": self.FAISS_IDS_DOC_ID}, | |
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}}, | |
upsert=True | |
) | |
logger.info(f"Saved {len(self.indexed_ids)} indexed IDs to MongoDB.") | |
else: | |
logger.warning("MongoDB not available. Skipping indexed IDs save.") | |
except Exception as e: | |
logger.error(f"Error saving FAISS index IDs to MongoDB: {e}", exc_info=True) | |
# Don't raise here as the FAISS index was saved successfully | |
# The IDs can be regenerated if needed | |
logger.info("Index building and saving complete") # Log successful completion | |
except Exception as e: # Catch any exception during the process | |
logger.error(f"Error during index building: {e}", exc_info=True) # Log the error with traceback | |
raise # Re-raise the exception | |
def load_components(self): | |
"""Load all components (models, data, index).""" | |
logger.info("Loading components...") # Log the start of component loading | |
try: # Start a try-except block for error handling | |
self.load_models() # Load all machine learning models | |
# Try to load data from MongoDB, but handle failures gracefully | |
try: | |
self._load_data_from_mongo() # Load data from MongoDB and preprocess it | |
logger.info("Successfully loaded data from MongoDB") | |
except Exception as mongo_error: | |
logger.warning(f"Failed to load data from MongoDB: {mongo_error}") | |
logger.info("Attempting to work with existing FAISS index without MongoDB data...") | |
# Set df to None to indicate no MongoDB data is available | |
self.df = None | |
# Try to load FAISS index and IDs | |
try: | |
self._load_faiss_index_and_ids() # Tries to load .bin and IDs from Mongo. | |
# Sets self.index and self.indexed_ids. | |
# self.indexed_ids will be [] if Mongo data for IDs is missing. | |
# Raises FileNotFoundError if .bin (INDEX_PATH) is missing. | |
if self.index and self.index.ntotal > 0: | |
logger.info(f"FAISS index loaded successfully with {self.index.ntotal} vectors") | |
# If we have MongoDB data, proceed with normal logic | |
if self.df is not None and not self.df.empty: | |
# Consistency check and incremental update logic | |
if not self.indexed_ids: | |
logger.warning("FAISS index file loaded, but no corresponding IDs found in MongoDB. Rebuilding for consistency.") | |
self.build_indexes_and_save(data_already_loaded=True) | |
else: | |
logger.info("Existing FAISS index and IDs loaded from storage.") | |
# Proceed with incremental update logic | |
current_df_ids = set(self.df[self.id_col].tolist()) | |
indexed_ids_set = set(self.indexed_ids) | |
new_ids_to_add = list(current_df_ids - indexed_ids_set) | |
if new_ids_to_add: | |
logger.info(f"Found {len(new_ids_to_add)} new documents to add to the index.") | |
new_docs_df = self.df[self.df[self.id_col].isin(new_ids_to_add)].copy() | |
new_embeddings, new_doc_ids_added = self._generate_embeddings(new_docs_df) | |
if new_embeddings.size > 0: | |
self.index.add(new_embeddings.astype(np.float32)) | |
self.indexed_ids.extend(new_doc_ids_added) | |
logger.info(f"Added {len(new_doc_ids_added)} new vectors to FAISS index. Total vectors: {self.index.ntotal}") | |
# Save the updated FAISS index | |
faiss.write_index(self.index, INDEX_PATH) | |
# Try to save the updated IDs to MongoDB, but don't fail if it doesn't work | |
try: | |
self.faiss_meta_collection.update_one( | |
{"_id": self.FAISS_IDS_DOC_ID}, | |
{"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}}, | |
upsert=True | |
) | |
logger.info("Updated FAISS index and IDs saved to MongoDB.") | |
except Exception as e: | |
logger.warning(f"Could not save IDs to MongoDB: {e}") | |
else: | |
logger.info("No new documents found to add to the index. Index is up-to-date.") | |
else: | |
# No MongoDB data available, but we have a FAISS index | |
logger.info("FAISS index available but no MongoDB data. Operating in limited mode.") | |
if not self.indexed_ids: | |
logger.warning("No indexed IDs available. Some functionality may be limited.") | |
else: | |
# This case handles if self.index is None (FileNotFoundError caught below) | |
# or if index was loaded but empty and no IDs from Mongo. | |
if self.df is not None and not self.df.empty: | |
logger.info("FAISS index and/or IDs not found or empty. Building new index.") | |
self.build_indexes_and_save(data_already_loaded=True) | |
else: | |
logger.warning("No data available (neither MongoDB nor FAISS index). Cannot build index.") | |
except FileNotFoundError: # This means INDEX_PATH (.bin file) was not found. | |
logger.warning(f"FAISS index file ({INDEX_PATH}) not found.") | |
if self.df is not None and not self.df.empty: | |
logger.info("Building index from scratch.") | |
self.build_indexes_and_save(data_already_loaded=True) | |
else: | |
logger.error("Cannot build index: no data available.") | |
except Exception as e: | |
logger.error(f"Error loading FAISS index: {e}", exc_info=True) | |
if self.df is not None and not self.df.empty: | |
logger.info("Attempting to rebuild index due to loading error.") | |
self.build_indexes_and_save(data_already_loaded=True) | |
else: | |
logger.error("Cannot rebuild index: no data available.") | |
logger.info("Components loaded successfully") # Log successful loading of all components | |
except Exception as e: # Catch any exception during component loading | |
logger.error(f"Error loading components: {e}", exc_info=True) # Log the error with traceback | |
raise # Re-raise the exception | |
def get_recommendations( | |
self, # Added self parameter | |
query: str, | |
k: int = DEFAULT_K, # Number of recommendations to return, defaults to DEFAULT_K from config | |
similarity_threshold: float = SIMILARITY_THRESHOLD # Similarity threshold, defaults to SIMILARITY_THRESHOLD from config | |
) -> Dict: | |
""" | |
Get recommendations for a query. | |
Returns a dictionary with retrieved documents and generated response. | |
""" | |
# Check prerequisites | |
if self.df is None or self.df.empty: # Check if the DataFrame (content data) is loaded | |
raise HTTPException(status_code=503, detail="Recommender data not available") # Raise 503 error if data is missing | |
if not all([self.index, self.embed_model, self.reranker, self.generator]): # Check if all essential components are loaded | |
missing = [ # List comprehension to find names of missing components | |
name for name, component in [ # Iterate through component names and their instances | |
("FAISS index", self.index), # FAISS index | |
("Embedding model", self.embed_model), # Embedding model | |
("Reranker model", self.reranker), # Reranker model | |
("Generator model", self.generator) # Generator model | |
] if component is None # Check if the component is None (not loaded) | |
] | |
raise HTTPException( # Raise 503 error if components are missing | |
status_code=503, # HTTP status code for Service Unavailable | |
detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}" # Error detail listing missing components | |
) | |
logger.info(f"Processing recommendation request: query='{query}', k={k}") # Log the incoming recommendation request | |
# Preprocess the query | |
processed_query = self._preprocess_text(query) # Preprocess the input Hindi query | |
query_embedding, _ = self._generate_embeddings(pd.DataFrame({self.processed_content_col: [processed_query], self.id_col: ["query"]})) # Generate embedding for the processed query | |
if query_embedding.size == 0: # Check if the query embedding is empty (e.g., generation failed) | |
logger.warning("Query embedding is empty.") # Log a warning | |
return { # Return an empty result | |
"retrieved_documents": [], # Empty list of documents | |
"generated_response": "No recommendations found. (Query embedding failed)" # Informative message | |
} | |
# Retrieve candidates from FAISS | |
num_candidates = max(k * CANDIDATE_MULTIPLIER, k) # Determine number of candidates to fetch (k * multiplier, or at least k) | |
try: # Start a try-except block for FAISS search | |
D, I = self.index.search(query_embedding.astype(np.float32), num_candidates) # Search FAISS index (D=distances, I=indices) | |
except Exception as e: # Catch any exception during FAISS search | |
logger.error(f"Error during FAISS search: {e}", exc_info=True) # Log the error with traceback | |
return { # Return an empty result | |
"retrieved_documents": [], # Empty list of documents | |
"generated_response": "No recommendations found. (FAISS search failed)" # Informative message | |
} | |
# Process FAISS results | |
retrieved_faiss_indices = I[0] | |
retrieved_faiss_scores = D[0] | |
# Map FAISS indices to original document IDs and collect scores | |
valid_candidate_data = [] | |
for faiss_idx, score in zip(retrieved_faiss_indices, retrieved_faiss_scores): | |
# Ensure FAISS index is valid and within bounds of self.indexed_ids | |
if faiss_idx != -1 and faiss_idx < len(self.indexed_ids): | |
valid_candidate_data.append({ | |
"original_id": self.indexed_ids[faiss_idx], # Actual ID of the item | |
"faiss_score": score | |
}) | |
elif faiss_idx != -1: # Log if faiss_idx is valid but out of bounds for indexed_ids | |
logger.warning( | |
f"FAISS index {faiss_idx} is out of bounds for self.indexed_ids (len: {len(self.indexed_ids)}). Skipping." | |
) | |
if not valid_candidate_data: | |
logger.info(f"No valid candidates found from FAISS/ID mapping for query '{query}'.") | |
return { | |
"retrieved_documents": [], | |
"generated_response": f"No recommendations found for '{query}' (no FAISS results or ID mapping issue)." | |
} | |
# Create a DataFrame from valid FAISS candidates (contains 'original_id' and 'faiss_score') | |
faiss_candidates_df = pd.DataFrame(valid_candidate_data) | |
# Fetch full candidate details from self.df by merging | |
# This uses the 'original_id' (which are actual item IDs) to robustly fetch data | |
# and preserves the order from FAISS retrieval. | |
candidates = pd.merge( | |
faiss_candidates_df, # Left DataFrame (dictates order and includes faiss_score) | |
self.df, # Right DataFrame (provides full item details) | |
left_on="original_id",# Key in faiss_candidates_df | |
right_on=self.id_col, # Key in self.df | |
how="inner" # Ensures only items present in both are kept | |
) | |
# If 'original_id' column is different from self.id_col and still exists, drop it as it's redundant | |
if "original_id" in candidates.columns and "original_id" != self.id_col: | |
candidates = candidates.drop(columns=["original_id"]) | |
# Ensure the ID column is of a consistent type if needed, though it should match indexed_ids type | |
# candidates[self.id_col] = candidates[self.id_col].astype(str) # Example if IDs need to be strings | |
# Filter out exact matches with query (case-insensitive, strip spaces) | |
candidates = candidates[ | |
(candidates[self.headline_col].str.strip().str.lower() != query.strip().lower()) & | |
(candidates[self.syn].str.strip().str.lower() != query.strip().lower()) | |
] | |
if candidates.empty: | |
logger.info(f"No candidates left after filtering exact query matches for query '{query}'.") | |
return {"retrieved_documents": [], "generated_response": f"No distinct recommendations found for '{query}'."} | |
candidates = candidates.drop_duplicates(subset=[self.syn]).copy() # Use .copy() after selection/drop_duplicates | |
if candidates.empty: | |
logger.info(f"No candidates left after dropping duplicates for query '{query}'.") | |
return {"retrieved_documents": [], "generated_response": f"No unique recommendations found for '{query}'."} | |
# Rerank using cross-encoder | |
#rerank_pairs = [(query, str(row[self.headline_col])) for _, row in candidates.iterrows()] # Create pairs of (query, candidate_headline) for reranking | |
rerank_pairs = [(query, str(row[self.syn])) for _, row in candidates.iterrows()] | |
if rerank_pairs: # Check if there are any candidate pairs to rerank | |
rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False) # Predict reranking scores | |
logger.info(f"Raw rerank scores for query '{query}': {rerank_scores.tolist()}") # Log raw scores | |
candidates["rerank_score"] = rerank_scores # Add rerank scores as a new column | |
candidates = candidates.sort_values("rerank_score", ascending=False) # Sort candidates by rerank score in descending order | |
logger.debug(f"Top candidates before thresholding (query='{query}', threshold={similarity_threshold}):") | |
for _, row in candidates.head().iterrows(): # Log top few candidates before filtering | |
logger.debug(f" ID: {row[self.id_col]}, Synopsis: {str(row[self.syn])[:50]}..., Rerank Score: {row['rerank_score']:.4f}") | |
#candidates = candidates[candidates["rerank_score"] >= similarity_threshold] | |
candidates = candidates[candidates["rerank_score"] >= similarity_threshold] | |
logger.info(f"Number of candidates after applying similarity_threshold ({similarity_threshold}): {len(candidates)}") | |
else: # If no pairs to rerank (e.g., all candidates were filtered out) | |
logger.info(f"No candidate pairs to rerank for query '{query}'.") | |
candidates["rerank_score"] = 0.0 # Assign a default rerank score of 0.0 | |
# Select top-k | |
top_candidates = candidates.head(k) # Select the top-k candidates after reranking | |
# Prepare output | |
retrieved_documents = [] # Initialize an empty list to store formatted retrieved documents | |
for _, row in top_candidates.iterrows(): # Iterate through the top-k candidate rows | |
taxonomy_data = row[self.taxonomy_col] | |
taxonomy_names = [] | |
if isinstance(taxonomy_data, list): | |
for term_obj in taxonomy_data: | |
if isinstance(term_obj, dict) and term_obj.get("name"): # Check key exists and has a value | |
taxonomy_names.append(str(term_obj["name"])) | |
retrieved_documents.append({ # Create a dictionary for each retrieved document | |
"id": row[self.id_col], # Document ID | |
"hl": str(row[self.headline_col]), # Document headline | |
"synopsis": row[self.syn], # Document primary content (synopsis) | |
"keywords": row[self.key], # Document secondary content (keywords) | |
"type": row.get(self.topic_col, None), # Document topic (or None if not available) | |
"taxonomy": taxonomy_names, # List of taxonomy names | |
"score": row["rerank_score"], # Rerank score of the document | |
"seolocation": row.get(self.seolocation_col, None), | |
"dl": row.get(self.deeplink_col, None), | |
"lu": row.get(self.last_updated_col, None), | |
"imageid": row.get(self.image_id_col, None), | |
"imgratio": row.get(self.image_ratio_col, None), | |
"imgsize": row.get(self.image_size_col, None) | |
}) | |
# Optionally, generate a response using the generator model (not implemented here) | |
generated_response = f"Top {len(retrieved_documents)} recommendations for '{query}'." # Create a simple generated response string | |
return { # Return the final result as a dictionary | |
"retrieved_documents": retrieved_documents, # List of retrieved documents | |
"generated_response": generated_response # Generated textual response | |
} | |
def _format_retrieved_documents(self, documents_df: pd.DataFrame) -> List[Dict]: | |
"""Helper function to format DataFrame rows into a list of document dictionaries.""" | |
retrieved_documents = [] | |
for _, row in documents_df.iterrows(): | |
taxonomy_data = row[self.taxonomy_col] | |
taxonomy_names = [] | |
if isinstance(taxonomy_data, list): | |
for term_obj in taxonomy_data: | |
if isinstance(term_obj, dict) and term_obj.get("name"): | |
taxonomy_names.append(str(term_obj["name"])) | |
retrieved_documents.append({ | |
"id": row[self.id_col], | |
"hl": str(row[self.headline_col]), | |
"synopsis": str(row.get(self.syn, "")), # Ensure synopsis is string | |
"keywords": str(row.get(self.key, "")), # Ensure keywords is string | |
"type": row.get(self.topic_col, None), | |
"taxonomy": taxonomy_names, | |
"score": row.get("rerank_score", 0.0), # Use .get for safety if rerank_score might be missing | |
"seolocation": row.get(self.seolocation_col, None), | |
"dl": row.get(self.deeplink_col, None), | |
"lu": row.get(self.last_updated_col, None), | |
"imageid": row.get(self.image_id_col, None), | |
"imgratio": row.get(self.image_ratio_col, None), | |
"imgsize": row.get(self.image_size_col, None) | |
}) | |
return retrieved_documents | |
def get_recommendations_by_id( | |
self, | |
msid: str, | |
k: int = DEFAULT_K, | |
similarity_threshold: float = SIMILARITY_THRESHOLD | |
) -> Dict: | |
""" | |
Get recommendations based on a given item ID (msid). | |
Finds items similar to the item identified by msid. | |
Returns a dictionary with retrieved documents. | |
""" | |
# Check prerequisites | |
if self.df is None or self.df.empty: | |
logger.error("Recommender data not available for get_recommendations_by_id.") | |
raise HTTPException(status_code=503, detail="Recommender data not available") | |
# Generator model is not strictly needed for this item-to-item recommendation path | |
if not all([self.index, self.embed_model, self.reranker]): | |
missing = [ | |
name for name, component in [ | |
("FAISS index", self.index), | |
("Embedding model", self.embed_model), | |
("Reranker model", self.reranker), | |
] if component is None | |
] | |
logger.error(f"Recommender not fully initialized for get_recommendations_by_id. Missing: {', '.join(missing)}") | |
raise HTTPException( | |
status_code=503, | |
detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}" | |
) | |
logger.info(f"Processing recommendation request for item ID: '{msid}', k={k}") | |
# Find the source item in the DataFrame | |
source_item_row = self.df[self.df[self.id_col] == msid] | |
if source_item_row.empty: | |
logger.warning(f"Item with ID '{msid}' not found in DataFrame.") | |
# Ensure the response structure matches what RecommendationResponse expects | |
# even on failure, to avoid FastAPI server errors during response model validation. | |
return {"retrieved_documents": [], "generated_response": f"Item with ID '{msid}' not found."} | |
# Or, if you prefer to let the route handler catch this as a 404: | |
# raise HTTPException(status_code=404, detail=f"Item with ID '{msid}' not found") | |
source_item = source_item_row.iloc[0] | |
source_item_content_for_embedding = source_item[self.processed_content_col] | |
source_item_content_for_reranking = str(source_item[self.syn]) # Using synopsis for reranker query | |
# Generate embedding for the source item's content | |
item_embedding, _ = self._generate_embeddings( | |
pd.DataFrame({ | |
self.processed_content_col: [source_item_content_for_embedding], | |
self.id_col: [msid] # Dummy ID, as _generate_embeddings expects it | |
}) | |
) | |
if item_embedding.size == 0: | |
logger.warning(f"Embedding generation failed for item ID '{msid}'.") | |
return {"retrieved_documents": [], "generated_response": "No recommendations found (source item embedding failed)"} | |
# Retrieve candidates from FAISS: fetch k+1 (or more with multiplier) to account for filtering source item | |
num_candidates_to_fetch = max((k + 1) * CANDIDATE_MULTIPLIER, k + 1) | |
try: | |
D, I = self.index.search(item_embedding.astype(np.float32), num_candidates_to_fetch) | |
except Exception as e: | |
logger.error(f"Error during FAISS search for item ID '{msid}': {e}", exc_info=True) | |
return {"retrieved_documents": [], "generated_response": "No recommendations found (FAISS search failed)"} | |
candidate_faiss_indices = I[0] | |
# candidate_scores = D[0] # FAISS scores, can be used if needed | |
valid_mask = candidate_faiss_indices != -1 | |
candidate_faiss_indices = candidate_faiss_indices[valid_mask] | |
if len(candidate_faiss_indices) == 0: | |
logger.info(f"No candidates found from FAISS for item ID '{msid}'.") | |
return {"retrieved_documents": [], "generated_response": f"No similar items found for ID '{msid}'."} | |
candidate_original_ids = [self.indexed_ids[i] for i in candidate_faiss_indices if i < len(self.indexed_ids)] | |
# Fetch candidates from the main DataFrame, excluding the source item itself | |
candidates_df = self.df[self.df[self.id_col].isin(candidate_original_ids) & (self.df[self.id_col] != msid)].copy() | |
candidates_df = candidates_df.drop_duplicates(subset=[self.syn]) # Avoid duplicate content | |
if candidates_df.empty: | |
logger.info(f"No candidates left after filtering for item ID '{msid}'.") | |
return {"retrieved_documents": [], "generated_response": f"No other similar items found for ID '{msid}'."} | |
# Rerank using cross-encoder | |
rerank_pairs = [(source_item_content_for_reranking, str(row[self.syn])) for _, row in candidates_df.iterrows()] | |
if rerank_pairs: | |
rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False) | |
candidates_df["rerank_score"] = rerank_scores | |
candidates_df = candidates_df.sort_values("rerank_score", ascending=False) | |
candidates_df = candidates_df[candidates_df["rerank_score"] >= similarity_threshold] | |
else: | |
candidates_df["rerank_score"] = 0.0 # Default score if no pairs or reranking skipped | |
top_candidates = candidates_df.head(k) | |
retrieved_documents = self._format_retrieved_documents(top_candidates) | |
generated_response = f"Top {len(retrieved_documents)} recommendations similar to item ID '{msid}'." | |
if not retrieved_documents: | |
generated_response = f"No recommendations found similar to item ID '{msid}'." | |
return {"retrieved_documents": retrieved_documents, "generated_response": generated_response} | |
def prepare_reranker_training_data_from_new_feedback_format(self, user_id: str, training_event_details: Dict) -> List[Dict]: | |
""" | |
Prepares training data for the reranker model from a user's feedback document. | |
Generates both positive and negative training samples using semantic similarity | |
based negative sampling for better model discrimination. | |
""" | |
if self.df is None or self.df.empty: | |
logger.warning(f"User {user_id}: DataFrame not loaded. Cannot prepare training data.") | |
return [] | |
training_samples = [] | |
query_msid = training_event_details.get("query_msid") | |
positive_msids_list = training_event_details.get("positive_msids") | |
if not query_msid or not isinstance(query_msid, str): | |
logger.warning(f"User {user_id}: 'query_msid' missing or invalid in training_event_details. Details: {str(training_event_details)[:200]}") | |
return [] | |
if not isinstance(positive_msids_list, list) or not positive_msids_list: | |
logger.warning(f"User {user_id}: 'positive_msids' field missing, not a list, or empty in training_event_details for query_msid '{query_msid}'. Details: {str(training_event_details)[:200]}") | |
return [] | |
source_item_row = self.df[self.df[self.id_col] == query_msid] | |
if source_item_row.empty: | |
logger.warning(f"User {user_id}: Query item (msid: {query_msid}) for training data not found in DataFrame.") | |
return [] | |
query_text = str(source_item_row.iloc[0].get(self.syn, "")).strip() | |
if not query_text: | |
logger.warning(f"User {user_id}: Query text (synopsis) is empty for msid {query_msid}. Skipping training sample generation for this event.") | |
return [] | |
# Process positive samples | |
positive_samples = [] | |
for positive_msid in positive_msids_list: | |
if not isinstance(positive_msid, str) or not positive_msid.strip(): | |
continue | |
clicked_item_row = self.df[self.df[self.id_col] == positive_msid] | |
if not clicked_item_row.empty: | |
candidate_text_positive = str(clicked_item_row.iloc[0].get(self.syn, "")).strip() | |
if candidate_text_positive: | |
positive_samples.append({ | |
"query_text": query_text, | |
"candidate_text": candidate_text_positive, | |
"label": 1.0, | |
"msid": positive_msid | |
}) | |
if not positive_samples: | |
logger.warning(f"User {user_id}: No valid positive samples found for query_msid {query_msid}") | |
return [] | |
# Generate negative samples through semantic similarity based sampling | |
num_negatives_per_positive = 5 # Increased for better training | |
all_msids = set(self.df[self.id_col].tolist()) | |
positive_msids_set = set(p["msid"] for p in positive_samples) | |
# Get query embedding for semantic search | |
query_embedding, _ = self._generate_embeddings( | |
pd.DataFrame({ | |
self.processed_content_col: [query_text], | |
self.id_col: ["temp_query"] | |
}) | |
) | |
if query_embedding.size > 0: | |
# Get semantically similar candidates (harder negatives) | |
D, I = self.index.search(query_embedding.astype(np.float32), k=50) # Get more candidates | |
candidate_indices = I[0] | |
candidate_msids = [ | |
self.indexed_ids[idx] for idx in candidate_indices | |
if idx != -1 and idx < len(self.indexed_ids) | |
] | |
# Filter out positives and query item | |
negative_candidates = [ | |
msid for msid in candidate_msids | |
if msid not in positive_msids_set and msid != query_msid | |
] | |
import random | |
for pos_sample in positive_samples: | |
# Mix of hard and random negatives | |
num_hard_negatives = min(3, len(negative_candidates)) | |
num_random_negatives = num_negatives_per_positive - num_hard_negatives | |
# Select hard negatives (semantically similar) | |
hard_negatives = negative_candidates[:num_hard_negatives] | |
# Select random negatives from remaining pool | |
remaining_candidates = list(all_msids - positive_msids_set - set(hard_negatives) - {query_msid}) | |
random_negatives = random.sample(remaining_candidates, num_random_negatives) | |
# Combine hard and random negatives | |
selected_negatives = hard_negatives + random_negatives | |
for neg_msid in selected_negatives: | |
neg_item_row = self.df[self.df[self.id_col] == neg_msid] | |
if not neg_item_row.empty: | |
candidate_text_negative = str(neg_item_row.iloc[0].get(self.syn, "")).strip() | |
if candidate_text_negative: | |
training_samples.append({ | |
"query_text": pos_sample["query_text"], | |
"candidate_text": candidate_text_negative, | |
"label": 0.0, | |
"msid": neg_msid | |
}) | |
# Add positive samples to final training data | |
training_samples.extend(positive_samples) | |
if training_samples: | |
logger.info(f"User {user_id}: Prepared {len(training_samples)} training samples ({len(positive_samples)} positive, {len(training_samples)-len(positive_samples)} negative) from the interaction event (query_msid: {query_msid}).") | |
return training_samples | |
def _log_training_data_for_refinement(self, training_data: List[Dict]): | |
"""Save training data for future fine-tuning.""" | |
if not training_data: | |
logger.info("No training data provided for saving.") | |
return | |
self.model_trainer.prepare_training_data(training_data) | |
async def check_and_trigger_fine_tuning(self): | |
""" | |
Check if fine-tuning should be triggered based on conditions and start if needed. | |
Returns True if fine-tuning was triggered, False otherwise. | |
""" | |
try: | |
if not self.model_trainer.check_training_conditions(): | |
return False | |
# Start fine-tuning process in background | |
asyncio.create_task(self._run_fine_tuning_process()) | |
return True | |
except Exception as e: | |
logger.error(f"Error in check_and_trigger_fine_tuning: {e}") | |
return False | |
async def _run_fine_tuning_process(self): | |
"""Run the fine-tuning process in the background.""" | |
try: | |
# Start fine-tuning | |
logger.info("Starting fine-tuning process") | |
# Run fine-tuning using the model trainer | |
new_version = await asyncio.to_thread(self.model_trainer.fine_tune) | |
if new_version: | |
# Load and validate the fine-tuned model before deploying | |
model_path = str(self.model_trainer.get_model_path(new_version)) | |
if os.path.exists(model_path): | |
# Load the new model for validation | |
# This is a synchronous operation, potentially okay if quick, | |
# but could be moved to to_thread if model loading is slow. | |
new_model = await asyncio.to_thread(CrossEncoder, model_path, device=self.device) | |
# Validate model performance | |
validation_passed = await self._validate_fine_tuned_model(new_model) | |
if validation_passed: | |
logger.info(f"Fine-tuned model validation passed. Deploying version: {new_version}") | |
self.fine_tuned_reranker = new_model # new_model is already on self.device | |
self.reranker = self.fine_tuned_reranker # Switch active reranker | |
# Update embeddings and index if needed | |
# update_embeddings_and_index is synchronous and can be long | |
await asyncio.to_thread(self.update_embeddings_and_index) | |
logger.info(f"Fine-tuning process completed successfully. New version: {new_version}") | |
else: | |
logger.warning("Fine-tuned model validation failed. Keeping current model.") | |
self.reranker = self.base_reranker | |
else: | |
logger.error(f"Fine-tuned model file not found at {model_path}") | |
self.reranker = self.base_reranker | |
else: | |
logger.error("Fine-tuning process failed") | |
self.reranker = self.base_reranker | |
except Exception as e: | |
logger.error(f"Error during fine-tuning process: {e}") | |
self.reranker = self.base_reranker | |
async def _validate_fine_tuned_model(self, new_model: CrossEncoder) -> bool: | |
""" | |
Validate the fine-tuned model's performance before deployment. | |
Uses multiple metrics for a more comprehensive evaluation. | |
Returns True if validation passes, False otherwise. | |
""" | |
try: | |
# Get a sample of validation data | |
validation_data = self.model_trainer.get_validation_data() | |
if not validation_data: | |
logger.warning("No validation data available") | |
return False | |
# Initialize metrics | |
base_metrics = { | |
"true_positives": 0, | |
"false_positives": 0, | |
"true_negatives": 0, | |
"false_negatives": 0, | |
"scores": [] | |
} | |
new_metrics = { | |
"true_positives": 0, | |
"false_positives": 0, | |
"true_negatives": 0, | |
"false_negatives": 0, | |
"scores": [] | |
} | |
# Evaluate both models on validation data | |
for sample in validation_data: | |
query = sample["query_text"] | |
candidate = sample["candidate_text"] | |
label = float(sample["label"]) | |
# Get predictions from both models | |
# predict is synchronous and CPU/GPU bound | |
base_pred_array = await asyncio.to_thread(self.base_reranker.predict, [(query, candidate)]) | |
base_pred = base_pred_array[0] | |
new_pred_array = await asyncio.to_thread(new_model.predict, [(query, candidate)]) | |
new_pred = new_pred_array[0] | |
# Update metrics for base model | |
base_metrics["scores"].append(base_pred) | |
if label == 1.0: | |
if base_pred >= 0.5: | |
base_metrics["true_positives"] += 1 | |
else: | |
base_metrics["false_negatives"] += 1 | |
else: | |
if base_pred >= 0.5: | |
base_metrics["false_positives"] += 1 | |
else: | |
base_metrics["true_negatives"] += 1 | |
# Update metrics for new model | |
new_metrics["scores"].append(new_pred) | |
if label == 1.0: | |
if new_pred >= 0.5: | |
new_metrics["true_positives"] += 1 | |
else: | |
new_metrics["false_negatives"] += 1 | |
else: | |
if new_pred >= 0.5: | |
new_metrics["false_positives"] += 1 | |
else: | |
new_metrics["true_negatives"] += 1 | |
if not base_metrics["scores"] or not new_metrics["scores"]: | |
logger.warning("No predictions generated during validation") | |
return False | |
# Calculate metrics for both models | |
def calculate_model_metrics(metrics): | |
tp = metrics["true_positives"] | |
fp = metrics["false_positives"] | |
tn = metrics["true_negatives"] | |
fn = metrics["false_negatives"] | |
# Prevent division by zero | |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
recall = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0 | |
return { | |
"precision": precision, | |
"recall": recall, | |
"f1": f1, | |
"accuracy": accuracy | |
} | |
base_performance = calculate_model_metrics(base_metrics) | |
new_performance = calculate_model_metrics(new_metrics) | |
# Log detailed performance comparison | |
logger.info("Model validation results:") | |
logger.info(f"Base model metrics: {base_performance}") | |
logger.info(f"Fine-tuned model metrics: {new_performance}") | |
# More lenient validation criteria | |
min_relative_improvement = 0.001 # 0.1% minimum relative improvement | |
min_absolute_f1 = 0.3 # Lower minimum F1 score required | |
# Check if new model shows improvement in any metric | |
f1_improvement = new_performance["f1"] - base_performance["f1"] | |
precision_improvement = new_performance["precision"] - base_performance["precision"] | |
recall_improvement = new_performance["recall"] - base_performance["recall"] | |
accuracy_improvement = new_performance["accuracy"] - base_performance["accuracy"] | |
# Calculate relative improvements | |
relative_f1_imp = f1_improvement / base_performance["f1"] if base_performance["f1"] > 0 else float('inf') | |
relative_prec_imp = precision_improvement / base_performance["precision"] if base_performance["precision"] > 0 else float('inf') | |
relative_recall_imp = recall_improvement / base_performance["recall"] if base_performance["recall"] > 0 else float('inf') | |
relative_acc_imp = accuracy_improvement / base_performance["accuracy"] if base_performance["accuracy"] > 0 else float('inf') | |
# Model passes validation if it shows improvement in any metric and meets minimum F1 | |
if new_performance["f1"] >= min_absolute_f1 and ( | |
relative_f1_imp >= min_relative_improvement or | |
relative_prec_imp >= min_relative_improvement or | |
relative_recall_imp >= min_relative_improvement or | |
relative_acc_imp >= min_relative_improvement | |
): | |
logger.info( | |
f"Fine-tuned model shows improvement. Metrics changes:\n" | |
f"F1: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n" | |
f"Precision: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n" | |
f"Recall: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n" | |
f"Accuracy: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})" | |
) | |
return True | |
else: | |
logger.warning( | |
f"Fine-tuned model does not meet improvement criteria.\n" | |
f"F1 change: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n" | |
f"Precision change: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n" | |
f"Recall change: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n" | |
f"Accuracy change: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})" | |
) | |
return False | |
except Exception as e: | |
logger.error(f"Error during model validation: {e}") | |
return False | |
def reload_fine_tuned_model(self): | |
""" | |
Reload the fine-tuned reranker model from disk and switch to it if available. | |
""" | |
self._load_fine_tuned_model() | |
logger.info("Reloaded fine-tuned reranker model (if available).") | |
def get_recommendations_summary(self, msid: str, k: int = DEFAULT_K, summary: bool = True, smart_tip: bool = True): | |
""" | |
Synchronous wrapper for recommendations with summary and smart tip. This is a simplified version for Gradio or direct calls. | |
""" | |
# Get base recommendations by msid | |
recommendations_data = self.get_recommendations_by_id(msid, k) | |
if not recommendations_data or "retrieved_documents" not in recommendations_data: | |
return { | |
"generated_response": f"No recommendations found for item ID '{msid}'.", | |
"retrieved_documents": [] | |
} | |
retrieved_docs = recommendations_data.get("retrieved_documents", []) | |
if not retrieved_docs: | |
return recommendations_data | |
# Fetch article details from MongoDB | |
from src.database.mongodb import mongodb | |
doc_ids_to_fetch = [doc["id"] for doc in retrieved_docs if doc.get("id")] | |
articles_details_map = {} | |
if doc_ids_to_fetch and (summary or smart_tip): | |
projection = {"_id": 0, "id": 1} | |
if summary: | |
projection.update({"story": 1, "syn": 1}) | |
if smart_tip: | |
projection.update({"seolocation": 1, "tn": 1, "hl": 1}) | |
fetched_articles_list = list(mongodb.news_collection.find({"id": {"$in": doc_ids_to_fetch}}, projection)) | |
for article in fetched_articles_list: | |
if article.get("id"): | |
if summary and not article.get("story") and article.get("syn"): | |
article["story"] = article["syn"] | |
articles_details_map[article["id"]] = article | |
# Helper functions for summary and smart tip | |
def _generate_summary(article_data): | |
try: | |
from src.test_summarize import get_summary_points | |
story = article_data.get("story", "") | |
if not story: | |
return None | |
summary_points = get_summary_points(story) | |
if isinstance(summary_points, list): | |
return " ".join(summary_points) if summary_points else None | |
elif isinstance(summary_points, str): | |
return summary_points if summary_points.strip() else None | |
return None | |
except Exception: | |
return None | |
def _generate_smart_tip(article_data): | |
seolocation = article_data.get("seolocation") | |
title = article_data.get("tn") | |
headline = article_data.get("hl") | |
if not all([seolocation, title, headline]): | |
return None | |
# Find related articles | |
topic = title.lower() if title else "" | |
query = {} | |
if topic: | |
query["$or"] = [ | |
{"tn": {"$regex": topic, "$options": "i"}}, | |
{"hl": {"$regex": topic, "$options": "i"}} | |
] | |
if article_data.get("id"): | |
query["id"] = {"$ne": article_data["id"]} | |
related_articles = list(mongodb.news_collection.find(query, {"hl": 1, "seolocation": 1, "tn": 1, "_id": 0}).limit(3)) | |
suggestions = [] | |
for rel_article in related_articles: | |
if rel_article.get("hl") and rel_article.get("seolocation"): | |
suggestions.append({ | |
"label": rel_article.get("hl", ""), | |
"url": rel_article.get("seolocation", "") | |
}) | |
if not suggestions: | |
suggestions = [{ | |
"label": headline, | |
"url": seolocation | |
}] | |
return { | |
"title": f"\U0001F50D Smart Tip: {title}", | |
"description": "You might also be interested in:", | |
"suggestions": suggestions | |
} | |
# Process each document | |
processed_documents = [] | |
for doc in retrieved_docs: | |
article_data = articles_details_map.get(doc.get("id")) | |
if summary and article_data: | |
doc["summary"] = _generate_summary(article_data) | |
if smart_tip and article_data: | |
doc["smart_tip"] = _generate_smart_tip(article_data) | |
processed_documents.append(doc) | |
recommendations_data["retrieved_documents"] = processed_documents | |
return recommendations_data | |
def get_recommendations_user_feedback(self, user_id: str, msid: str, clicked_msid: str, k: int = DEFAULT_K): | |
""" | |
Synchronous wrapper for user feedback recommendations. Returns recommendations based on clicked items. | |
""" | |
# clicked_msid can be a comma-separated string | |
actual_clicked_msids = [s.strip() for s in clicked_msid.split(',') if s.strip()] | |
combined_recommendations_docs = [] | |
seen_recommendation_ids = set() | |
for c_msid in actual_clicked_msids: | |
result = self.get_recommendations_by_id(c_msid, k) | |
for doc in result.get("retrieved_documents", []): | |
if doc['id'] not in seen_recommendation_ids: | |
combined_recommendations_docs.append(doc) | |
seen_recommendation_ids.add(doc['id']) | |
if not combined_recommendations_docs: | |
recommendations_result = {"retrieved_documents": [], "generated_response": "No recommendations found for the clicked items."} | |
else: | |
combined_recommendations_docs.sort(key=lambda x: x.get('score', 0.0), reverse=True) | |
final_retrieved_documents = combined_recommendations_docs[:k] | |
recommendations_result = { | |
"retrieved_documents": final_retrieved_documents, | |
"generated_response": f"Top {len(final_retrieved_documents)} recommendations based on your recent clicks on: {', '.join(actual_clicked_msids)}." | |
} | |
return recommendations_result | |
# Instantiate the recommender for use by other modules | |
recommender = RecoRecommender() |