random2345t6 / rag_system.py
SakibAhmed's picture
Upload 8 files
eac6673 verified
# rag_system.py
import os
import logging
import shutil
import json
from typing import Optional
from rag_components import KnowledgeRAG
from utils import download_and_unzip_gdrive_folder
from config import (
GROQ_API_KEY, GDRIVE_SOURCES_ENABLED, GDRIVE_FOLDER_ID_OR_URL, RAG_SOURCES_DIR,
RAG_STORAGE_PARENT_DIR, RAG_FAISS_INDEX_SUBDIR_NAME, RAG_LOAD_INDEX_ON_STARTUP,
RAG_EMBEDDING_MODEL_NAME, RAG_LLM_MODEL_NAME,
RAG_EMBEDDING_USE_GPU, RAG_LLM_TEMPERATURE, RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP,
RAG_RERANKER_MODEL_NAME, RAG_RERANKER_ENABLED, RAG_CHUNKED_SOURCES_FILENAME
)
logger = logging.getLogger(__name__)
# MODIFIED: Added source_dir_override parameter
def initialize_and_get_rag_system(force_rebuild: bool = False, source_dir_override: Optional[str] = None) -> Optional[KnowledgeRAG]:
"""
Initializes and returns the KnowledgeRAG system.
Can force a rebuild by deleting the existing index first.
Uses module-level configuration constants.
Downloads sources from GDrive if configured.
"""
logger.info("[RAG_SYSTEM_INIT] ========== Initializing RAG System ==========")
if not GROQ_API_KEY:
logger.error("[RAG_SYSTEM_INIT] Groq API Key (BOT_API_KEY) not found. RAG system cannot be initialized.")
return None
# MODIFIED: Determine the source directory to use
source_dir_to_use = source_dir_override if source_dir_override and os.path.isdir(source_dir_override) else RAG_SOURCES_DIR
if source_dir_override and not os.path.isdir(source_dir_override):
logger.error(f"[RAG_SYSTEM_INIT] Custom source directory override '{source_dir_override}' not found. Aborting.")
return None # Or handle error appropriately
logger.info(f"[RAG_SYSTEM_INIT] Using source directory: '{source_dir_to_use}'")
if GDRIVE_SOURCES_ENABLED and not source_dir_override: # Only download if not using a custom directory
logger.info("[RAG_SYSTEM_INIT] Google Drive sources download is ENABLED")
if GDRIVE_FOLDER_ID_OR_URL:
# ... (rest of GDrive logic is unchanged)
logger.info(f"[RAG_SYSTEM_INIT] Downloading from Google Drive: {GDRIVE_FOLDER_ID_OR_URL}")
if os.path.isdir(RAG_SOURCES_DIR):
logger.info(f"[RAG_SYSTEM_INIT] Clearing existing contents of {RAG_SOURCES_DIR}")
try:
for item_name in os.listdir(RAG_SOURCES_DIR):
item_path = os.path.join(RAG_SOURCES_DIR, item_name)
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
logger.info(f"[RAG_SYSTEM_INIT] Successfully cleared {RAG_SOURCES_DIR}")
except Exception as e_clear:
logger.error(f"[RAG_SYSTEM_INIT] Could not clear {RAG_SOURCES_DIR}: {e_clear}")
download_successful = download_and_unzip_gdrive_folder(GDRIVE_FOLDER_ID_OR_URL, RAG_SOURCES_DIR)
if download_successful:
logger.info(f"[RAG_SYSTEM_INIT] Successfully populated sources from Google Drive")
else:
logger.error("[RAG_SYSTEM_INIT] Failed to download sources from Google Drive")
else:
logger.warning("[RAG_SYSTEM_INIT] GDRIVE_SOURCES_ENABLED is True but GDRIVE_FOLDER_URL not set")
elif not source_dir_override:
logger.info("[RAG_SYSTEM_INIT] Google Drive sources download is DISABLED")
faiss_index_actual_path = os.path.join(RAG_STORAGE_PARENT_DIR, RAG_FAISS_INDEX_SUBDIR_NAME)
processed_files_metadata_path = os.path.join(faiss_index_actual_path, "processed_files.json")
if force_rebuild:
logger.info(f"[RAG_SYSTEM_INIT] Force rebuild: Deleting existing FAISS index at '{faiss_index_actual_path}'")
if os.path.exists(faiss_index_actual_path):
try:
shutil.rmtree(faiss_index_actual_path)
logger.info(f"[RAG_SYSTEM_INIT] Deleted existing FAISS index")
except Exception as e_del:
logger.error(f"[RAG_SYSTEM_INIT] Could not delete existing FAISS index: {e_del}", exc_info=True)
try:
logger.info("[RAG_SYSTEM_INIT] Creating KnowledgeRAG instance...")
current_rag_instance = KnowledgeRAG(
index_storage_dir=RAG_STORAGE_PARENT_DIR,
embedding_model_name=RAG_EMBEDDING_MODEL_NAME,
groq_model_name_for_rag=RAG_LLM_MODEL_NAME,
use_gpu_for_embeddings=RAG_EMBEDDING_USE_GPU,
groq_api_key_for_rag=GROQ_API_KEY,
temperature=RAG_LLM_TEMPERATURE,
chunk_size=RAG_CHUNK_SIZE,
chunk_overlap=RAG_CHUNK_OVERLAP,
reranker_model_name=RAG_RERANKER_MODEL_NAME,
enable_reranker=RAG_RERANKER_ENABLED,
)
operation_successful = False
if RAG_LOAD_INDEX_ON_STARTUP and not force_rebuild:
logger.info(f"[RAG_SYSTEM_INIT] Attempting to load index from disk")
try:
current_rag_instance.load_index_from_disk()
operation_successful = True
logger.info(f"[RAG_SYSTEM_INIT] Index loaded successfully from: {faiss_index_actual_path}")
except FileNotFoundError:
logger.warning(f"[RAG_SYSTEM_INIT] Pre-built index not found. Will build from source files")
except Exception as e_load:
logger.error(f"[RAG_SYSTEM_INIT] Error loading index: {e_load}. Will build from source files", exc_info=True)
if not operation_successful:
logger.info(f"[RAG_SYSTEM_INIT] Building new index from source data in '{source_dir_to_use}'") # MODIFIED: Use correct dir
try:
pre_chunked_path = os.path.join(RAG_STORAGE_PARENT_DIR, RAG_CHUNKED_SOURCES_FILENAME)
if not os.path.exists(pre_chunked_path) and (not os.path.isdir(source_dir_to_use) or not os.listdir(source_dir_to_use)): # MODIFIED: Use correct dir
logger.error(f"[RAG_SYSTEM_INIT] Neither pre-chunked JSON nor raw source files found")
os.makedirs(faiss_index_actual_path, exist_ok=True)
with open(os.path.join(faiss_index_actual_path, "index.faiss"), "w") as f_dummy: f_dummy.write("")
with open(os.path.join(faiss_index_actual_path, "index.pkl"), "w") as f_dummy: f_dummy.write("")
logger.info("[RAG_SYSTEM_INIT] Created dummy index files")
current_rag_instance.processed_source_files = ["No source files found to build index."]
raise FileNotFoundError(f"Sources directory '{source_dir_to_use}' is empty") # MODIFIED: Use correct dir
current_rag_instance.build_index_from_source_files(
source_folder_path=source_dir_to_use # MODIFIED: Use correct dir
)
os.makedirs(faiss_index_actual_path, exist_ok=True)
with open(processed_files_metadata_path, 'w') as f:
json.dump(current_rag_instance.processed_source_files, f)
operation_successful = True
logger.info(f"[RAG_SYSTEM_INIT] Index built successfully from source data")
except FileNotFoundError as e_fnf:
logger.critical(f"[RAG_SYSTEM_INIT] FATAL: No source data found: {e_fnf}", exc_info=False)
return None
except ValueError as e_val:
logger.critical(f"[RAG_SYSTEM_INIT] FATAL: No processable documents found: {e_val}", exc_info=False)
return None
except Exception as e_build:
logger.critical(f"[RAG_SYSTEM_INIT] FATAL: Failed to build FAISS index: {e_build}", exc_info=True)
return None
if operation_successful and current_rag_instance.vector_store:
logger.info("[RAG_SYSTEM_INIT] ========== RAG System Initialized Successfully ==========")
return current_rag_instance
else:
logger.error("[RAG_SYSTEM_INIT] Index was neither loaded nor built successfully")
return None
except Exception as e_init_components:
logger.critical(f"[RAG_SYSTEM_INIT] FATAL: Failed to initialize RAG system components: {e_init_components}", exc_info=True)
return None