|
|
|
""" |
|
BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED |
|
================================================================= |
|
|
|
Fixed version that properly initializes 680M parameter model with all optimizations! |
|
Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues. |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import json |
|
import logging |
|
from datetime import datetime |
|
from typing import Dict, Any, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
import datasets |
|
from datasets import load_dataset |
|
import numpy as np |
|
|
|
|
|
from bit_transformer.model import BitTransformerLM |
|
from bit_transformer.bit_io import text_to_bits, bits_to_text |
|
from bit_transformer.utils import set_dropout |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class OptimizedConfig: |
|
"""Optimized 680M parameter configuration with ALL BitTransformerLM features enabled.""" |
|
|
|
|
|
D_MODEL = 1536 |
|
NUM_LAYERS = 24 |
|
NUM_HEADS = 24 |
|
DIM_FEEDFORWARD = 6144 |
|
MAX_SEQ_LEN = 2048 |
|
|
|
|
|
BATCH_SIZE_PER_GPU = 1 |
|
NUM_GPUS = 4 |
|
TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS |
|
GRADIENT_ACCUMULATION_STEPS = 32 |
|
|
|
LEARNING_RATE = 3e-4 |
|
WEIGHT_DECAY = 0.01 |
|
MAX_STEPS = 10000 |
|
WARMUP_STEPS = 500 |
|
|
|
|
|
USE_REVERSIBLE = True |
|
USE_GRADIENT_CHECKPOINTING = True |
|
USE_MIXED_PRECISION = True |
|
USE_AUTOCAST = True |
|
CHUNK_SIZE = None |
|
FULL_ATTN_LOGGING = False |
|
|
|
|
|
LAMBDA_K = 1.0 |
|
LAMBDA_C = 1.0 |
|
LAMBDA_S = 1.0 |
|
NEGENTROPY_THRESHOLD = 0.2 |
|
LZ_COMPLEXITY_THRESHOLD = 0.3 |
|
SYMBIOSIS_THRESHOLD = 0.5 |
|
|
|
@classmethod |
|
def get_model_config(cls) -> Dict[str, Any]: |
|
"""Get optimized model configuration.""" |
|
return { |
|
"d_model": cls.D_MODEL, |
|
"nhead": cls.NUM_HEADS, |
|
"num_layers": cls.NUM_LAYERS, |
|
"dim_feedforward": cls.DIM_FEEDFORWARD, |
|
"max_seq_len": cls.MAX_SEQ_LEN, |
|
"lambda_K": cls.LAMBDA_K, |
|
"lambda_C": cls.LAMBDA_C, |
|
"lambda_S": cls.LAMBDA_S, |
|
"reversible": cls.USE_REVERSIBLE, |
|
"use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, |
|
"use_autocast": cls.USE_AUTOCAST, |
|
"chunk_size": cls.CHUNK_SIZE, |
|
"full_attn_logging": cls.FULL_ATTN_LOGGING, |
|
} |
|
|
|
|
|
class SimpleWikiTextDataset(torch.utils.data.Dataset): |
|
"""Simplified WikiText dataset for bit-level training.""" |
|
|
|
def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048): |
|
self.max_length = max_length |
|
|
|
logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...") |
|
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) |
|
|
|
|
|
texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples] |
|
self.texts = texts |
|
|
|
logger.info(f"Loaded {len(self.texts)} text samples from {split}") |
|
|
|
def __len__(self) -> int: |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
text = self.texts[idx] |
|
|
|
try: |
|
|
|
bits = text_to_bits(text) |
|
|
|
|
|
if len(bits) > self.max_length: |
|
bits = bits[:self.max_length] |
|
elif len(bits) < self.max_length: |
|
bits = bits + [0] * (self.max_length - len(bits)) |
|
|
|
|
|
input_bits = torch.tensor(bits[:-1], dtype=torch.long) |
|
target_bits = torch.tensor(bits[1:], dtype=torch.long) |
|
|
|
return { |
|
'input_ids': input_bits, |
|
'labels': target_bits, |
|
'attention_mask': torch.ones_like(input_bits) |
|
} |
|
|
|
except Exception as e: |
|
logger.warning(f"Error processing text at index {idx}: {e}") |
|
|
|
fallback_bits = [0, 1] * (self.max_length // 2) |
|
input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) |
|
target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) |
|
|
|
return { |
|
'input_ids': input_bits, |
|
'labels': target_bits, |
|
'attention_mask': torch.ones_like(input_bits) |
|
} |
|
|
|
|
|
def create_optimized_model(config: OptimizedConfig) -> nn.Module: |
|
"""Create properly optimized BitTransformerLM model.""" |
|
|
|
|
|
logger.info("ποΈ Creating optimized BitTransformerLM model...") |
|
model_config = config.get_model_config() |
|
|
|
logger.info("Model configuration:") |
|
for k, v in model_config.items(): |
|
logger.info(f" {k}: {v}") |
|
|
|
model = BitTransformerLM(**model_config) |
|
|
|
|
|
params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logger.info(f"β
Model created: {params:,} parameters ({params/1e6:.1f}M)") |
|
|
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS: |
|
logger.info(f"π Setting up multi-GPU training on {config.NUM_GPUS} GPUs...") |
|
|
|
|
|
model = model.cuda() |
|
|
|
|
|
if config.NUM_GPUS > 1: |
|
model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS))) |
|
logger.info(f"β
DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}") |
|
|
|
else: |
|
logger.warning("β οΈ Limited GPU availability - using single GPU or CPU") |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
|
|
def train_step(model: nn.Module, batch: Dict[str, torch.Tensor], |
|
optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, |
|
config: OptimizedConfig) -> tuple: |
|
"""Optimized training step with all BitTransformerLM features.""" |
|
|
|
model.train() |
|
set_dropout(model, 0.1) |
|
|
|
|
|
input_ids = batch['input_ids'].cuda(non_blocking=True) |
|
labels = batch['labels'].cuda(non_blocking=True) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
|
outputs = model(input_ids) |
|
|
|
if isinstance(outputs, tuple): |
|
logits, telemetry = outputs |
|
else: |
|
logits, telemetry = outputs, {} |
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean') |
|
|
|
|
|
safety_penalty = 0.0 |
|
if telemetry: |
|
negentropy = telemetry.get('negentropy', 1.0) |
|
lz_complexity = telemetry.get('lz_complexity', 1.0) |
|
symbiosis = telemetry.get('symbiosis', 1.0) |
|
|
|
if (negentropy < config.NEGENTROPY_THRESHOLD or |
|
lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or |
|
symbiosis < config.SYMBIOSIS_THRESHOLD): |
|
safety_penalty = 0.1 |
|
loss = loss + safety_penalty |
|
|
|
|
|
loss = loss / config.GRADIENT_ACCUMULATION_STEPS |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
|
return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty |
|
|
|
|
|
def main(): |
|
"""Main training function.""" |
|
|
|
logger.info("π OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!") |
|
logger.info("=" * 60) |
|
|
|
config = OptimizedConfig() |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
logger.error("β CUDA not available!") |
|
return |
|
|
|
logger.info(f"π₯ Hardware: {torch.cuda.device_count()}x GPUs detected") |
|
for i in range(torch.cuda.device_count()): |
|
props = torch.cuda.get_device_properties(i) |
|
logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)") |
|
|
|
|
|
model = create_optimized_model(config) |
|
|
|
|
|
logger.info("π Loading datasets...") |
|
train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN) |
|
val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN) |
|
|
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=config.BATCH_SIZE_PER_GPU, |
|
shuffle=True, |
|
num_workers=2, |
|
pin_memory=True |
|
) |
|
|
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=config.BATCH_SIZE_PER_GPU, |
|
shuffle=False, |
|
num_workers=1, |
|
pin_memory=True |
|
) |
|
|
|
|
|
logger.info("βοΈ Setting up optimizer...") |
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=config.LEARNING_RATE, |
|
weight_decay=config.WEIGHT_DECAY, |
|
betas=(0.9, 0.95) |
|
) |
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=config.LEARNING_RATE, |
|
total_steps=config.MAX_STEPS, |
|
pct_start=config.WARMUP_STEPS / config.MAX_STEPS, |
|
) |
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION) |
|
|
|
|
|
logger.info("π― Starting training...") |
|
logger.info(f"Target steps: {config.MAX_STEPS}") |
|
logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}") |
|
|
|
step = 0 |
|
running_loss = 0.0 |
|
start_time = time.time() |
|
|
|
for epoch in range(100): |
|
for batch_idx, batch in enumerate(train_loader): |
|
|
|
loss, telemetry, safety_penalty = train_step( |
|
model, batch, optimizer, scaler, config |
|
) |
|
running_loss += loss |
|
|
|
|
|
if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: |
|
|
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
step += 1 |
|
|
|
|
|
if step % 10 == 0: |
|
avg_loss = running_loss / 10 |
|
elapsed = time.time() - start_time |
|
samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed |
|
memory_used = torch.cuda.max_memory_allocated() / (1024**3) |
|
|
|
logger.info( |
|
f"Step {step:4d} | " |
|
f"Loss: {avg_loss:.4f} | " |
|
f"K: {telemetry.get('negentropy', 0):.3f} | " |
|
f"C: {telemetry.get('lz_complexity', 0):.3f} | " |
|
f"S: {telemetry.get('symbiosis', 0):.3f} | " |
|
f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
|
f"Speed: {samples_per_sec:.1f} samp/s | " |
|
f"Mem: {memory_used:.1f}GB" |
|
+ (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "") |
|
) |
|
|
|
running_loss = 0.0 |
|
start_time = time.time() |
|
|
|
|
|
if step % 100 == 0: |
|
model.eval() |
|
set_dropout(model, 0.0) |
|
val_loss = 0 |
|
|
|
with torch.no_grad(): |
|
for val_batch in val_loader: |
|
val_input_ids = val_batch['input_ids'].cuda() |
|
val_labels = val_batch['labels'].cuda() |
|
|
|
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
|
val_outputs = model(val_input_ids) |
|
if isinstance(val_outputs, tuple): |
|
val_logits, _ = val_outputs |
|
else: |
|
val_logits = val_outputs |
|
|
|
val_loss += F.cross_entropy( |
|
val_logits.view(-1, 2), |
|
val_labels.view(-1) |
|
).item() |
|
|
|
val_loss /= len(val_loader) |
|
logger.info(f"π Validation Loss: {val_loss:.4f}") |
|
|
|
|
|
if step % 500 == 0: |
|
checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
torch.save({ |
|
'step': step, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'config': config.get_model_config(), |
|
}, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt") |
|
|
|
logger.info(f"πΎ Checkpoint saved: step {step}") |
|
|
|
if step >= config.MAX_STEPS: |
|
logger.info("π Training completed!") |
|
return |
|
|
|
if step >= config.MAX_STEPS: |
|
break |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |