--- # 🧭 BitTransformerLM Codex Playbook (Merged) A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom. --- ## Phase 1 β€” Training Loop & Runtime Optimizations (apply these first) ### Task 1 β€” Make batch size configurable & fix OneCycle accounting β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file bit_transformer/training.py \ --edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer" ``` βœ… OneCycle’s horizon matches reality across runs. --- ### Task 2 β€” Remove hardcoded `total_steps=100` in dashboard/MCP β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file dashboard/manager.py \ --edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100" ``` βœ… Aligns scheduler behavior between direct loop and MCP/dashboard. --- ### Task 3 β€” Add mixed-precision autocast (AMP, BF16) β€” COMPLETED βœ… **Prompt (pseudo-patch):** ```python with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16): logits = model(batch) loss = criterion(logits, labels) loss.backward() ``` βœ… 1.2–1.8Γ— throughput on attention-heavy training. Keep grad-clip. --- ### Task 4 β€” Add gradient accumulation β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file bit_transformer/training.py \ --edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps" ``` βœ… Simulates larger effective batch sizes without extra memory. --- ### Task 5 β€” Optimize dataset pipeline (mmap + streaming) β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file data/wikitext_schedule.py \ --edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)" ``` βœ… Removes conversion bottlenecks on large corpora. --- ### Task 6 β€” Schedule compression probability (safer ramp) β€” COMPLETED βœ… **Prompt (pseudo-code):** ```python compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps) ``` βœ… Prevents early instability from aggressive compression. --- ### Task 7 β€” Stabilize safety gate (EMA + burn‑in) β€” COMPLETED βœ… **Prompt (pseudo-patch):** ```python ema_val = ema(val_loss, decay=0.9) if step < burn_in_steps: allow_training = True elif ema_val > threshold: trigger_gate() ``` βœ… Reduces false positives from noisy early validations. --- ### Task 8 β€” Enable `torch.compile` selectively β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file bit_transformer/training.py \ --edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning" ``` βœ… Opportunistic speedup where supported. --- ### Task 9 β€” Integrate FlashAttention / SDPA **Prompt (pseudo-patch):** ```python from torch.nn import functional as F def forward_attention(q, k, v, is_causal=True): return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) ``` βœ… Unlocks fused kernels; prefer `is_causal=True` over boolean masks. --- ### Task 10 β€” Cache causal masks β€” COMPLETED βœ… **Prompt (pseudo-code):** ```python mask_cache = {} def get_tri_mask(seq_len, device): key = (seq_len, device) if key not in mask_cache: mask_cache[key] = torch.triu( torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 ) return mask_cache[key] ``` βœ… Avoids repeated `triu` allocations when masks are still needed. --- ### Task 11 β€” Fix stitched attention negative indexing β€” COMPLETED βœ… **Prompt (pseudo-code):** ```python start = max(s - overlap, 0) end = min(s + chunk_size, T) canvas[..., start:end] = attn_chunk[..., : end - start] ``` βœ… Prevents wrap-around misplacement during TΓ—T map reconstruction. --- ### Task 12 β€” Default off: full TΓ—T attention logging in chunked runs β€” COMPLETED βœ… **Prompt:** ```bash codex run bittransformerlm/patch \ --file bit_transformer/model.py \ --edit "Set full_attn_logging=False by default when chunk_size is set" ``` βœ… Big memory/time savings without losing training signal. --- ## Phase 2 β€” Model Creation & Training Tasks (run after Phase 1) ### Task A β€” Train the best current baseline (8Γ—256 with ACT) **Prompt:** ```bash codex run bittransformerlm/train \ --layers 8 \ --d_model 256 \ --nhead 8 \ --causal true \ --chunk_size 128 \ --act true \ --reversible true \ --checkpointing true \ --batch_size 64 \ --accum_steps 2 \ --amp bf16 \ --lr_schedule progressive_plateau \ --full_attn_logging false ``` βœ… Reproduces the validated **sweet spot** with newly enabled efficiency features. --- ### Task B β€” CPU‑friendly deployment (8Γ—128, INT8 + optional QAT) **Prompt:** ```bash codex run bittransformerlm/train \ --layers 8 \ --d_model 128 \ --nhead 8 \ --causal true \ --chunk_size 128 \ --quantization int8 \ --qat true \ --reversible true \ --checkpointing true \ --batch_size 128 \ --accum_steps 1 \ --amp bf16 ``` βœ… Efficient CPU target; QAT optional based on deployment constraints. --- ### Task C β€” Cautious scale‑up candidate (16Γ—256) **Prompt:** ```bash codex run bittransformerlm/train \ --layers 16 \ --d_model 256 \ --nhead 8 \ --causal true \ --chunk_size 128 \ --act true \ --reversible true \ --checkpointing true \ --batch_size 48 \ --accum_steps 3 \ --amp bf16 \ --lr_schedule progressive_plateau ``` ⚠️ Use only after data expansion and schedule retune. --- ## Recommended Execution Order 1. **Phase 1 Tasks 1–12** (apply all optimizations). 2. **Task A** baseline β†’ validate. 3. **Task B** CPU build β†’ validate + (optional) QAT. 4. **Task C** scale‑up **only** when data/schedule allow. --- ### Notes - Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift. - Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention. - When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required. ---