BitTransformerLM / massive_scale_simple.py
WCNegentropy's picture
πŸ€– Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
14.6 kB
#!/usr/bin/env python3
"""
BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED
=================================================================
Fixed version that properly initializes 680M parameter model with all optimizations!
Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues.
"""
import os
import sys
import time
import json
import logging
from datetime import datetime
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import datasets
from datasets import load_dataset
import numpy as np
# BitTransformerLM imports
from bit_transformer.model import BitTransformerLM
from bit_transformer.bit_io import text_to_bits, bits_to_text
from bit_transformer.utils import set_dropout
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
class OptimizedConfig:
"""Optimized 680M parameter configuration with ALL BitTransformerLM features enabled."""
# Model Architecture (680M parameters - CONFIRMED)
D_MODEL = 1536
NUM_LAYERS = 24
NUM_HEADS = 24
DIM_FEEDFORWARD = 6144
MAX_SEQ_LEN = 2048
# Training Configuration
BATCH_SIZE_PER_GPU = 1 # Ultra conservative for 680M model
NUM_GPUS = 4
TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS # 4
GRADIENT_ACCUMULATION_STEPS = 32 # Effective batch size = 128
LEARNING_RATE = 3e-4 # Optimal for 680M model
WEIGHT_DECAY = 0.01
MAX_STEPS = 10000
WARMUP_STEPS = 500
# BitTransformerLM Optimizations - ALL ENABLED!
USE_REVERSIBLE = True # 50% memory savings
USE_GRADIENT_CHECKPOINTING = True # Additional memory savings
USE_MIXED_PRECISION = True # FP16 training
USE_AUTOCAST = True # CPU mixed precision when needed
CHUNK_SIZE = None # Full attention (no chunking)
FULL_ATTN_LOGGING = False # Memory optimization
# Safety & Telemetry
LAMBDA_K = 1.0
LAMBDA_C = 1.0
LAMBDA_S = 1.0
NEGENTROPY_THRESHOLD = 0.2
LZ_COMPLEXITY_THRESHOLD = 0.3
SYMBIOSIS_THRESHOLD = 0.5
@classmethod
def get_model_config(cls) -> Dict[str, Any]:
"""Get optimized model configuration."""
return {
"d_model": cls.D_MODEL,
"nhead": cls.NUM_HEADS,
"num_layers": cls.NUM_LAYERS,
"dim_feedforward": cls.DIM_FEEDFORWARD,
"max_seq_len": cls.MAX_SEQ_LEN,
"lambda_K": cls.LAMBDA_K,
"lambda_C": cls.LAMBDA_C,
"lambda_S": cls.LAMBDA_S,
"reversible": cls.USE_REVERSIBLE,
"use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING,
"use_autocast": cls.USE_AUTOCAST,
"chunk_size": cls.CHUNK_SIZE,
"full_attn_logging": cls.FULL_ATTN_LOGGING,
}
class SimpleWikiTextDataset(torch.utils.data.Dataset):
"""Simplified WikiText dataset for bit-level training."""
def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048):
self.max_length = max_length
logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...")
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
# Filter and limit samples
texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples]
self.texts = texts
logger.info(f"Loaded {len(self.texts)} text samples from {split}")
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = self.texts[idx]
try:
# Convert text to bits
bits = text_to_bits(text)
# Truncate or pad to max_length
if len(bits) > self.max_length:
bits = bits[:self.max_length]
elif len(bits) < self.max_length:
bits = bits + [0] * (self.max_length - len(bits))
# Convert to tensor
input_bits = torch.tensor(bits[:-1], dtype=torch.long)
target_bits = torch.tensor(bits[1:], dtype=torch.long)
return {
'input_ids': input_bits,
'labels': target_bits,
'attention_mask': torch.ones_like(input_bits)
}
except Exception as e:
logger.warning(f"Error processing text at index {idx}: {e}")
# Fallback
fallback_bits = [0, 1] * (self.max_length // 2)
input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long)
target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long)
return {
'input_ids': input_bits,
'labels': target_bits,
'attention_mask': torch.ones_like(input_bits)
}
def create_optimized_model(config: OptimizedConfig) -> nn.Module:
"""Create properly optimized BitTransformerLM model."""
# Create model on CPU first
logger.info("πŸ—οΈ Creating optimized BitTransformerLM model...")
model_config = config.get_model_config()
logger.info("Model configuration:")
for k, v in model_config.items():
logger.info(f" {k}: {v}")
model = BitTransformerLM(**model_config)
# Count parameters
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"βœ… Model created: {params:,} parameters ({params/1e6:.1f}M)")
# Move to GPU and setup DataParallel
if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS:
logger.info(f"πŸš€ Setting up multi-GPU training on {config.NUM_GPUS} GPUs...")
# Move model to GPU 0
model = model.cuda()
# Wrap with DataParallel for multi-GPU
if config.NUM_GPUS > 1:
model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS)))
logger.info(f"βœ… DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}")
else:
logger.warning("⚠️ Limited GPU availability - using single GPU or CPU")
if torch.cuda.is_available():
model = model.cuda()
return model
def train_step(model: nn.Module, batch: Dict[str, torch.Tensor],
optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler,
config: OptimizedConfig) -> tuple:
"""Optimized training step with all BitTransformerLM features."""
model.train()
set_dropout(model, 0.1) # Enable dropout for training
# Move batch to GPU
input_ids = batch['input_ids'].cuda(non_blocking=True)
labels = batch['labels'].cuda(non_blocking=True)
# Forward pass with mixed precision
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
outputs = model(input_ids)
if isinstance(outputs, tuple):
logits, telemetry = outputs
else:
logits, telemetry = outputs, {}
# Compute loss
loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean')
# Add safety penalties if enabled
safety_penalty = 0.0
if telemetry:
negentropy = telemetry.get('negentropy', 1.0)
lz_complexity = telemetry.get('lz_complexity', 1.0)
symbiosis = telemetry.get('symbiosis', 1.0)
if (negentropy < config.NEGENTROPY_THRESHOLD or
lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or
symbiosis < config.SYMBIOSIS_THRESHOLD):
safety_penalty = 0.1
loss = loss + safety_penalty
# Scale for gradient accumulation
loss = loss / config.GRADIENT_ACCUMULATION_STEPS
# Backward pass
scaler.scale(loss).backward()
return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty
def main():
"""Main training function."""
logger.info("πŸš€ OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!")
logger.info("=" * 60)
config = OptimizedConfig()
# Check CUDA
if not torch.cuda.is_available():
logger.error("❌ CUDA not available!")
return
logger.info(f"πŸ”₯ Hardware: {torch.cuda.device_count()}x GPUs detected")
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)")
# Create model
model = create_optimized_model(config)
# Create datasets
logger.info("πŸ“š Loading datasets...")
train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN)
val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE_PER_GPU,
shuffle=True,
num_workers=2,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE_PER_GPU,
shuffle=False,
num_workers=1,
pin_memory=True
)
# Setup optimizer and scheduler
logger.info("βš™οΈ Setting up optimizer...")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.LEARNING_RATE,
weight_decay=config.WEIGHT_DECAY,
betas=(0.9, 0.95)
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.LEARNING_RATE,
total_steps=config.MAX_STEPS,
pct_start=config.WARMUP_STEPS / config.MAX_STEPS,
)
scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION)
# Training loop
logger.info("🎯 Starting training...")
logger.info(f"Target steps: {config.MAX_STEPS}")
logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}")
step = 0
running_loss = 0.0
start_time = time.time()
for epoch in range(100): # Large number
for batch_idx, batch in enumerate(train_loader):
# Training step
loss, telemetry, safety_penalty = train_step(
model, batch, optimizer, scaler, config
)
running_loss += loss
# Gradient accumulation
if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
step += 1
# Logging
if step % 10 == 0:
avg_loss = running_loss / 10
elapsed = time.time() - start_time
samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed
memory_used = torch.cuda.max_memory_allocated() / (1024**3)
logger.info(
f"Step {step:4d} | "
f"Loss: {avg_loss:.4f} | "
f"K: {telemetry.get('negentropy', 0):.3f} | "
f"C: {telemetry.get('lz_complexity', 0):.3f} | "
f"S: {telemetry.get('symbiosis', 0):.3f} | "
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
f"Speed: {samples_per_sec:.1f} samp/s | "
f"Mem: {memory_used:.1f}GB"
+ (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "")
)
running_loss = 0.0
start_time = time.time()
# Validation
if step % 100 == 0:
model.eval()
set_dropout(model, 0.0)
val_loss = 0
with torch.no_grad():
for val_batch in val_loader:
val_input_ids = val_batch['input_ids'].cuda()
val_labels = val_batch['labels'].cuda()
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
val_outputs = model(val_input_ids)
if isinstance(val_outputs, tuple):
val_logits, _ = val_outputs
else:
val_logits = val_outputs
val_loss += F.cross_entropy(
val_logits.view(-1, 2),
val_labels.view(-1)
).item()
val_loss /= len(val_loader)
logger.info(f"πŸ“Š Validation Loss: {val_loss:.4f}")
# Save checkpoint
if step % 500 == 0:
checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'config': config.get_model_config(),
}, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt")
logger.info(f"πŸ’Ύ Checkpoint saved: step {step}")
if step >= config.MAX_STEPS:
logger.info("🏁 Training completed!")
return
if step >= config.MAX_STEPS:
break
if __name__ == "__main__":
main()