"""Configuration management for BitTransformerLM.""" from __future__ import annotations import os from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional import torch from .types import ( AttentionMask, ChunkSize, DeviceType, DiffusionConfig, GenerationConfig, HiddenSize, NumHeads, NumLayers, QuantizationConfig, SafetyThresholds, SequenceLength, ) @dataclass class ModelConfig: """Configuration for BitTransformerLM model architecture. Attributes: d_model: Model dimension for embeddings and attention. nhead: Number of attention heads. num_layers: Number of transformer layers. dim_feedforward: Dimension of feedforward networks. max_seq_len: Maximum sequence length for positional encoding. lambda_K: Weight for negentropy metric in telemetry. lambda_C: Weight for complexity metric in telemetry. lambda_S: Weight for symbiosis metric in telemetry. reversible: Enable reversible layers for memory efficiency. use_checkpoint: Use gradient checkpointing. use_autocast: Use automatic mixed precision. use_act: Enable Adaptive Computation Time. act_threshold: ACT halting threshold. chunk_size: Chunk size for chunked attention (None for full attention). overlap: Overlap size for chunked attention. full_attn_logging: Log full attention matrices for telemetry. """ d_model: HiddenSize = 128 nhead: NumHeads = 8 num_layers: NumLayers = 4 dim_feedforward: int = 512 max_seq_len: SequenceLength = 1024 lambda_K: float = 1.0 lambda_C: float = 1.0 lambda_S: float = 1.0 reversible: bool = False use_checkpoint: bool = True use_autocast: bool = False use_act: bool = False act_threshold: float = 0.9 chunk_size: ChunkSize = None overlap: int = 0 full_attn_logging: Optional[bool] = None def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary.""" return { "d_model": self.d_model, "nhead": self.nhead, "num_layers": self.num_layers, "dim_feedforward": self.dim_feedforward, "max_seq_len": self.max_seq_len, "lambda_K": self.lambda_K, "lambda_C": self.lambda_C, "lambda_S": self.lambda_S, "reversible": self.reversible, "use_checkpoint": self.use_checkpoint, "use_autocast": self.use_autocast, "use_act": self.use_act, "act_threshold": self.act_threshold, "chunk_size": self.chunk_size, "overlap": self.overlap, "full_attn_logging": self.full_attn_logging, } @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> ModelConfig: """Create config from dictionary.""" return cls(**config_dict) @dataclass class TrainingConfig: """Configuration for training BitTransformerLM. Attributes: epochs: Number of training epochs. batch_size: Training batch size. learning_rate: Initial learning rate. weight_decay: Weight decay for regularization. gradient_clip_val: Gradient clipping value. warmup_steps: Number of warmup steps for learning rate. accumulate_grad_batches: Number of gradient accumulation steps. amp: Enable automatic mixed precision. compile_model: Enable PyTorch 2.0 compilation. log_every_n_steps: Logging frequency. val_check_interval: Validation check frequency. save_top_k: Number of best checkpoints to save. """ epochs: int = 10 batch_size: int = 8 learning_rate: float = 1e-3 weight_decay: float = 0.01 gradient_clip_val: float = 1.0 warmup_steps: int = 100 accumulate_grad_batches: int = 1 amp: bool = False compile_model: bool = False log_every_n_steps: int = 50 val_check_interval: float = 1.0 save_top_k: int = 3 @dataclass class SafetyConfig: """Configuration for safety monitoring and thresholds. Attributes: enable_safety: Enable safety monitoring. k_threshold: Negentropy threshold for safety gate. c_threshold: Complexity threshold for safety gate. s_threshold: Symbiosis threshold for safety gate. strict_mode: Enable strict safety enforcement. retry_attempts: Number of retry attempts for failed safety checks. """ enable_safety: bool = True k_threshold: float = 0.1 c_threshold: float = 0.3 s_threshold: float = 0.5 strict_mode: bool = False retry_attempts: int = 3 def to_thresholds(self) -> SafetyThresholds: """Convert to SafetyThresholds type.""" return { "k_threshold": self.k_threshold, "c_threshold": self.c_threshold, "s_threshold": self.s_threshold, } @dataclass class DataConfig: """Configuration for data processing and loading. Attributes: dataset_path: Path to training dataset. val_dataset_path: Path to validation dataset. num_workers: Number of data loader workers. pin_memory: Pin memory for data loading. prefetch_factor: Prefetch factor for data loading. max_sequence_length: Maximum sequence length to process. compression_prob: Probability of using compressed data. use_parity: Enable parity bit protection. """ dataset_path: Optional[Path] = None val_dataset_path: Optional[Path] = None num_workers: int = 0 pin_memory: bool = True prefetch_factor: int = 2 max_sequence_length: int = 1024 compression_prob: float = 0.5 use_parity: bool = True @dataclass class ExperimentConfig: """Complete configuration for BitTransformerLM experiments. Attributes: model: Model configuration. training: Training configuration. safety: Safety configuration. data: Data configuration. device: Target device for training. seed: Random seed for reproducibility. experiment_name: Name of the experiment. output_dir: Directory for saving outputs. resume_from_checkpoint: Path to checkpoint to resume from. """ model: ModelConfig = field(default_factory=ModelConfig) training: TrainingConfig = field(default_factory=TrainingConfig) safety: SafetyConfig = field(default_factory=SafetyConfig) data: DataConfig = field(default_factory=DataConfig) device: DeviceType = "auto" seed: int = 42 experiment_name: str = "bit_transformer_experiment" output_dir: Path = Path("./outputs") resume_from_checkpoint: Optional[Path] = None def __post_init__(self): """Post-initialization to handle device selection and path creation.""" # Auto-detect device if self.device == "auto": if torch.cuda.is_available(): self.device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" # Ensure output directory exists self.output_dir.mkdir(parents=True, exist_ok=True) def to_dict(self) -> Dict[str, Any]: """Convert complete config to dictionary.""" return { "model": self.model.to_dict(), "training": self.training.__dict__, "safety": self.safety.__dict__, "data": self.data.__dict__, "device": str(self.device), "seed": self.seed, "experiment_name": self.experiment_name, "output_dir": str(self.output_dir), "resume_from_checkpoint": str(self.resume_from_checkpoint) if self.resume_from_checkpoint else None, } # Preset configurations for common use cases def get_small_config() -> ExperimentConfig: """Get configuration for small-scale experiments.""" return ExperimentConfig( model=ModelConfig( d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=256, ), training=TrainingConfig( batch_size=4, learning_rate=1e-3, epochs=5, ), ) def get_medium_config() -> ExperimentConfig: """Get configuration for medium-scale experiments.""" return ExperimentConfig( model=ModelConfig( d_model=128, nhead=8, num_layers=4, dim_feedforward=512, max_seq_len=1024, ), training=TrainingConfig( batch_size=8, learning_rate=1e-3, epochs=10, ), ) def get_large_config() -> ExperimentConfig: """Get configuration for large-scale experiments.""" return ExperimentConfig( model=ModelConfig( d_model=256, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=2048, reversible=True, chunk_size=512, ), training=TrainingConfig( batch_size=16, learning_rate=5e-4, epochs=20, amp=True, compile_model=True, ), ) def get_config_from_env() -> ExperimentConfig: """Load configuration from environment variables.""" config = ExperimentConfig() # Model config from environment if os.getenv("BT_D_MODEL"): config.model.d_model = int(os.getenv("BT_D_MODEL")) if os.getenv("BT_NUM_LAYERS"): config.model.num_layers = int(os.getenv("BT_NUM_LAYERS")) if os.getenv("BT_NHEAD"): config.model.nhead = int(os.getenv("BT_NHEAD")) # Training config from environment if os.getenv("BT_BATCH_SIZE"): config.training.batch_size = int(os.getenv("BT_BATCH_SIZE")) if os.getenv("BT_LEARNING_RATE"): config.training.learning_rate = float(os.getenv("BT_LEARNING_RATE")) if os.getenv("BT_EPOCHS"): config.training.epochs = int(os.getenv("BT_EPOCHS")) # Device from environment if os.getenv("BT_DEVICE"): config.device = os.getenv("BT_DEVICE") # Output directory from environment if os.getenv("BT_OUTPUT_DIR"): config.output_dir = Path(os.getenv("BT_OUTPUT_DIR")) return config