""" Comprehensive error handling and recovery utilities for BitTransformerLM. Provides robust error recovery mechanisms, graceful degradation, and detailed error logging for production deployments. """ import logging import traceback import functools from typing import Dict, Any, Optional, Callable, Union, Type from contextlib import contextmanager import torch import numpy as np from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike class BitTransformerError(Exception): """Base exception class for BitTransformerLM errors.""" def __init__(self, message: str, error_code: str = "BTLM_ERROR", context: Optional[Dict[str, Any]] = None): self.message = message self.error_code = error_code self.context = context or {} super().__init__(f"[{error_code}] {message}") class ModelError(BitTransformerError): """Errors related to model operations.""" pass class CompressionError(BitTransformerError): """Errors related to compression/decompression.""" pass class SafetyError(BitTransformerError): """Errors related to safety gates and telemetry.""" pass class DataError(BitTransformerError): """Errors related to data processing.""" pass class DistributedError(BitTransformerError): """Errors related to distributed training.""" pass class ErrorRecoveryManager: """Manages error recovery strategies and fallback mechanisms.""" def __init__(self, logger: Optional[logging.Logger] = None): self.logger = logger or logging.getLogger(__name__) self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {} self.error_counts: Dict[str, int] = {} self.max_retries = 3 def register_recovery_strategy(self, error_type: Type[Exception], strategy: RecoveryStrategy) -> None: """Register a recovery strategy for a specific error type.""" self.recovery_strategies[error_type] = strategy def handle_error(self, error: Exception, context: Optional[Dict[str, Any]] = None, allow_recovery: bool = True) -> Any: """Handle an error with potential recovery.""" error_key = f"{type(error).__name__}:{str(error)}" self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1 self.logger.error( f"Error occurred: {error}\n" f"Context: {context}\n" f"Traceback: {traceback.format_exc()}" ) if allow_recovery and self.error_counts[error_key] <= self.max_retries: # Try recovery strategy for error_type, strategy in self.recovery_strategies.items(): if isinstance(error, error_type): try: self.logger.info(f"Attempting recovery for {type(error).__name__}") return strategy() except Exception as recovery_error: self.logger.error(f"Recovery failed: {recovery_error}") break # If no recovery or recovery failed, raise the original error raise error # Global error recovery manager instance error_manager = ErrorRecoveryManager() def with_error_recovery(recovery_value: Any = None, max_retries: int = 3, error_types: Optional[tuple] = None): """Decorator for adding error recovery to functions.""" def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): last_error = None for attempt in range(max_retries + 1): try: return func(*args, **kwargs) except Exception as e: last_error = e # Check if we should handle this error type if error_types and not isinstance(e, error_types): raise if attempt < max_retries: error_manager.logger.warning( f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..." ) continue # Final attempt failed error_manager.logger.error( f"Function {func.__name__} failed after {max_retries + 1} attempts" ) break # Return recovery value or raise last error if recovery_value is not None: return recovery_value raise last_error return wrapper return decorator @contextmanager def safe_operation(operation_name: str, context: Optional[Dict[str, Any]] = None, recovery_value: Any = None): """Context manager for safe operations with error handling.""" try: error_manager.logger.debug(f"Starting operation: {operation_name}") yield error_manager.logger.debug(f"Completed operation: {operation_name}") except Exception as e: error_context = {"operation": operation_name} if context: error_context.update(context) try: return error_manager.handle_error(e, error_context) except: if recovery_value is not None: error_manager.logger.warning( f"Operation {operation_name} failed, using recovery value" ) return recovery_value raise def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor], fallback_value: Optional[torch.Tensor] = None) -> Callable: """Wrapper for tensor operations with safety checks.""" def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: # Validate input tensor if not isinstance(tensor, torch.Tensor): raise DataError("Input must be a torch.Tensor") if tensor.numel() == 0: if fallback_value is not None: return fallback_value raise DataError("Cannot operate on empty tensor") # Check for NaN or Inf values if torch.isnan(tensor).any(): error_manager.logger.warning("NaN values detected in tensor, attempting to clean") tensor = torch.nan_to_num(tensor, nan=0.0) if torch.isinf(tensor).any(): error_manager.logger.warning("Inf values detected in tensor, attempting to clean") tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6) try: return tensor_op(tensor, *args, **kwargs) except (RuntimeError, ValueError) as e: if "out of memory" in str(e).lower(): # OOM recovery: try with smaller chunks error_manager.logger.warning("OOM detected, attempting chunked operation") return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs) elif "device" in str(e).lower(): # Device mismatch recovery error_manager.logger.warning("Device mismatch, attempting CPU fallback") return tensor_op(tensor.cpu(), *args, **kwargs) else: raise return wrapper def _chunked_tensor_operation(tensor_op: Callable, tensor: torch.Tensor, chunk_size: int = 1024, *args, **kwargs) -> torch.Tensor: """Execute tensor operation in chunks to avoid OOM.""" if tensor.size(0) <= chunk_size: return tensor_op(tensor, *args, **kwargs) results = [] for i in range(0, tensor.size(0), chunk_size): chunk = tensor[i:i + chunk_size] chunk_result = tensor_op(chunk, *args, **kwargs) results.append(chunk_result) return torch.cat(results, dim=0) def validate_model_inputs(inputs: torch.Tensor, max_seq_len: int = 8192, expected_dtype: torch.dtype = torch.long) -> torch.Tensor: """Validate and sanitize model inputs.""" if not isinstance(inputs, torch.Tensor): raise DataError("Model inputs must be torch.Tensor") # Check dimensions if inputs.dim() == 1: inputs = inputs.unsqueeze(0) # Add batch dimension elif inputs.dim() > 2: raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}") # Check sequence length if inputs.size(-1) > max_seq_len: error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating") inputs = inputs[:, :max_seq_len] # Check dtype if inputs.dtype != expected_dtype: error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}") inputs = inputs.to(expected_dtype) # Check value range for bit sequences if expected_dtype == torch.long: invalid_values = (inputs < 0) | (inputs > 1) if invalid_values.any(): error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]") inputs = torch.clamp(inputs, 0, 1) return inputs def safe_model_forward(model: torch.nn.Module, inputs: torch.Tensor, **kwargs) -> torch.Tensor: """Safely execute model forward pass with error recovery.""" inputs = validate_model_inputs(inputs) try: with safe_operation("model_forward"): return model(inputs, **kwargs) except RuntimeError as e: if "out of memory" in str(e).lower(): # Try with gradient checkpointing error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing") from torch.utils.checkpoint import checkpoint return checkpoint(model, inputs, **kwargs) elif "device" in str(e).lower(): # Device mismatch recovery device = next(model.parameters()).device inputs = inputs.to(device) return model(inputs, **kwargs) else: raise def recovery_checkpoint_save(model: torch.nn.Module, path: str, additional_data: Optional[Dict[str, Any]] = None) -> bool: """Save model checkpoint with error recovery.""" try: checkpoint_data = { 'model_state_dict': model.state_dict(), 'timestamp': torch.tensor(0), # placeholder } if additional_data: checkpoint_data.update(additional_data) torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True) error_manager.logger.info(f"Checkpoint saved successfully to {path}") return True except Exception as e: error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}") # Try backup location backup_path = path + ".backup" try: torch.save(checkpoint_data, backup_path) error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}") return True except Exception as backup_e: error_manager.logger.error(f"Backup save also failed: {backup_e}") return False def setup_error_logging(log_level: LogLevel = "INFO", log_file: Optional[str] = None) -> logging.Logger: """Set up comprehensive error logging.""" logger = logging.getLogger("BitTransformerLM") logger.setLevel(getattr(logging, log_level)) # Console handler console_handler = logging.StreamHandler() console_formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) # File handler if specified if log_file: file_handler = logging.FileHandler(log_file) file_formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' ) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) return logger # Default recovery strategies def default_tensor_recovery() -> torch.Tensor: """Default recovery strategy for tensor operations.""" return torch.zeros(1, dtype=torch.long) def default_model_recovery() -> Dict[str, torch.Tensor]: """Default recovery strategy for model operations.""" return {"output": torch.zeros(1, dtype=torch.float32)} # Register default recovery strategies error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery) error_manager.register_recovery_strategy(ModelError, default_model_recovery)