|
""" |
|
Comprehensive error handling and recovery utilities for BitTransformerLM. |
|
|
|
Provides robust error recovery mechanisms, graceful degradation, and detailed |
|
error logging for production deployments. |
|
""" |
|
|
|
import logging |
|
import traceback |
|
import functools |
|
from typing import Dict, Any, Optional, Callable, Union, Type |
|
from contextlib import contextmanager |
|
import torch |
|
import numpy as np |
|
|
|
from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike |
|
|
|
|
|
class BitTransformerError(Exception): |
|
"""Base exception class for BitTransformerLM errors.""" |
|
|
|
def __init__(self, message: str, error_code: str = "BTLM_ERROR", |
|
context: Optional[Dict[str, Any]] = None): |
|
self.message = message |
|
self.error_code = error_code |
|
self.context = context or {} |
|
super().__init__(f"[{error_code}] {message}") |
|
|
|
|
|
class ModelError(BitTransformerError): |
|
"""Errors related to model operations.""" |
|
pass |
|
|
|
|
|
class CompressionError(BitTransformerError): |
|
"""Errors related to compression/decompression.""" |
|
pass |
|
|
|
|
|
class SafetyError(BitTransformerError): |
|
"""Errors related to safety gates and telemetry.""" |
|
pass |
|
|
|
|
|
class DataError(BitTransformerError): |
|
"""Errors related to data processing.""" |
|
pass |
|
|
|
|
|
class DistributedError(BitTransformerError): |
|
"""Errors related to distributed training.""" |
|
pass |
|
|
|
|
|
class ErrorRecoveryManager: |
|
"""Manages error recovery strategies and fallback mechanisms.""" |
|
|
|
def __init__(self, logger: Optional[logging.Logger] = None): |
|
self.logger = logger or logging.getLogger(__name__) |
|
self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {} |
|
self.error_counts: Dict[str, int] = {} |
|
self.max_retries = 3 |
|
|
|
def register_recovery_strategy(self, |
|
error_type: Type[Exception], |
|
strategy: RecoveryStrategy) -> None: |
|
"""Register a recovery strategy for a specific error type.""" |
|
self.recovery_strategies[error_type] = strategy |
|
|
|
def handle_error(self, |
|
error: Exception, |
|
context: Optional[Dict[str, Any]] = None, |
|
allow_recovery: bool = True) -> Any: |
|
"""Handle an error with potential recovery.""" |
|
error_key = f"{type(error).__name__}:{str(error)}" |
|
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1 |
|
|
|
self.logger.error( |
|
f"Error occurred: {error}\n" |
|
f"Context: {context}\n" |
|
f"Traceback: {traceback.format_exc()}" |
|
) |
|
|
|
if allow_recovery and self.error_counts[error_key] <= self.max_retries: |
|
|
|
for error_type, strategy in self.recovery_strategies.items(): |
|
if isinstance(error, error_type): |
|
try: |
|
self.logger.info(f"Attempting recovery for {type(error).__name__}") |
|
return strategy() |
|
except Exception as recovery_error: |
|
self.logger.error(f"Recovery failed: {recovery_error}") |
|
break |
|
|
|
|
|
raise error |
|
|
|
|
|
|
|
error_manager = ErrorRecoveryManager() |
|
|
|
|
|
def with_error_recovery(recovery_value: Any = None, |
|
max_retries: int = 3, |
|
error_types: Optional[tuple] = None): |
|
"""Decorator for adding error recovery to functions.""" |
|
def decorator(func: Callable) -> Callable: |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
last_error = None |
|
|
|
for attempt in range(max_retries + 1): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
last_error = e |
|
|
|
|
|
if error_types and not isinstance(e, error_types): |
|
raise |
|
|
|
if attempt < max_retries: |
|
error_manager.logger.warning( |
|
f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..." |
|
) |
|
continue |
|
|
|
|
|
error_manager.logger.error( |
|
f"Function {func.__name__} failed after {max_retries + 1} attempts" |
|
) |
|
break |
|
|
|
|
|
if recovery_value is not None: |
|
return recovery_value |
|
raise last_error |
|
|
|
return wrapper |
|
return decorator |
|
|
|
|
|
@contextmanager |
|
def safe_operation(operation_name: str, |
|
context: Optional[Dict[str, Any]] = None, |
|
recovery_value: Any = None): |
|
"""Context manager for safe operations with error handling.""" |
|
try: |
|
error_manager.logger.debug(f"Starting operation: {operation_name}") |
|
yield |
|
error_manager.logger.debug(f"Completed operation: {operation_name}") |
|
except Exception as e: |
|
error_context = {"operation": operation_name} |
|
if context: |
|
error_context.update(context) |
|
|
|
try: |
|
return error_manager.handle_error(e, error_context) |
|
except: |
|
if recovery_value is not None: |
|
error_manager.logger.warning( |
|
f"Operation {operation_name} failed, using recovery value" |
|
) |
|
return recovery_value |
|
raise |
|
|
|
|
|
def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor], |
|
fallback_value: Optional[torch.Tensor] = None) -> Callable: |
|
"""Wrapper for tensor operations with safety checks.""" |
|
def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
|
|
if not isinstance(tensor, torch.Tensor): |
|
raise DataError("Input must be a torch.Tensor") |
|
|
|
if tensor.numel() == 0: |
|
if fallback_value is not None: |
|
return fallback_value |
|
raise DataError("Cannot operate on empty tensor") |
|
|
|
|
|
if torch.isnan(tensor).any(): |
|
error_manager.logger.warning("NaN values detected in tensor, attempting to clean") |
|
tensor = torch.nan_to_num(tensor, nan=0.0) |
|
|
|
if torch.isinf(tensor).any(): |
|
error_manager.logger.warning("Inf values detected in tensor, attempting to clean") |
|
tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6) |
|
|
|
try: |
|
return tensor_op(tensor, *args, **kwargs) |
|
except (RuntimeError, ValueError) as e: |
|
if "out of memory" in str(e).lower(): |
|
|
|
error_manager.logger.warning("OOM detected, attempting chunked operation") |
|
return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs) |
|
elif "device" in str(e).lower(): |
|
|
|
error_manager.logger.warning("Device mismatch, attempting CPU fallback") |
|
return tensor_op(tensor.cpu(), *args, **kwargs) |
|
else: |
|
raise |
|
|
|
return wrapper |
|
|
|
|
|
def _chunked_tensor_operation(tensor_op: Callable, |
|
tensor: torch.Tensor, |
|
chunk_size: int = 1024, |
|
*args, **kwargs) -> torch.Tensor: |
|
"""Execute tensor operation in chunks to avoid OOM.""" |
|
if tensor.size(0) <= chunk_size: |
|
return tensor_op(tensor, *args, **kwargs) |
|
|
|
results = [] |
|
for i in range(0, tensor.size(0), chunk_size): |
|
chunk = tensor[i:i + chunk_size] |
|
chunk_result = tensor_op(chunk, *args, **kwargs) |
|
results.append(chunk_result) |
|
|
|
return torch.cat(results, dim=0) |
|
|
|
|
|
def validate_model_inputs(inputs: torch.Tensor, |
|
max_seq_len: int = 8192, |
|
expected_dtype: torch.dtype = torch.long) -> torch.Tensor: |
|
"""Validate and sanitize model inputs.""" |
|
if not isinstance(inputs, torch.Tensor): |
|
raise DataError("Model inputs must be torch.Tensor") |
|
|
|
|
|
if inputs.dim() == 1: |
|
inputs = inputs.unsqueeze(0) |
|
elif inputs.dim() > 2: |
|
raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}") |
|
|
|
|
|
if inputs.size(-1) > max_seq_len: |
|
error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating") |
|
inputs = inputs[:, :max_seq_len] |
|
|
|
|
|
if inputs.dtype != expected_dtype: |
|
error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}") |
|
inputs = inputs.to(expected_dtype) |
|
|
|
|
|
if expected_dtype == torch.long: |
|
invalid_values = (inputs < 0) | (inputs > 1) |
|
if invalid_values.any(): |
|
error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]") |
|
inputs = torch.clamp(inputs, 0, 1) |
|
|
|
return inputs |
|
|
|
|
|
def safe_model_forward(model: torch.nn.Module, |
|
inputs: torch.Tensor, |
|
**kwargs) -> torch.Tensor: |
|
"""Safely execute model forward pass with error recovery.""" |
|
inputs = validate_model_inputs(inputs) |
|
|
|
try: |
|
with safe_operation("model_forward"): |
|
return model(inputs, **kwargs) |
|
except RuntimeError as e: |
|
if "out of memory" in str(e).lower(): |
|
|
|
error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing") |
|
from torch.utils.checkpoint import checkpoint |
|
return checkpoint(model, inputs, **kwargs) |
|
elif "device" in str(e).lower(): |
|
|
|
device = next(model.parameters()).device |
|
inputs = inputs.to(device) |
|
return model(inputs, **kwargs) |
|
else: |
|
raise |
|
|
|
|
|
def recovery_checkpoint_save(model: torch.nn.Module, |
|
path: str, |
|
additional_data: Optional[Dict[str, Any]] = None) -> bool: |
|
"""Save model checkpoint with error recovery.""" |
|
try: |
|
checkpoint_data = { |
|
'model_state_dict': model.state_dict(), |
|
'timestamp': torch.tensor(0), |
|
} |
|
if additional_data: |
|
checkpoint_data.update(additional_data) |
|
|
|
torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True) |
|
error_manager.logger.info(f"Checkpoint saved successfully to {path}") |
|
return True |
|
|
|
except Exception as e: |
|
error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}") |
|
|
|
|
|
backup_path = path + ".backup" |
|
try: |
|
torch.save(checkpoint_data, backup_path) |
|
error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}") |
|
return True |
|
except Exception as backup_e: |
|
error_manager.logger.error(f"Backup save also failed: {backup_e}") |
|
return False |
|
|
|
|
|
def setup_error_logging(log_level: LogLevel = "INFO", |
|
log_file: Optional[str] = None) -> logging.Logger: |
|
"""Set up comprehensive error logging.""" |
|
logger = logging.getLogger("BitTransformerLM") |
|
logger.setLevel(getattr(logging, log_level)) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_formatter = logging.Formatter( |
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
console_handler.setFormatter(console_formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
if log_file: |
|
file_handler = logging.FileHandler(log_file) |
|
file_formatter = logging.Formatter( |
|
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' |
|
) |
|
file_handler.setFormatter(file_formatter) |
|
logger.addHandler(file_handler) |
|
|
|
return logger |
|
|
|
|
|
|
|
def default_tensor_recovery() -> torch.Tensor: |
|
"""Default recovery strategy for tensor operations.""" |
|
return torch.zeros(1, dtype=torch.long) |
|
|
|
|
|
def default_model_recovery() -> Dict[str, torch.Tensor]: |
|
"""Default recovery strategy for model operations.""" |
|
return {"output": torch.zeros(1, dtype=torch.float32)} |
|
|
|
|
|
|
|
error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery) |
|
error_manager.register_recovery_strategy(ModelError, default_model_recovery) |