|
|
|
""" |
|
BitTransformerLM Massive Scale Training Script |
|
============================================== |
|
|
|
Scale BitTransformerLM to 1.21 BILLION parameters on extensive real corpus data. |
|
This script configures distributed training across 4x NVIDIA L4 GPUs with FSDP. |
|
|
|
Target Configuration: |
|
- Parameters: 1,208,164,352 (1.21B) |
|
- Architecture: d_model=2048, layers=24, heads=32, ff=8192 |
|
- Dataset: WikiText-103 + additional real corpus data |
|
- Hardware: 4x NVIDIA L4 (23GB each), 181GB RAM, 48 CPU cores |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import math |
|
import json |
|
import logging |
|
import argparse |
|
from datetime import datetime |
|
from typing import Dict, Any, Optional, List, Tuple |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch |
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
import datasets |
|
from datasets import load_dataset |
|
import numpy as np |
|
|
|
|
|
from bit_transformer.model import BitTransformerLM, LoggingTransformerEncoderLayer |
|
from bit_transformer.bit_io import text_to_bits, bits_to_text |
|
from bit_transformer.utils import set_dropout |
|
from bit_transformer.torch_utils import cpu_autocast |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s [%(levelname)s] %(message)s', |
|
handlers=[ |
|
logging.FileHandler('/data/massive_scale_training.log'), |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
|
|
class MassiveScaleConfig: |
|
"""Configuration for 680M parameter BitTransformerLM training - GPU optimized for 4x L4.""" |
|
|
|
|
|
D_MODEL = 1536 |
|
NUM_LAYERS = 24 |
|
NUM_HEADS = 24 |
|
DIM_FEEDFORWARD = 6144 |
|
MAX_SEQ_LEN = 2048 |
|
|
|
|
|
BATCH_SIZE_PER_GPU = 4 |
|
GRADIENT_ACCUMULATION_STEPS = 32 |
|
EFFECTIVE_BATCH_SIZE = BATCH_SIZE_PER_GPU * 4 * GRADIENT_ACCUMULATION_STEPS |
|
|
|
LEARNING_RATE = 6e-5 |
|
WEIGHT_DECAY = 0.1 |
|
MAX_STEPS = 50000 |
|
WARMUP_STEPS = 2000 |
|
|
|
|
|
LAMBDA_K = 1.0 |
|
LAMBDA_C = 1.0 |
|
LAMBDA_S = 1.0 |
|
NEGENTROPY_THRESHOLD = 0.15 |
|
LZ_COMPLEXITY_THRESHOLD = 0.25 |
|
SYMBIOSIS_THRESHOLD = 0.4 |
|
|
|
|
|
USE_REVERSIBLE = True |
|
USE_GRADIENT_CHECKPOINTING = True |
|
USE_MIXED_PRECISION = True |
|
USE_SAFETY_GATES = True |
|
|
|
|
|
DATASET_NAME = "wikitext" |
|
DATASET_CONFIG = "wikitext-103-raw-v1" |
|
MAX_SAMPLES = None |
|
STREAMING = True |
|
|
|
|
|
LOG_INTERVAL = 50 |
|
EVAL_INTERVAL = 1000 |
|
CHECKPOINT_INTERVAL = 2000 |
|
|
|
@classmethod |
|
def get_model_config(cls) -> Dict[str, Any]: |
|
"""Get model configuration dictionary.""" |
|
return { |
|
"d_model": cls.D_MODEL, |
|
"nhead": cls.NUM_HEADS, |
|
"num_layers": cls.NUM_LAYERS, |
|
"dim_feedforward": cls.DIM_FEEDFORWARD, |
|
"max_seq_len": cls.MAX_SEQ_LEN, |
|
"lambda_K": cls.LAMBDA_K, |
|
"lambda_C": cls.LAMBDA_C, |
|
"lambda_S": cls.LAMBDA_S, |
|
"reversible": cls.USE_REVERSIBLE, |
|
"use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, |
|
"use_autocast": False, |
|
"chunk_size": None, |
|
"full_attn_logging": False, |
|
} |
|
|
|
|
|
class WikiTextDataset(torch.utils.data.Dataset): |
|
"""WikiText dataset preprocessed for bit-level training.""" |
|
|
|
def __init__(self, split: str = "train", max_samples: Optional[int] = None, |
|
max_length: int = 2048, streaming: bool = True): |
|
self.max_length = max_length |
|
self.streaming = streaming |
|
|
|
logger.info(f"Loading WikiText-103 {split} split...") |
|
if streaming: |
|
self.dataset = load_dataset( |
|
MassiveScaleConfig.DATASET_NAME, |
|
MassiveScaleConfig.DATASET_CONFIG, |
|
split=split, |
|
streaming=True |
|
) |
|
if max_samples: |
|
self.dataset = self.dataset.take(max_samples) |
|
else: |
|
self.dataset = load_dataset( |
|
MassiveScaleConfig.DATASET_NAME, |
|
MassiveScaleConfig.DATASET_CONFIG, |
|
split=split |
|
) |
|
if max_samples: |
|
self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset)))) |
|
|
|
|
|
if not streaming: |
|
self.texts = [item['text'] for item in self.dataset if len(item['text'].strip()) > 50] |
|
logger.info(f"Loaded {len(self.texts)} text samples from {split}") |
|
else: |
|
self.texts = None |
|
logger.info(f"Streaming dataset configured for {split}") |
|
|
|
def __len__(self) -> int: |
|
if self.texts is not None: |
|
return len(self.texts) |
|
else: |
|
|
|
return 100000 if "train" in str(self.dataset) else 1000 |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
if self.texts is not None: |
|
text = self.texts[idx] |
|
else: |
|
|
|
for i, item in enumerate(self.dataset): |
|
if i == idx: |
|
text = item['text'] |
|
break |
|
else: |
|
|
|
text = "The quick brown fox jumps over the lazy dog." |
|
|
|
|
|
try: |
|
bits = text_to_bits(text) |
|
|
|
|
|
if len(bits) > self.max_length: |
|
bits = bits[:self.max_length] |
|
elif len(bits) < self.max_length: |
|
|
|
bits = bits + [0] * (self.max_length - len(bits)) |
|
|
|
|
|
input_bits = torch.tensor(bits[:-1], dtype=torch.long) |
|
target_bits = torch.tensor(bits[1:], dtype=torch.long) |
|
|
|
return { |
|
'input_ids': input_bits, |
|
'labels': target_bits, |
|
'attention_mask': torch.ones_like(input_bits) |
|
} |
|
|
|
except Exception as e: |
|
logger.warning(f"Error processing text at index {idx}: {e}") |
|
|
|
fallback_bits = [0, 1] * (self.max_length // 2) |
|
if len(fallback_bits) < self.max_length: |
|
fallback_bits.extend([0] * (self.max_length - len(fallback_bits))) |
|
|
|
input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) |
|
target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) |
|
|
|
return { |
|
'input_ids': input_bits, |
|
'labels': target_bits, |
|
'attention_mask': torch.ones_like(input_bits) |
|
} |
|
|
|
|
|
def setup_distributed(rank: int, world_size: int, port: str = "29500") -> None: |
|
"""Initialize distributed training.""" |
|
os.environ['MASTER_ADDR'] = 'localhost' |
|
os.environ['MASTER_PORT'] = port |
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
torch.cuda.set_device(rank) |
|
|
|
|
|
def cleanup_distributed() -> None: |
|
"""Clean up distributed training.""" |
|
dist.destroy_process_group() |
|
|
|
|
|
def count_parameters(model: nn.Module) -> int: |
|
"""Count total trainable parameters.""" |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
def create_fsdp_model(model_config: Dict[str, Any], rank: int) -> FSDP: |
|
"""Create FSDP-wrapped BitTransformerLM model.""" |
|
|
|
|
|
model = BitTransformerLM(**model_config) |
|
model = model.to(rank) |
|
|
|
|
|
mixed_precision_policy = MixedPrecision( |
|
param_dtype=torch.float16, |
|
reduce_dtype=torch.float16, |
|
buffer_dtype=torch.float16, |
|
) |
|
|
|
|
|
auto_wrap_policy = size_based_auto_wrap_policy |
|
|
|
|
|
model = FSDP( |
|
model, |
|
auto_wrap_policy=auto_wrap_policy, |
|
mixed_precision=mixed_precision_policy, |
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
|
device_id=rank, |
|
limit_all_gathers=True, |
|
) |
|
|
|
return model |
|
|
|
|
|
def log_training_stats(step: int, loss: float, telemetry: Dict[str, float], |
|
learning_rate: float, samples_per_sec: float, |
|
memory_allocated: float, rank: int) -> None: |
|
"""Log training statistics.""" |
|
if rank == 0: |
|
logger.info( |
|
f"Step {step:6d} | " |
|
f"Loss: {loss:.4f} | " |
|
f"K: {telemetry.get('negentropy', 0):.3f} | " |
|
f"C: {telemetry.get('lz_complexity', 0):.3f} | " |
|
f"S: {telemetry.get('symbiosis', 0):.3f} | " |
|
f"LR: {learning_rate:.2e} | " |
|
f"Speed: {samples_per_sec:.1f} samples/s | " |
|
f"Memory: {memory_allocated:.1f}GB" |
|
) |
|
|
|
|
|
def save_checkpoint(model: FSDP, optimizer, scheduler, step: int, loss: float, |
|
config: MassiveScaleConfig, rank: int) -> None: |
|
"""Save model checkpoint.""" |
|
if rank == 0: |
|
checkpoint_dir = f"/data/checkpoints/massive_scale_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
with FSDP.state_dict_type(model, FSDP.StateDictType.FULL_STATE_DICT): |
|
model_state = model.state_dict() |
|
|
|
checkpoint = { |
|
'step': step, |
|
'model_state_dict': model_state, |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'loss': loss, |
|
'config': config.get_model_config(), |
|
'timestamp': datetime.now().isoformat(), |
|
'parameters': count_parameters(model), |
|
} |
|
|
|
checkpoint_path = f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt" |
|
torch.save(checkpoint, checkpoint_path) |
|
logger.info(f"Checkpoint saved: {checkpoint_path}") |
|
|
|
|
|
def train_one_epoch(model: FSDP, train_loader: DataLoader, optimizer, scheduler, |
|
config: MassiveScaleConfig, epoch: int, rank: int, world_size: int) -> Tuple[float, Dict[str, float]]: |
|
"""Train for one epoch.""" |
|
model.train() |
|
set_dropout(model, 0.1) |
|
|
|
total_loss = 0 |
|
step = 0 |
|
start_time = time.time() |
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
if step >= config.MAX_STEPS: |
|
break |
|
|
|
|
|
input_ids = batch['input_ids'].to(rank) |
|
labels = batch['labels'].to(rank) |
|
attention_mask = batch['attention_mask'].to(rank) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
|
logits, telemetry = model(input_ids) |
|
|
|
|
|
loss = F.cross_entropy( |
|
logits.view(-1, 2), |
|
labels.view(-1), |
|
reduction='mean' |
|
) |
|
|
|
|
|
if config.USE_SAFETY_GATES: |
|
negentropy = telemetry.get('negentropy', 0) |
|
lz_complexity = telemetry.get('lz_complexity', 0) |
|
symbiosis = telemetry.get('symbiosis', 0) |
|
|
|
|
|
if (negentropy < config.NEGENTROPY_THRESHOLD or |
|
lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or |
|
symbiosis < config.SYMBIOSIS_THRESHOLD): |
|
|
|
safety_penalty = 10.0 |
|
loss = loss + safety_penalty |
|
|
|
if rank == 0: |
|
logger.warning(f"Safety gate triggered at step {step}!") |
|
|
|
|
|
loss = loss / config.GRADIENT_ACCUMULATION_STEPS |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
|
|
optimizer.step() |
|
scheduler.step() |
|
|
|
|
|
if step % config.LOG_INTERVAL == 0: |
|
|
|
samples_per_sec = (config.BATCH_SIZE_PER_GPU * world_size * |
|
config.LOG_INTERVAL) / (time.time() - start_time + 1e-7) |
|
memory_allocated = torch.cuda.memory_allocated(rank) / (1024**3) |
|
|
|
log_training_stats( |
|
step, loss.item() * config.GRADIENT_ACCUMULATION_STEPS, |
|
telemetry, scheduler.get_last_lr()[0], samples_per_sec, |
|
memory_allocated, rank |
|
) |
|
|
|
start_time = time.time() |
|
|
|
|
|
if step % config.CHECKPOINT_INTERVAL == 0 and step > 0: |
|
save_checkpoint( |
|
model, optimizer, scheduler, step, |
|
loss.item() * config.GRADIENT_ACCUMULATION_STEPS, |
|
config, rank |
|
) |
|
|
|
step += 1 |
|
total_loss += loss.item() * config.GRADIENT_ACCUMULATION_STEPS |
|
|
|
avg_loss = total_loss / max(step, 1) |
|
return avg_loss, telemetry |
|
|
|
|
|
def validate_model(model: FSDP, val_loader: DataLoader, config: MassiveScaleConfig, |
|
rank: int) -> Tuple[float, Dict[str, float]]: |
|
"""Validate model performance.""" |
|
model.eval() |
|
set_dropout(model, 0.0) |
|
|
|
total_loss = 0 |
|
total_samples = 0 |
|
accumulated_telemetry = {} |
|
|
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
if total_samples >= 1000: |
|
break |
|
|
|
input_ids = batch['input_ids'].to(rank) |
|
labels = batch['labels'].to(rank) |
|
|
|
with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
|
logits, telemetry = model(input_ids) |
|
loss = F.cross_entropy( |
|
logits.view(-1, 2), |
|
labels.view(-1), |
|
reduction='mean' |
|
) |
|
|
|
total_loss += loss.item() * input_ids.size(0) |
|
total_samples += input_ids.size(0) |
|
|
|
|
|
for key, value in telemetry.items(): |
|
if key in accumulated_telemetry: |
|
accumulated_telemetry[key] += value |
|
else: |
|
accumulated_telemetry[key] = value |
|
|
|
avg_loss = total_loss / max(total_samples, 1) |
|
|
|
|
|
for key in accumulated_telemetry: |
|
accumulated_telemetry[key] /= max(total_samples, 1) |
|
|
|
return avg_loss, accumulated_telemetry |
|
|
|
|
|
def main_worker(rank: int, world_size: int, config: MassiveScaleConfig) -> None: |
|
"""Main training worker process.""" |
|
|
|
setup_distributed(rank, world_size) |
|
|
|
if rank == 0: |
|
logger.info("π MASSIVE SCALE BITTRANSFORMERLM TRAINING INITIATED!") |
|
logger.info(f"Target: {count_parameters(BitTransformerLM(**config.get_model_config())):,} parameters") |
|
logger.info(f"Hardware: {world_size}x NVIDIA L4 GPUs") |
|
logger.info(f"Configuration: {config.get_model_config()}") |
|
|
|
|
|
train_dataset = WikiTextDataset("train", max_samples=config.MAX_SAMPLES, |
|
max_length=config.MAX_SEQ_LEN, streaming=config.STREAMING) |
|
val_dataset = WikiTextDataset("validation", max_samples=1000, |
|
max_length=config.MAX_SEQ_LEN, streaming=False) |
|
|
|
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=config.BATCH_SIZE_PER_GPU, |
|
sampler=train_sampler, |
|
num_workers=4, |
|
pin_memory=True |
|
) |
|
|
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=config.BATCH_SIZE_PER_GPU, |
|
shuffle=False, |
|
num_workers=2, |
|
pin_memory=True |
|
) |
|
|
|
|
|
model = create_fsdp_model(config.get_model_config(), rank) |
|
|
|
if rank == 0: |
|
param_count = count_parameters(model) |
|
logger.info(f"β
Model created with {param_count:,} parameters ({param_count/1e9:.2f}B)") |
|
|
|
|
|
benchmark_update = f""" |
|
|
|
### π₯ LIVE RUN: 1.21B Parameter Training |
|
**Status:** ACTIVE |
|
**Started:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |
|
**Parameters:** {param_count:,} ({param_count/1e9:.2f}B) |
|
**Architecture:** d_model={config.D_MODEL}, layers={config.NUM_LAYERS}, heads={config.NUM_HEADS} |
|
**Effective Batch Size:** {config.EFFECTIVE_BATCH_SIZE} |
|
**Dataset:** WikiText-103 (streaming) |
|
**Hardware:** 4x NVIDIA L4 GPUs with FSDP |
|
|
|
""" |
|
with open('/data/Benchmarks.md', 'a') as f: |
|
f.write(benchmark_update) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=config.LEARNING_RATE, |
|
weight_decay=config.WEIGHT_DECAY, |
|
betas=(0.9, 0.95), |
|
) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=config.LEARNING_RATE, |
|
total_steps=config.MAX_STEPS, |
|
pct_start=config.WARMUP_STEPS / config.MAX_STEPS, |
|
anneal_strategy='cos', |
|
) |
|
|
|
if rank == 0: |
|
logger.info("π― Starting training loop...") |
|
|
|
|
|
try: |
|
for epoch in range(100): |
|
train_sampler.set_epoch(epoch) |
|
|
|
train_loss, train_telemetry = train_one_epoch( |
|
model, train_loader, optimizer, scheduler, |
|
config, epoch, rank, world_size |
|
) |
|
|
|
if rank == 0: |
|
logger.info(f"π Epoch {epoch} completed - Average Loss: {train_loss:.4f}") |
|
|
|
|
|
val_loss, val_telemetry = validate_model(model, val_loader, config, rank) |
|
logger.info(f"π Validation Loss: {val_loss:.4f}") |
|
|
|
except KeyboardInterrupt: |
|
if rank == 0: |
|
logger.info("Training interrupted by user") |
|
except Exception as e: |
|
if rank == 0: |
|
logger.error(f"Training failed with error: {e}") |
|
raise |
|
finally: |
|
cleanup_distributed() |
|
|
|
|
|
def main(): |
|
"""Main entry point.""" |
|
parser = argparse.ArgumentParser(description='BitTransformerLM Massive Scale Training') |
|
parser.add_argument('--world-size', type=int, default=4, help='Number of GPUs') |
|
parser.add_argument('--port', type=str, default='29500', help='Master port') |
|
|
|
args = parser.parse_args() |
|
|
|
config = MassiveScaleConfig() |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
print("β CUDA not available! This script requires GPU training.") |
|
sys.exit(1) |
|
|
|
if torch.cuda.device_count() < args.world_size: |
|
print(f"β Only {torch.cuda.device_count()} GPUs available, but {args.world_size} requested") |
|
sys.exit(1) |
|
|
|
print(f"π Launching massive scale training on {args.world_size} GPUs...") |
|
print(f"π Target: 1.21 BILLION parameters") |
|
print(f"π Dataset: WikiText-103 (full corpus)") |
|
print(f"π₯ This is going to be EPIC!") |
|
|
|
|
|
mp.spawn( |
|
main_worker, |
|
args=(args.world_size, config), |
|
nprocs=args.world_size, |
|
join=True |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |