"""Utility functions for model fine-tuning.""" import json import logging from typing import Dict, List, Optional from pathlib import Path from datetime import datetime logger = logging.getLogger(__name__) # Constants DATA_DIR = Path("src/fine_tuning/data") TRAINING_DATA_FILE = DATA_DIR / "reranker_training_data.jsonl" MODEL_METADATA_FILE = DATA_DIR / "model_metadata.json" USER_FEEDBACK_FILE = DATA_DIR / "user_feedback.jsonl" MODEL_DIR = DATA_DIR / "models/fine_tuned" MAX_OLD_MODELS = 3 # Maximum number of old model versions to keep def save_training_data(training_samples: List[Dict], append: bool = True) -> None: """Save training samples to a JSONL file.""" try: DATA_DIR.mkdir(parents=True, exist_ok=True) mode = 'a' if append else 'w' with open(TRAINING_DATA_FILE, mode, encoding='utf-8') as f: for sample in training_samples: json.dump(sample, f, ensure_ascii=False) f.write('\n') logger.info(f"Saved {len(training_samples)} training samples to {TRAINING_DATA_FILE}") except Exception as e: logger.error(f"Error saving training data: {e}") def load_training_data() -> List[Dict]: """Load training samples from the JSONL file.""" samples = [] try: if TRAINING_DATA_FILE.exists(): with open(TRAINING_DATA_FILE, 'r', encoding='utf-8') as f: for line in f: if line.strip(): samples.append(json.loads(line)) logger.info(f"Loaded {len(samples)} training samples from {TRAINING_DATA_FILE}") except Exception as e: logger.error(f"Error loading training data: {e}") return samples def save_model_metadata(metadata: Dict) -> None: """Save model metadata to a JSON file.""" try: DATA_DIR.mkdir(parents=True, exist_ok=True) # Ensure the metadata contains a timestamp metadata['last_updated'] = datetime.now().isoformat() with open(MODEL_METADATA_FILE, 'w', encoding='utf-8') as f: json.dump(metadata, f, ensure_ascii=False, indent=2) logger.info(f"Saved model metadata to {MODEL_METADATA_FILE}") except Exception as e: logger.error(f"Error saving model metadata: {e}") def load_model_metadata() -> Optional[Dict]: """Load model metadata from the JSON file.""" try: if MODEL_METADATA_FILE.exists(): with open(MODEL_METADATA_FILE, 'r', encoding='utf-8') as f: metadata = json.load(f) return metadata except Exception as e: logger.error(f"Error loading model metadata: {e}") return None def get_model_path(version: str) -> Path: """Get the path for a specific model version.""" MODEL_DIR.mkdir(parents=True, exist_ok=True) return MODEL_DIR / f"reranker_{version}" def get_latest_model_version() -> str: """Get the latest model version from existing files.""" try: versions = [] if MODEL_DIR.exists(): for path in MODEL_DIR.glob("reranker_v*"): version = path.name.split('_')[-1] # Extract version from filename if version.startswith('v') and version[1:].isdigit(): versions.append(version) return max(versions, default="v0", key=lambda x: int(x[1:])) except Exception as e: logger.error(f"Error getting latest model version: {e}") return "v0" def cleanup_old_models() -> None: """Remove old model versions, keeping only the most recent ones.""" try: if MODEL_DIR.exists(): versions = [] for path in MODEL_DIR.glob("reranker_v*"): version = path.name.split('_')[-1] if version.startswith('v') and version[1:].isdigit(): versions.append((version, path)) # Sort by version number (descending) versions.sort(key=lambda x: int(x[0][1:]), reverse=True) # Remove old versions beyond the limit for version, path in versions[MAX_OLD_MODELS:]: try: path.unlink() # Delete the file logger.info(f"Removed old model version: {version}") except Exception as e: logger.error(f"Error removing model version {version}: {e}") except Exception as e: logger.error(f"Error during model cleanup: {e}") def load_user_feedback() -> Dict[str, Dict]: """ Load user feedback data from the feedback tracking database. Returns a dictionary mapping query-candidate pairs to feedback information. """ feedback_data = {} try: if USER_FEEDBACK_FILE.exists(): with open(USER_FEEDBACK_FILE, 'r', encoding='utf-8') as f: for line in f: if line.strip(): feedback = json.loads(line) # Create a unique key for the query-candidate pair key = f"{feedback['query_text']}_{feedback['candidate_text']}" # Store feedback with confidence and timestamp feedback_data[key] = { 'rating': feedback['rating'], 'confidence': feedback.get('confidence', 1.0), 'timestamp': feedback.get('timestamp', datetime.now().isoformat()), 'user_id': feedback.get('user_id', 'anonymous'), 'interaction_type': feedback.get('interaction_type', 'explicit'), 'session_id': feedback.get('session_id', None) } logger.info(f"Loaded {len(feedback_data)} user feedback entries") except Exception as e: logger.error(f"Error loading user feedback data: {e}") return feedback_data def save_user_feedback(feedback: Dict) -> None: """ Save a user feedback entry to the feedback tracking database. """ try: DATA_DIR.mkdir(parents=True, exist_ok=True) with open(USER_FEEDBACK_FILE, 'a', encoding='utf-8') as f: feedback['timestamp'] = datetime.now().isoformat() json.dump(feedback, f, ensure_ascii=False) f.write('\n') logger.info(f"Saved user feedback for query: {feedback.get('query_text', 'unknown')}") except Exception as e: logger.error(f"Error saving user feedback: {e}")