BitTransformerLM / bit_transformer_lm_codex_playbook.md
WCNegentropy's picture
πŸ€– Updated BitTransformerLM from development space
36c78b1 verified
|
raw
history blame
6.73 kB

Phase 1 β€” Training Loop & Runtime Optimizations (apply these first)

Task 1 β€” Make batch size configurable & fix OneCycle accounting β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file bit_transformer/training.py \
  --edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"

βœ… OneCycle’s horizon matches reality across runs.


Task 2 β€” Remove hardcoded total_steps=100 in dashboard/MCP β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file dashboard/manager.py \
  --edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"

βœ… Aligns scheduler behavior between direct loop and MCP/dashboard.


Task 3 β€” Add mixed-precision autocast (AMP, BF16) β€” COMPLETED βœ…

Prompt (pseudo-patch):

with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
    logits = model(batch)
    loss = criterion(logits, labels)
loss.backward()

βœ… 1.2–1.8Γ— throughput on attention-heavy training. Keep grad-clip.


Task 4 β€” Add gradient accumulation β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file bit_transformer/training.py \
  --edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"

βœ… Simulates larger effective batch sizes without extra memory.


Task 5 β€” Optimize dataset pipeline (mmap + streaming) β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file data/wikitext_schedule.py \
  --edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"

βœ… Removes conversion bottlenecks on large corpora.


Task 6 β€” Schedule compression probability (safer ramp) β€” COMPLETED βœ…

Prompt (pseudo-code):

compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)

βœ… Prevents early instability from aggressive compression.


Task 7 β€” Stabilize safety gate (EMA + burn‑in) β€” COMPLETED βœ…

Prompt (pseudo-patch):

ema_val = ema(val_loss, decay=0.9)
if step < burn_in_steps:
    allow_training = True
elif ema_val > threshold:
    trigger_gate()

βœ… Reduces false positives from noisy early validations.


Task 8 β€” Enable torch.compile selectively β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file bit_transformer/training.py \
  --edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"

βœ… Opportunistic speedup where supported.


Task 9 β€” Integrate FlashAttention / SDPA

Prompt (pseudo-patch):

from torch.nn import functional as F

def forward_attention(q, k, v, is_causal=True):
    return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)

βœ… Unlocks fused kernels; prefer is_causal=True over boolean masks.


Task 10 β€” Cache causal masks β€” COMPLETED βœ…

Prompt (pseudo-code):

mask_cache = {}

def get_tri_mask(seq_len, device):
    key = (seq_len, device)
    if key not in mask_cache:
        mask_cache[key] = torch.triu(
            torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
        )
    return mask_cache[key]

βœ… Avoids repeated triu allocations when masks are still needed.


Task 11 β€” Fix stitched attention negative indexing β€” COMPLETED βœ…

Prompt (pseudo-code):

start = max(s - overlap, 0)
end   = min(s + chunk_size, T)
canvas[..., start:end] = attn_chunk[..., : end - start]

βœ… Prevents wrap-around misplacement during TΓ—T map reconstruction.


Task 12 β€” Default off: full TΓ—T attention logging in chunked runs β€” COMPLETED βœ…

Prompt:

codex run bittransformerlm/patch \
  --file bit_transformer/model.py \
  --edit "Set full_attn_logging=False by default when chunk_size is set"

βœ… Big memory/time savings without losing training signal.


Phase 2 β€” Model Creation & Training Tasks (run after Phase 1)

Task A β€” Train the best current baseline (8Γ—256 with ACT)

Prompt:

codex run bittransformerlm/train \
  --layers 8 \
  --d_model 256 \
  --nhead 8 \
  --causal true \
  --chunk_size 128 \
  --act true \
  --reversible true \
  --checkpointing true \
  --batch_size 64 \
  --accum_steps 2 \
  --amp bf16 \
  --lr_schedule progressive_plateau \
  --full_attn_logging false

βœ… Reproduces the validated sweet spot with newly enabled efficiency features.


Task B β€” CPU‑friendly deployment (8Γ—128, INT8 + optional QAT)

Prompt:

codex run bittransformerlm/train \
  --layers 8 \
  --d_model 128 \
  --nhead 8 \
  --causal true \
  --chunk_size 128 \
  --quantization int8 \
  --qat true \
  --reversible true \
  --checkpointing true \
  --batch_size 128 \
  --accum_steps 1 \
  --amp bf16

βœ… Efficient CPU target; QAT optional based on deployment constraints.


Task C β€” Cautious scale‑up candidate (16Γ—256)

Prompt:

codex run bittransformerlm/train \
  --layers 16 \
  --d_model 256 \
  --nhead 8 \
  --causal true \
  --chunk_size 128 \
  --act true \
  --reversible true \
  --checkpointing true \
  --batch_size 48 \
  --accum_steps 3 \
  --amp bf16 \
  --lr_schedule progressive_plateau

⚠️ Use only after data expansion and schedule retune.


Recommended Execution Order

  1. Phase 1 Tasks 1–12 (apply all optimizations).
  2. Task A baseline β†’ validate.
  3. Task B CPU build β†’ validate + (optional) QAT.
  4. Task C scale‑up only when data/schedule allow.

Notes

  • Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
  • Keep full_attn_logging=false in chunked runs; enable selectively when inspecting attention.
  • When using SDPA, prefer is_causal=True and avoid passing dense masks unless required.