Phase 1 β Training Loop & Runtime Optimizations (apply these first)
Task 1 β Make batch size configurable & fix OneCycle accounting β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"
β OneCycleβs horizon matches reality across runs.
Task 2 β Remove hardcoded total_steps=100
in dashboard/MCP β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file dashboard/manager.py \
--edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"
β Aligns scheduler behavior between direct loop and MCP/dashboard.
Task 3 β Add mixed-precision autocast (AMP, BF16) β COMPLETED β
Prompt (pseudo-patch):
with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
logits = model(batch)
loss = criterion(logits, labels)
loss.backward()
β 1.2β1.8Γ throughput on attention-heavy training. Keep grad-clip.
Task 4 β Add gradient accumulation β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"
β Simulates larger effective batch sizes without extra memory.
Task 5 β Optimize dataset pipeline (mmap + streaming) β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file data/wikitext_schedule.py \
--edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"
β Removes conversion bottlenecks on large corpora.
Task 6 β Schedule compression probability (safer ramp) β COMPLETED β
Prompt (pseudo-code):
compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)
β Prevents early instability from aggressive compression.
Task 7 β Stabilize safety gate (EMA + burnβin) β COMPLETED β
Prompt (pseudo-patch):
ema_val = ema(val_loss, decay=0.9)
if step < burn_in_steps:
allow_training = True
elif ema_val > threshold:
trigger_gate()
β Reduces false positives from noisy early validations.
Task 8 β Enable torch.compile
selectively β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"
β Opportunistic speedup where supported.
Task 9 β Integrate FlashAttention / SDPA
Prompt (pseudo-patch):
from torch.nn import functional as F
def forward_attention(q, k, v, is_causal=True):
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
β
Unlocks fused kernels; prefer is_causal=True
over boolean masks.
Task 10 β Cache causal masks β COMPLETED β
Prompt (pseudo-code):
mask_cache = {}
def get_tri_mask(seq_len, device):
key = (seq_len, device)
if key not in mask_cache:
mask_cache[key] = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
)
return mask_cache[key]
β
Avoids repeated triu
allocations when masks are still needed.
Task 11 β Fix stitched attention negative indexing β COMPLETED β
Prompt (pseudo-code):
start = max(s - overlap, 0)
end = min(s + chunk_size, T)
canvas[..., start:end] = attn_chunk[..., : end - start]
β Prevents wrap-around misplacement during TΓT map reconstruction.
Task 12 β Default off: full TΓT attention logging in chunked runs β COMPLETED β
Prompt:
codex run bittransformerlm/patch \
--file bit_transformer/model.py \
--edit "Set full_attn_logging=False by default when chunk_size is set"
β Big memory/time savings without losing training signal.
Phase 2 β Model Creation & Training Tasks (run after Phase 1)
Task A β Train the best current baseline (8Γ256 with ACT)
Prompt:
codex run bittransformerlm/train \
--layers 8 \
--d_model 256 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--act true \
--reversible true \
--checkpointing true \
--batch_size 64 \
--accum_steps 2 \
--amp bf16 \
--lr_schedule progressive_plateau \
--full_attn_logging false
β Reproduces the validated sweet spot with newly enabled efficiency features.
Task B β CPUβfriendly deployment (8Γ128, INT8 + optional QAT)
Prompt:
codex run bittransformerlm/train \
--layers 8 \
--d_model 128 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--quantization int8 \
--qat true \
--reversible true \
--checkpointing true \
--batch_size 128 \
--accum_steps 1 \
--amp bf16
β Efficient CPU target; QAT optional based on deployment constraints.
Task C β Cautious scaleβup candidate (16Γ256)
Prompt:
codex run bittransformerlm/train \
--layers 16 \
--d_model 256 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--act true \
--reversible true \
--checkpointing true \
--batch_size 48 \
--accum_steps 3 \
--amp bf16 \
--lr_schedule progressive_plateau
β οΈ Use only after data expansion and schedule retune.
Recommended Execution Order
- Phase 1 Tasks 1β12 (apply all optimizations).
- Task A baseline β validate.
- Task B CPU build β validate + (optional) QAT.
- Task C scaleβup only when data/schedule allow.
Notes
- Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
- Keep
full_attn_logging=false
in chunked runs; enable selectively when inspecting attention. - When using SDPA, prefer
is_causal=True
and avoid passing dense masks unless required.