#!/usr/bin/env python3 """ BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED ================================================================= Fixed version that properly initializes 680M parameter model with all optimizations! Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues. """ import os import sys import time import json import logging from datetime import datetime from typing import Dict, Any, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import datasets from datasets import load_dataset import numpy as np # BitTransformerLM imports from bit_transformer.model import BitTransformerLM from bit_transformer.bit_io import text_to_bits, bits_to_text from bit_transformer.utils import set_dropout # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') logger = logging.getLogger(__name__) class OptimizedConfig: """Optimized 680M parameter configuration with ALL BitTransformerLM features enabled.""" # Model Architecture (680M parameters - CONFIRMED) D_MODEL = 1536 NUM_LAYERS = 24 NUM_HEADS = 24 DIM_FEEDFORWARD = 6144 MAX_SEQ_LEN = 2048 # Training Configuration BATCH_SIZE_PER_GPU = 1 # Ultra conservative for 680M model NUM_GPUS = 4 TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS # 4 GRADIENT_ACCUMULATION_STEPS = 32 # Effective batch size = 128 LEARNING_RATE = 3e-4 # Optimal for 680M model WEIGHT_DECAY = 0.01 MAX_STEPS = 10000 WARMUP_STEPS = 500 # BitTransformerLM Optimizations - ALL ENABLED! USE_REVERSIBLE = True # 50% memory savings USE_GRADIENT_CHECKPOINTING = True # Additional memory savings USE_MIXED_PRECISION = True # FP16 training USE_AUTOCAST = True # CPU mixed precision when needed CHUNK_SIZE = None # Full attention (no chunking) FULL_ATTN_LOGGING = False # Memory optimization # Safety & Telemetry LAMBDA_K = 1.0 LAMBDA_C = 1.0 LAMBDA_S = 1.0 NEGENTROPY_THRESHOLD = 0.2 LZ_COMPLEXITY_THRESHOLD = 0.3 SYMBIOSIS_THRESHOLD = 0.5 @classmethod def get_model_config(cls) -> Dict[str, Any]: """Get optimized model configuration.""" return { "d_model": cls.D_MODEL, "nhead": cls.NUM_HEADS, "num_layers": cls.NUM_LAYERS, "dim_feedforward": cls.DIM_FEEDFORWARD, "max_seq_len": cls.MAX_SEQ_LEN, "lambda_K": cls.LAMBDA_K, "lambda_C": cls.LAMBDA_C, "lambda_S": cls.LAMBDA_S, "reversible": cls.USE_REVERSIBLE, "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, "use_autocast": cls.USE_AUTOCAST, "chunk_size": cls.CHUNK_SIZE, "full_attn_logging": cls.FULL_ATTN_LOGGING, } class SimpleWikiTextDataset(torch.utils.data.Dataset): """Simplified WikiText dataset for bit-level training.""" def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048): self.max_length = max_length logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...") dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) # Filter and limit samples texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples] self.texts = texts logger.info(f"Loaded {len(self.texts)} text samples from {split}") def __len__(self) -> int: return len(self.texts) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: text = self.texts[idx] try: # Convert text to bits bits = text_to_bits(text) # Truncate or pad to max_length if len(bits) > self.max_length: bits = bits[:self.max_length] elif len(bits) < self.max_length: bits = bits + [0] * (self.max_length - len(bits)) # Convert to tensor input_bits = torch.tensor(bits[:-1], dtype=torch.long) target_bits = torch.tensor(bits[1:], dtype=torch.long) return { 'input_ids': input_bits, 'labels': target_bits, 'attention_mask': torch.ones_like(input_bits) } except Exception as e: logger.warning(f"Error processing text at index {idx}: {e}") # Fallback fallback_bits = [0, 1] * (self.max_length // 2) input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) return { 'input_ids': input_bits, 'labels': target_bits, 'attention_mask': torch.ones_like(input_bits) } def create_optimized_model(config: OptimizedConfig) -> nn.Module: """Create properly optimized BitTransformerLM model.""" # Create model on CPU first logger.info("🏗️ Creating optimized BitTransformerLM model...") model_config = config.get_model_config() logger.info("Model configuration:") for k, v in model_config.items(): logger.info(f" {k}: {v}") model = BitTransformerLM(**model_config) # Count parameters params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"✅ Model created: {params:,} parameters ({params/1e6:.1f}M)") # Move to GPU and setup DataParallel if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS: logger.info(f"🚀 Setting up multi-GPU training on {config.NUM_GPUS} GPUs...") # Move model to GPU 0 model = model.cuda() # Wrap with DataParallel for multi-GPU if config.NUM_GPUS > 1: model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS))) logger.info(f"✅ DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}") else: logger.warning("⚠️ Limited GPU availability - using single GPU or CPU") if torch.cuda.is_available(): model = model.cuda() return model def train_step(model: nn.Module, batch: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, config: OptimizedConfig) -> tuple: """Optimized training step with all BitTransformerLM features.""" model.train() set_dropout(model, 0.1) # Enable dropout for training # Move batch to GPU input_ids = batch['input_ids'].cuda(non_blocking=True) labels = batch['labels'].cuda(non_blocking=True) # Forward pass with mixed precision with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): outputs = model(input_ids) if isinstance(outputs, tuple): logits, telemetry = outputs else: logits, telemetry = outputs, {} # Compute loss loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean') # Add safety penalties if enabled safety_penalty = 0.0 if telemetry: negentropy = telemetry.get('negentropy', 1.0) lz_complexity = telemetry.get('lz_complexity', 1.0) symbiosis = telemetry.get('symbiosis', 1.0) if (negentropy < config.NEGENTROPY_THRESHOLD or lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or symbiosis < config.SYMBIOSIS_THRESHOLD): safety_penalty = 0.1 loss = loss + safety_penalty # Scale for gradient accumulation loss = loss / config.GRADIENT_ACCUMULATION_STEPS # Backward pass scaler.scale(loss).backward() return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty def main(): """Main training function.""" logger.info("🚀 OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!") logger.info("=" * 60) config = OptimizedConfig() # Check CUDA if not torch.cuda.is_available(): logger.error("❌ CUDA not available!") return logger.info(f"🔥 Hardware: {torch.cuda.device_count()}x GPUs detected") for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)") # Create model model = create_optimized_model(config) # Create datasets logger.info("📚 Loading datasets...") train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN) val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN) # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=config.BATCH_SIZE_PER_GPU, shuffle=True, num_workers=2, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=config.BATCH_SIZE_PER_GPU, shuffle=False, num_workers=1, pin_memory=True ) # Setup optimizer and scheduler logger.info("⚙️ Setting up optimizer...") optimizer = torch.optim.AdamW( model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, betas=(0.9, 0.95) ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=config.LEARNING_RATE, total_steps=config.MAX_STEPS, pct_start=config.WARMUP_STEPS / config.MAX_STEPS, ) scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION) # Training loop logger.info("🎯 Starting training...") logger.info(f"Target steps: {config.MAX_STEPS}") logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}") step = 0 running_loss = 0.0 start_time = time.time() for epoch in range(100): # Large number for batch_idx, batch in enumerate(train_loader): # Training step loss, telemetry, safety_penalty = train_step( model, batch, optimizer, scaler, config ) running_loss += loss # Gradient accumulation if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: # Gradient clipping scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Optimizer step scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() step += 1 # Logging if step % 10 == 0: avg_loss = running_loss / 10 elapsed = time.time() - start_time samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed memory_used = torch.cuda.max_memory_allocated() / (1024**3) logger.info( f"Step {step:4d} | " f"Loss: {avg_loss:.4f} | " f"K: {telemetry.get('negentropy', 0):.3f} | " f"C: {telemetry.get('lz_complexity', 0):.3f} | " f"S: {telemetry.get('symbiosis', 0):.3f} | " f"LR: {scheduler.get_last_lr()[0]:.2e} | " f"Speed: {samples_per_sec:.1f} samp/s | " f"Mem: {memory_used:.1f}GB" + (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "") ) running_loss = 0.0 start_time = time.time() # Validation if step % 100 == 0: model.eval() set_dropout(model, 0.0) val_loss = 0 with torch.no_grad(): for val_batch in val_loader: val_input_ids = val_batch['input_ids'].cuda() val_labels = val_batch['labels'].cuda() with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): val_outputs = model(val_input_ids) if isinstance(val_outputs, tuple): val_logits, _ = val_outputs else: val_logits = val_outputs val_loss += F.cross_entropy( val_logits.view(-1, 2), val_labels.view(-1) ).item() val_loss /= len(val_loader) logger.info(f"📊 Validation Loss: {val_loss:.4f}") # Save checkpoint if step % 500 == 0: checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(checkpoint_dir, exist_ok=True) torch.save({ 'step': step, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'config': config.get_model_config(), }, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt") logger.info(f"💾 Checkpoint saved: step {step}") if step >= config.MAX_STEPS: logger.info("🏁 Training completed!") return if step >= config.MAX_STEPS: break if __name__ == "__main__": main()