#!/usr/bin/env python3 """ Enhanced checkpointing system for BitTransformerLM with multiple training runs support. Optimized for Claude Code environment with HF Pro + 20GB persistent storage. """ import os import json import shutil import logging from pathlib import Path from typing import Dict, Any, Optional, List, Union from datetime import datetime import torch from huggingface_hub import HfApi, hf_hub_download from bit_transformer.error_handling import with_error_recovery, safe_operation from bit_transformer.types import PathLike, ModelConfig, TrainingConfig logger = logging.getLogger(__name__) class EnhancedCheckpointManager: """Advanced checkpoint management for multiple training runs with HF integration.""" def __init__(self, base_dir: PathLike = "/data/checkpoints", hf_repo_id: str = "WCNegentropy/BitTransformerLM", hf_token: Optional[str] = None, max_local_checkpoints: int = 5): self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.hf_repo_id = hf_repo_id self.hf_token = hf_token or os.getenv("HF_TOKEN") self.api = HfApi(token=self.hf_token) if self.hf_token else None self.max_local_checkpoints = max_local_checkpoints # Training session tracking self.sessions_dir = self.base_dir / "training_sessions" self.sessions_dir.mkdir(exist_ok=True) # Best models storage self.best_models_dir = self.base_dir / "best_models" self.best_models_dir.mkdir(exist_ok=True) def create_training_session(self, session_name: str, model_config: ModelConfig, training_config: TrainingConfig) -> str: """Create a new training session with metadata.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") session_id = f"{session_name}_{timestamp}" session_dir = self.sessions_dir / session_id session_dir.mkdir(exist_ok=True) # Save session metadata metadata = { "session_id": session_id, "session_name": session_name, "created_at": timestamp, "model_config": model_config, "training_config": training_config, "checkpoints": [], "best_metric": None, "status": "active" } with open(session_dir / "metadata.json", "w") as f: json.dump(metadata, f, indent=2, default=str) logger.info(f"Created training session: {session_id}") return session_id @with_error_recovery(recovery_value=False) def save_checkpoint(self, model: torch.nn.Module, session_id: str, epoch: int, metrics: Dict[str, float], optimizer_state: Optional[Dict] = None, scheduler_state: Optional[Dict] = None, additional_data: Optional[Dict] = None) -> bool: """Save checkpoint with comprehensive metadata.""" session_dir = self.sessions_dir / session_id if not session_dir.exists(): raise ValueError(f"Training session {session_id} not found") # Create checkpoint directory checkpoint_name = f"checkpoint_epoch_{epoch:04d}" checkpoint_dir = session_dir / checkpoint_name checkpoint_dir.mkdir(exist_ok=True) # Save model state model_path = checkpoint_dir / "model.pt" torch.save({ 'model_state_dict': model.state_dict(), 'epoch': epoch, 'metrics': metrics, 'model_config': getattr(model, 'config', {}), 'timestamp': datetime.now().isoformat() }, model_path) # Save optimizer state if provided if optimizer_state: torch.save(optimizer_state, checkpoint_dir / "optimizer.pt") # Save scheduler state if provided if scheduler_state: torch.save(scheduler_state, checkpoint_dir / "scheduler.pt") # Save additional data if additional_data: with open(checkpoint_dir / "additional_data.json", "w") as f: json.dump(additional_data, f, indent=2, default=str) # Update session metadata self._update_session_metadata(session_id, checkpoint_name, metrics) # Cleanup old checkpoints to save space self._cleanup_old_checkpoints(session_dir) logger.info(f"Saved checkpoint {checkpoint_name} for session {session_id}") return True def load_checkpoint(self, session_id: str, checkpoint_name: Optional[str] = None, model: Optional[torch.nn.Module] = None) -> Dict[str, Any]: """Load checkpoint with all associated data.""" session_dir = self.sessions_dir / session_id if not session_dir.exists(): raise ValueError(f"Training session {session_id} not found") # Use latest checkpoint if none specified if checkpoint_name is None: checkpoints = [d for d in session_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint_")] if not checkpoints: raise ValueError(f"No checkpoints found for session {session_id}") checkpoint_name = max(checkpoints, key=lambda x: x.name).name checkpoint_dir = session_dir / checkpoint_name if not checkpoint_dir.exists(): raise ValueError(f"Checkpoint {checkpoint_name} not found in session {session_id}") # Load model state model_path = checkpoint_dir / "model.pt" checkpoint_data = torch.load(model_path, map_location='cpu', weights_only=False) if model is not None: model.load_state_dict(checkpoint_data['model_state_dict']) # Load optimizer state if exists optimizer_state = None optimizer_path = checkpoint_dir / "optimizer.pt" if optimizer_path.exists(): optimizer_state = torch.load(optimizer_path, map_location='cpu', weights_only=False) # Load scheduler state if exists scheduler_state = None scheduler_path = checkpoint_dir / "scheduler.pt" if scheduler_path.exists(): scheduler_state = torch.load(scheduler_path, map_location='cpu', weights_only=False) # Load additional data if exists additional_data = {} additional_path = checkpoint_dir / "additional_data.json" if additional_path.exists(): with open(additional_path) as f: additional_data = json.load(f) return { 'model_data': checkpoint_data, 'optimizer_state': optimizer_state, 'scheduler_state': scheduler_state, 'additional_data': additional_data, 'checkpoint_path': str(checkpoint_dir) } def save_best_model(self, session_id: str, model: torch.nn.Module, metric_name: str, metric_value: float, is_better_func: callable = lambda x, y: x > y) -> bool: """Save model if it achieves best performance.""" best_model_path = self.best_models_dir / f"{session_id}_best.pt" best_meta_path = self.best_models_dir / f"{session_id}_best_meta.json" # Check if this is the best model so far current_best = None if best_meta_path.exists(): with open(best_meta_path) as f: current_best = json.load(f) if current_best is None or is_better_func(metric_value, current_best['metric_value']): # Save new best model torch.save({ 'model_state_dict': model.state_dict(), 'metric_name': metric_name, 'metric_value': metric_value, 'session_id': session_id, 'timestamp': datetime.now().isoformat() }, best_model_path) # Save metadata with open(best_meta_path, "w") as f: json.dump({ 'metric_name': metric_name, 'metric_value': metric_value, 'session_id': session_id, 'timestamp': datetime.now().isoformat() }, f, indent=2) logger.info(f"New best model saved for session {session_id}: {metric_name}={metric_value}") return True return False def push_to_hf(self, session_id: str, checkpoint_name: Optional[str] = None, include_optimizer: bool = False) -> bool: """Push checkpoint to HuggingFace Hub.""" if not self.api: logger.error("HuggingFace API not available - check token") return False try: checkpoint_data = self.load_checkpoint(session_id, checkpoint_name) checkpoint_dir = Path(checkpoint_data['checkpoint_path']) # Upload model weights self.api.upload_file( path_or_fileobj=str(checkpoint_dir / "model.pt"), path_in_repo=f"checkpoints/{session_id}/model.pt", repo_id=self.hf_repo_id, commit_message=f"Upload checkpoint {checkpoint_name or 'latest'} from session {session_id}" ) # Upload optimizer state if requested and exists if include_optimizer and (checkpoint_dir / "optimizer.pt").exists(): self.api.upload_file( path_or_fileobj=str(checkpoint_dir / "optimizer.pt"), path_in_repo=f"checkpoints/{session_id}/optimizer.pt", repo_id=self.hf_repo_id ) logger.info(f"Successfully pushed checkpoint to HuggingFace: {self.hf_repo_id}") return True except Exception as e: logger.error(f"Failed to push to HuggingFace: {e}") return False def pull_from_hf(self, session_id: str, local_session_id: Optional[str] = None) -> bool: """Pull checkpoint from HuggingFace Hub.""" if not self.api: logger.error("HuggingFace API not available - check token") return False try: local_session = local_session_id or session_id local_dir = self.sessions_dir / local_session / "checkpoint_from_hf" local_dir.mkdir(parents=True, exist_ok=True) # Download model weights model_file = hf_hub_download( repo_id=self.hf_repo_id, filename=f"checkpoints/{session_id}/model.pt", local_dir=str(local_dir), local_dir_use_symlinks=False ) logger.info(f"Successfully pulled checkpoint from HuggingFace to {local_dir}") return True except Exception as e: logger.error(f"Failed to pull from HuggingFace: {e}") return False def get_storage_usage(self) -> Dict[str, Any]: """Get detailed storage usage breakdown.""" def get_dir_size(path: Path) -> int: total = 0 for item in path.rglob('*'): if item.is_file(): total += item.stat().st_size return total usage = { 'total_gb': get_dir_size(self.base_dir) / 1e9, 'sessions_gb': get_dir_size(self.sessions_dir) / 1e9, 'best_models_gb': get_dir_size(self.best_models_dir) / 1e9, 'num_sessions': len(list(self.sessions_dir.iterdir())), 'num_best_models': len(list(self.best_models_dir.glob('*_best.pt'))), } # Get per-session breakdown sessions = [] for session_dir in self.sessions_dir.iterdir(): if session_dir.is_dir(): sessions.append({ 'session_id': session_dir.name, 'size_gb': get_dir_size(session_dir) / 1e9, 'num_checkpoints': len(list(session_dir.glob('checkpoint_*'))) }) usage['sessions'] = sorted(sessions, key=lambda x: x['size_gb'], reverse=True) return usage def _update_session_metadata(self, session_id: str, checkpoint_name: str, metrics: Dict[str, float]): """Update session metadata with new checkpoint info.""" metadata_path = self.sessions_dir / session_id / "metadata.json" with open(metadata_path) as f: metadata = json.load(f) metadata['checkpoints'].append({ 'name': checkpoint_name, 'metrics': metrics, 'timestamp': datetime.now().isoformat() }) # Update best metric if applicable if 'loss' in metrics: if metadata['best_metric'] is None or metrics['loss'] < metadata['best_metric'].get('loss', float('inf')): metadata['best_metric'] = metrics.copy() with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2, default=str) def _cleanup_old_checkpoints(self, session_dir: Path): """Remove oldest checkpoints to stay within limits.""" checkpoints = sorted([d for d in session_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint_")], key=lambda x: x.stat().st_mtime) while len(checkpoints) > self.max_local_checkpoints: old_checkpoint = checkpoints.pop(0) shutil.rmtree(old_checkpoint) logger.info(f"Cleaned up old checkpoint: {old_checkpoint.name}") # Convenience functions for easy usage def create_checkpoint_manager(hf_token: str = "os.environ.get('HF_TOKEN', 'your-token-here')") -> EnhancedCheckpointManager: """Create a pre-configured checkpoint manager for this environment.""" return EnhancedCheckpointManager( base_dir="/data/checkpoints", hf_repo_id="WCNegentropy/BitTransformerLM", hf_token=hf_token, max_local_checkpoints=3 # Conservative for 20GB storage ) if __name__ == "__main__": # Demo usage manager = create_checkpoint_manager() usage = manager.get_storage_usage() print(f"Current storage usage: {usage['total_gb']:.2f} GB") print(f"Number of training sessions: {usage['num_sessions']}")