|
|
|
""" |
|
Enhanced checkpointing system for BitTransformerLM with multiple training runs support. |
|
Optimized for Claude Code environment with HF Pro + 20GB persistent storage. |
|
""" |
|
|
|
import os |
|
import json |
|
import shutil |
|
import logging |
|
from pathlib import Path |
|
from typing import Dict, Any, Optional, List, Union |
|
from datetime import datetime |
|
import torch |
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
|
from bit_transformer.error_handling import with_error_recovery, safe_operation |
|
from bit_transformer.types import PathLike, ModelConfig, TrainingConfig |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EnhancedCheckpointManager: |
|
"""Advanced checkpoint management for multiple training runs with HF integration.""" |
|
|
|
def __init__(self, |
|
base_dir: PathLike = "/data/checkpoints", |
|
hf_repo_id: str = "WCNegentropy/BitTransformerLM", |
|
hf_token: Optional[str] = None, |
|
max_local_checkpoints: int = 5): |
|
|
|
self.base_dir = Path(base_dir) |
|
self.base_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
self.hf_repo_id = hf_repo_id |
|
self.hf_token = hf_token or os.getenv("HF_TOKEN") |
|
self.api = HfApi(token=self.hf_token) if self.hf_token else None |
|
|
|
self.max_local_checkpoints = max_local_checkpoints |
|
|
|
|
|
self.sessions_dir = self.base_dir / "training_sessions" |
|
self.sessions_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.best_models_dir = self.base_dir / "best_models" |
|
self.best_models_dir.mkdir(exist_ok=True) |
|
|
|
def create_training_session(self, |
|
session_name: str, |
|
model_config: ModelConfig, |
|
training_config: TrainingConfig) -> str: |
|
"""Create a new training session with metadata.""" |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
session_id = f"{session_name}_{timestamp}" |
|
session_dir = self.sessions_dir / session_id |
|
session_dir.mkdir(exist_ok=True) |
|
|
|
|
|
metadata = { |
|
"session_id": session_id, |
|
"session_name": session_name, |
|
"created_at": timestamp, |
|
"model_config": model_config, |
|
"training_config": training_config, |
|
"checkpoints": [], |
|
"best_metric": None, |
|
"status": "active" |
|
} |
|
|
|
with open(session_dir / "metadata.json", "w") as f: |
|
json.dump(metadata, f, indent=2, default=str) |
|
|
|
logger.info(f"Created training session: {session_id}") |
|
return session_id |
|
|
|
@with_error_recovery(recovery_value=False) |
|
def save_checkpoint(self, |
|
model: torch.nn.Module, |
|
session_id: str, |
|
epoch: int, |
|
metrics: Dict[str, float], |
|
optimizer_state: Optional[Dict] = None, |
|
scheduler_state: Optional[Dict] = None, |
|
additional_data: Optional[Dict] = None) -> bool: |
|
"""Save checkpoint with comprehensive metadata.""" |
|
|
|
session_dir = self.sessions_dir / session_id |
|
if not session_dir.exists(): |
|
raise ValueError(f"Training session {session_id} not found") |
|
|
|
|
|
checkpoint_name = f"checkpoint_epoch_{epoch:04d}" |
|
checkpoint_dir = session_dir / checkpoint_name |
|
checkpoint_dir.mkdir(exist_ok=True) |
|
|
|
|
|
model_path = checkpoint_dir / "model.pt" |
|
torch.save({ |
|
'model_state_dict': model.state_dict(), |
|
'epoch': epoch, |
|
'metrics': metrics, |
|
'model_config': getattr(model, 'config', {}), |
|
'timestamp': datetime.now().isoformat() |
|
}, model_path) |
|
|
|
|
|
if optimizer_state: |
|
torch.save(optimizer_state, checkpoint_dir / "optimizer.pt") |
|
|
|
|
|
if scheduler_state: |
|
torch.save(scheduler_state, checkpoint_dir / "scheduler.pt") |
|
|
|
|
|
if additional_data: |
|
with open(checkpoint_dir / "additional_data.json", "w") as f: |
|
json.dump(additional_data, f, indent=2, default=str) |
|
|
|
|
|
self._update_session_metadata(session_id, checkpoint_name, metrics) |
|
|
|
|
|
self._cleanup_old_checkpoints(session_dir) |
|
|
|
logger.info(f"Saved checkpoint {checkpoint_name} for session {session_id}") |
|
return True |
|
|
|
def load_checkpoint(self, |
|
session_id: str, |
|
checkpoint_name: Optional[str] = None, |
|
model: Optional[torch.nn.Module] = None) -> Dict[str, Any]: |
|
"""Load checkpoint with all associated data.""" |
|
|
|
session_dir = self.sessions_dir / session_id |
|
if not session_dir.exists(): |
|
raise ValueError(f"Training session {session_id} not found") |
|
|
|
|
|
if checkpoint_name is None: |
|
checkpoints = [d for d in session_dir.iterdir() |
|
if d.is_dir() and d.name.startswith("checkpoint_")] |
|
if not checkpoints: |
|
raise ValueError(f"No checkpoints found for session {session_id}") |
|
checkpoint_name = max(checkpoints, key=lambda x: x.name).name |
|
|
|
checkpoint_dir = session_dir / checkpoint_name |
|
if not checkpoint_dir.exists(): |
|
raise ValueError(f"Checkpoint {checkpoint_name} not found in session {session_id}") |
|
|
|
|
|
model_path = checkpoint_dir / "model.pt" |
|
checkpoint_data = torch.load(model_path, map_location='cpu', weights_only=False) |
|
|
|
if model is not None: |
|
model.load_state_dict(checkpoint_data['model_state_dict']) |
|
|
|
|
|
optimizer_state = None |
|
optimizer_path = checkpoint_dir / "optimizer.pt" |
|
if optimizer_path.exists(): |
|
optimizer_state = torch.load(optimizer_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
scheduler_state = None |
|
scheduler_path = checkpoint_dir / "scheduler.pt" |
|
if scheduler_path.exists(): |
|
scheduler_state = torch.load(scheduler_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
additional_data = {} |
|
additional_path = checkpoint_dir / "additional_data.json" |
|
if additional_path.exists(): |
|
with open(additional_path) as f: |
|
additional_data = json.load(f) |
|
|
|
return { |
|
'model_data': checkpoint_data, |
|
'optimizer_state': optimizer_state, |
|
'scheduler_state': scheduler_state, |
|
'additional_data': additional_data, |
|
'checkpoint_path': str(checkpoint_dir) |
|
} |
|
|
|
def save_best_model(self, |
|
session_id: str, |
|
model: torch.nn.Module, |
|
metric_name: str, |
|
metric_value: float, |
|
is_better_func: callable = lambda x, y: x > y) -> bool: |
|
"""Save model if it achieves best performance.""" |
|
|
|
best_model_path = self.best_models_dir / f"{session_id}_best.pt" |
|
best_meta_path = self.best_models_dir / f"{session_id}_best_meta.json" |
|
|
|
|
|
current_best = None |
|
if best_meta_path.exists(): |
|
with open(best_meta_path) as f: |
|
current_best = json.load(f) |
|
|
|
if current_best is None or is_better_func(metric_value, current_best['metric_value']): |
|
|
|
torch.save({ |
|
'model_state_dict': model.state_dict(), |
|
'metric_name': metric_name, |
|
'metric_value': metric_value, |
|
'session_id': session_id, |
|
'timestamp': datetime.now().isoformat() |
|
}, best_model_path) |
|
|
|
|
|
with open(best_meta_path, "w") as f: |
|
json.dump({ |
|
'metric_name': metric_name, |
|
'metric_value': metric_value, |
|
'session_id': session_id, |
|
'timestamp': datetime.now().isoformat() |
|
}, f, indent=2) |
|
|
|
logger.info(f"New best model saved for session {session_id}: {metric_name}={metric_value}") |
|
return True |
|
|
|
return False |
|
|
|
def push_to_hf(self, |
|
session_id: str, |
|
checkpoint_name: Optional[str] = None, |
|
include_optimizer: bool = False) -> bool: |
|
"""Push checkpoint to HuggingFace Hub.""" |
|
|
|
if not self.api: |
|
logger.error("HuggingFace API not available - check token") |
|
return False |
|
|
|
try: |
|
checkpoint_data = self.load_checkpoint(session_id, checkpoint_name) |
|
checkpoint_dir = Path(checkpoint_data['checkpoint_path']) |
|
|
|
|
|
self.api.upload_file( |
|
path_or_fileobj=str(checkpoint_dir / "model.pt"), |
|
path_in_repo=f"checkpoints/{session_id}/model.pt", |
|
repo_id=self.hf_repo_id, |
|
commit_message=f"Upload checkpoint {checkpoint_name or 'latest'} from session {session_id}" |
|
) |
|
|
|
|
|
if include_optimizer and (checkpoint_dir / "optimizer.pt").exists(): |
|
self.api.upload_file( |
|
path_or_fileobj=str(checkpoint_dir / "optimizer.pt"), |
|
path_in_repo=f"checkpoints/{session_id}/optimizer.pt", |
|
repo_id=self.hf_repo_id |
|
) |
|
|
|
logger.info(f"Successfully pushed checkpoint to HuggingFace: {self.hf_repo_id}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to push to HuggingFace: {e}") |
|
return False |
|
|
|
def pull_from_hf(self, |
|
session_id: str, |
|
local_session_id: Optional[str] = None) -> bool: |
|
"""Pull checkpoint from HuggingFace Hub.""" |
|
|
|
if not self.api: |
|
logger.error("HuggingFace API not available - check token") |
|
return False |
|
|
|
try: |
|
local_session = local_session_id or session_id |
|
local_dir = self.sessions_dir / local_session / "checkpoint_from_hf" |
|
local_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
model_file = hf_hub_download( |
|
repo_id=self.hf_repo_id, |
|
filename=f"checkpoints/{session_id}/model.pt", |
|
local_dir=str(local_dir), |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
logger.info(f"Successfully pulled checkpoint from HuggingFace to {local_dir}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to pull from HuggingFace: {e}") |
|
return False |
|
|
|
def get_storage_usage(self) -> Dict[str, Any]: |
|
"""Get detailed storage usage breakdown.""" |
|
|
|
def get_dir_size(path: Path) -> int: |
|
total = 0 |
|
for item in path.rglob('*'): |
|
if item.is_file(): |
|
total += item.stat().st_size |
|
return total |
|
|
|
usage = { |
|
'total_gb': get_dir_size(self.base_dir) / 1e9, |
|
'sessions_gb': get_dir_size(self.sessions_dir) / 1e9, |
|
'best_models_gb': get_dir_size(self.best_models_dir) / 1e9, |
|
'num_sessions': len(list(self.sessions_dir.iterdir())), |
|
'num_best_models': len(list(self.best_models_dir.glob('*_best.pt'))), |
|
} |
|
|
|
|
|
sessions = [] |
|
for session_dir in self.sessions_dir.iterdir(): |
|
if session_dir.is_dir(): |
|
sessions.append({ |
|
'session_id': session_dir.name, |
|
'size_gb': get_dir_size(session_dir) / 1e9, |
|
'num_checkpoints': len(list(session_dir.glob('checkpoint_*'))) |
|
}) |
|
|
|
usage['sessions'] = sorted(sessions, key=lambda x: x['size_gb'], reverse=True) |
|
|
|
return usage |
|
|
|
def _update_session_metadata(self, session_id: str, checkpoint_name: str, metrics: Dict[str, float]): |
|
"""Update session metadata with new checkpoint info.""" |
|
metadata_path = self.sessions_dir / session_id / "metadata.json" |
|
|
|
with open(metadata_path) as f: |
|
metadata = json.load(f) |
|
|
|
metadata['checkpoints'].append({ |
|
'name': checkpoint_name, |
|
'metrics': metrics, |
|
'timestamp': datetime.now().isoformat() |
|
}) |
|
|
|
|
|
if 'loss' in metrics: |
|
if metadata['best_metric'] is None or metrics['loss'] < metadata['best_metric'].get('loss', float('inf')): |
|
metadata['best_metric'] = metrics.copy() |
|
|
|
with open(metadata_path, "w") as f: |
|
json.dump(metadata, f, indent=2, default=str) |
|
|
|
def _cleanup_old_checkpoints(self, session_dir: Path): |
|
"""Remove oldest checkpoints to stay within limits.""" |
|
checkpoints = sorted([d for d in session_dir.iterdir() |
|
if d.is_dir() and d.name.startswith("checkpoint_")], |
|
key=lambda x: x.stat().st_mtime) |
|
|
|
while len(checkpoints) > self.max_local_checkpoints: |
|
old_checkpoint = checkpoints.pop(0) |
|
shutil.rmtree(old_checkpoint) |
|
logger.info(f"Cleaned up old checkpoint: {old_checkpoint.name}") |
|
|
|
|
|
|
|
def create_checkpoint_manager(hf_token: str = "os.environ.get('HF_TOKEN', 'your-token-here')") -> EnhancedCheckpointManager: |
|
"""Create a pre-configured checkpoint manager for this environment.""" |
|
return EnhancedCheckpointManager( |
|
base_dir="/data/checkpoints", |
|
hf_repo_id="WCNegentropy/BitTransformerLM", |
|
hf_token=hf_token, |
|
max_local_checkpoints=3 |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
manager = create_checkpoint_manager() |
|
usage = manager.get_storage_usage() |
|
print(f"Current storage usage: {usage['total_gb']:.2f} GB") |
|
print(f"Number of training sessions: {usage['num_sessions']}") |