BitTransformerLM / bit_transformer /error_handling.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
13.1 kB
"""
Comprehensive error handling and recovery utilities for BitTransformerLM.
Provides robust error recovery mechanisms, graceful degradation, and detailed
error logging for production deployments.
"""
import logging
import traceback
import functools
from typing import Dict, Any, Optional, Callable, Union, Type
from contextlib import contextmanager
import torch
import numpy as np
from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike
class BitTransformerError(Exception):
"""Base exception class for BitTransformerLM errors."""
def __init__(self, message: str, error_code: str = "BTLM_ERROR",
context: Optional[Dict[str, Any]] = None):
self.message = message
self.error_code = error_code
self.context = context or {}
super().__init__(f"[{error_code}] {message}")
class ModelError(BitTransformerError):
"""Errors related to model operations."""
pass
class CompressionError(BitTransformerError):
"""Errors related to compression/decompression."""
pass
class SafetyError(BitTransformerError):
"""Errors related to safety gates and telemetry."""
pass
class DataError(BitTransformerError):
"""Errors related to data processing."""
pass
class DistributedError(BitTransformerError):
"""Errors related to distributed training."""
pass
class ErrorRecoveryManager:
"""Manages error recovery strategies and fallback mechanisms."""
def __init__(self, logger: Optional[logging.Logger] = None):
self.logger = logger or logging.getLogger(__name__)
self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {}
self.error_counts: Dict[str, int] = {}
self.max_retries = 3
def register_recovery_strategy(self,
error_type: Type[Exception],
strategy: RecoveryStrategy) -> None:
"""Register a recovery strategy for a specific error type."""
self.recovery_strategies[error_type] = strategy
def handle_error(self,
error: Exception,
context: Optional[Dict[str, Any]] = None,
allow_recovery: bool = True) -> Any:
"""Handle an error with potential recovery."""
error_key = f"{type(error).__name__}:{str(error)}"
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
self.logger.error(
f"Error occurred: {error}\n"
f"Context: {context}\n"
f"Traceback: {traceback.format_exc()}"
)
if allow_recovery and self.error_counts[error_key] <= self.max_retries:
# Try recovery strategy
for error_type, strategy in self.recovery_strategies.items():
if isinstance(error, error_type):
try:
self.logger.info(f"Attempting recovery for {type(error).__name__}")
return strategy()
except Exception as recovery_error:
self.logger.error(f"Recovery failed: {recovery_error}")
break
# If no recovery or recovery failed, raise the original error
raise error
# Global error recovery manager instance
error_manager = ErrorRecoveryManager()
def with_error_recovery(recovery_value: Any = None,
max_retries: int = 3,
error_types: Optional[tuple] = None):
"""Decorator for adding error recovery to functions."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Check if we should handle this error type
if error_types and not isinstance(e, error_types):
raise
if attempt < max_retries:
error_manager.logger.warning(
f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..."
)
continue
# Final attempt failed
error_manager.logger.error(
f"Function {func.__name__} failed after {max_retries + 1} attempts"
)
break
# Return recovery value or raise last error
if recovery_value is not None:
return recovery_value
raise last_error
return wrapper
return decorator
@contextmanager
def safe_operation(operation_name: str,
context: Optional[Dict[str, Any]] = None,
recovery_value: Any = None):
"""Context manager for safe operations with error handling."""
try:
error_manager.logger.debug(f"Starting operation: {operation_name}")
yield
error_manager.logger.debug(f"Completed operation: {operation_name}")
except Exception as e:
error_context = {"operation": operation_name}
if context:
error_context.update(context)
try:
return error_manager.handle_error(e, error_context)
except:
if recovery_value is not None:
error_manager.logger.warning(
f"Operation {operation_name} failed, using recovery value"
)
return recovery_value
raise
def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor],
fallback_value: Optional[torch.Tensor] = None) -> Callable:
"""Wrapper for tensor operations with safety checks."""
def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# Validate input tensor
if not isinstance(tensor, torch.Tensor):
raise DataError("Input must be a torch.Tensor")
if tensor.numel() == 0:
if fallback_value is not None:
return fallback_value
raise DataError("Cannot operate on empty tensor")
# Check for NaN or Inf values
if torch.isnan(tensor).any():
error_manager.logger.warning("NaN values detected in tensor, attempting to clean")
tensor = torch.nan_to_num(tensor, nan=0.0)
if torch.isinf(tensor).any():
error_manager.logger.warning("Inf values detected in tensor, attempting to clean")
tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6)
try:
return tensor_op(tensor, *args, **kwargs)
except (RuntimeError, ValueError) as e:
if "out of memory" in str(e).lower():
# OOM recovery: try with smaller chunks
error_manager.logger.warning("OOM detected, attempting chunked operation")
return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs)
elif "device" in str(e).lower():
# Device mismatch recovery
error_manager.logger.warning("Device mismatch, attempting CPU fallback")
return tensor_op(tensor.cpu(), *args, **kwargs)
else:
raise
return wrapper
def _chunked_tensor_operation(tensor_op: Callable,
tensor: torch.Tensor,
chunk_size: int = 1024,
*args, **kwargs) -> torch.Tensor:
"""Execute tensor operation in chunks to avoid OOM."""
if tensor.size(0) <= chunk_size:
return tensor_op(tensor, *args, **kwargs)
results = []
for i in range(0, tensor.size(0), chunk_size):
chunk = tensor[i:i + chunk_size]
chunk_result = tensor_op(chunk, *args, **kwargs)
results.append(chunk_result)
return torch.cat(results, dim=0)
def validate_model_inputs(inputs: torch.Tensor,
max_seq_len: int = 8192,
expected_dtype: torch.dtype = torch.long) -> torch.Tensor:
"""Validate and sanitize model inputs."""
if not isinstance(inputs, torch.Tensor):
raise DataError("Model inputs must be torch.Tensor")
# Check dimensions
if inputs.dim() == 1:
inputs = inputs.unsqueeze(0) # Add batch dimension
elif inputs.dim() > 2:
raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}")
# Check sequence length
if inputs.size(-1) > max_seq_len:
error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating")
inputs = inputs[:, :max_seq_len]
# Check dtype
if inputs.dtype != expected_dtype:
error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}")
inputs = inputs.to(expected_dtype)
# Check value range for bit sequences
if expected_dtype == torch.long:
invalid_values = (inputs < 0) | (inputs > 1)
if invalid_values.any():
error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]")
inputs = torch.clamp(inputs, 0, 1)
return inputs
def safe_model_forward(model: torch.nn.Module,
inputs: torch.Tensor,
**kwargs) -> torch.Tensor:
"""Safely execute model forward pass with error recovery."""
inputs = validate_model_inputs(inputs)
try:
with safe_operation("model_forward"):
return model(inputs, **kwargs)
except RuntimeError as e:
if "out of memory" in str(e).lower():
# Try with gradient checkpointing
error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing")
from torch.utils.checkpoint import checkpoint
return checkpoint(model, inputs, **kwargs)
elif "device" in str(e).lower():
# Device mismatch recovery
device = next(model.parameters()).device
inputs = inputs.to(device)
return model(inputs, **kwargs)
else:
raise
def recovery_checkpoint_save(model: torch.nn.Module,
path: str,
additional_data: Optional[Dict[str, Any]] = None) -> bool:
"""Save model checkpoint with error recovery."""
try:
checkpoint_data = {
'model_state_dict': model.state_dict(),
'timestamp': torch.tensor(0), # placeholder
}
if additional_data:
checkpoint_data.update(additional_data)
torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True)
error_manager.logger.info(f"Checkpoint saved successfully to {path}")
return True
except Exception as e:
error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}")
# Try backup location
backup_path = path + ".backup"
try:
torch.save(checkpoint_data, backup_path)
error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}")
return True
except Exception as backup_e:
error_manager.logger.error(f"Backup save also failed: {backup_e}")
return False
def setup_error_logging(log_level: LogLevel = "INFO",
log_file: Optional[str] = None) -> logging.Logger:
"""Set up comprehensive error logging."""
logger = logging.getLogger("BitTransformerLM")
logger.setLevel(getattr(logging, log_level))
# Console handler
console_handler = logging.StreamHandler()
console_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# File handler if specified
if log_file:
file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
return logger
# Default recovery strategies
def default_tensor_recovery() -> torch.Tensor:
"""Default recovery strategy for tensor operations."""
return torch.zeros(1, dtype=torch.long)
def default_model_recovery() -> Dict[str, torch.Tensor]:
"""Default recovery strategy for model operations."""
return {"output": torch.zeros(1, dtype=torch.float32)}
# Register default recovery strategies
error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery)
error_manager.register_recovery_strategy(ModelError, default_model_recovery)