|
|
--- |
|
|
|
|
|
|
|
|
|
|
|
A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom. |
|
|
|
|
|
--- |
|
|
|
|
|
## Phase 1 — Training Loop & Runtime Optimizations (apply these first) |
|
|
|
|
|
### Task 1 — Make batch size configurable & fix OneCycle accounting — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file bit_transformer/training.py \ |
|
|
--edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer" |
|
|
``` |
|
|
|
|
|
✅ OneCycle’s horizon matches reality across runs. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 2 — Remove hardcoded `total_steps=100` in dashboard/MCP — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file dashboard/manager.py \ |
|
|
--edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100" |
|
|
``` |
|
|
|
|
|
✅ Aligns scheduler behavior between direct loop and MCP/dashboard. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 3 — Add mixed-precision autocast (AMP, BF16) — COMPLETED ✅ |
|
|
|
|
|
**Prompt (pseudo-patch):** |
|
|
|
|
|
```python |
|
|
with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16): |
|
|
logits = model(batch) |
|
|
loss = criterion(logits, labels) |
|
|
loss.backward() |
|
|
``` |
|
|
|
|
|
✅ 1.2–1.8× throughput on attention-heavy training. Keep grad-clip. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 4 — Add gradient accumulation — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file bit_transformer/training.py \ |
|
|
--edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps" |
|
|
``` |
|
|
|
|
|
✅ Simulates larger effective batch sizes without extra memory. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 5 — Optimize dataset pipeline (mmap + streaming) — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file data/wikitext_schedule.py \ |
|
|
--edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)" |
|
|
``` |
|
|
|
|
|
✅ Removes conversion bottlenecks on large corpora. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 6 — Schedule compression probability (safer ramp) — COMPLETED ✅ |
|
|
|
|
|
**Prompt (pseudo-code):** |
|
|
|
|
|
```python |
|
|
compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps) |
|
|
``` |
|
|
|
|
|
✅ Prevents early instability from aggressive compression. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 7 — Stabilize safety gate (EMA + burn‑in) — COMPLETED ✅ |
|
|
|
|
|
**Prompt (pseudo-patch):** |
|
|
|
|
|
```python |
|
|
ema_val = ema(val_loss, decay=0.9) |
|
|
if step < burn_in_steps: |
|
|
allow_training = True |
|
|
elif ema_val > threshold: |
|
|
trigger_gate() |
|
|
``` |
|
|
|
|
|
✅ Reduces false positives from noisy early validations. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 8 — Enable `torch.compile` selectively — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file bit_transformer/training.py \ |
|
|
--edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning" |
|
|
``` |
|
|
|
|
|
✅ Opportunistic speedup where supported. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 9 — Integrate FlashAttention / SDPA |
|
|
|
|
|
**Prompt (pseudo-patch):** |
|
|
|
|
|
```python |
|
|
from torch.nn import functional as F |
|
|
|
|
|
def forward_attention(q, k, v, is_causal=True): |
|
|
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) |
|
|
``` |
|
|
|
|
|
✅ Unlocks fused kernels; prefer `is_causal=True` over boolean masks. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 10 — Cache causal masks — COMPLETED ✅ |
|
|
|
|
|
**Prompt (pseudo-code):** |
|
|
|
|
|
```python |
|
|
mask_cache = {} |
|
|
|
|
|
def get_tri_mask(seq_len, device): |
|
|
key = (seq_len, device) |
|
|
if key not in mask_cache: |
|
|
mask_cache[key] = torch.triu( |
|
|
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 |
|
|
) |
|
|
return mask_cache[key] |
|
|
``` |
|
|
|
|
|
✅ Avoids repeated `triu` allocations when masks are still needed. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 11 — Fix stitched attention negative indexing — COMPLETED ✅ |
|
|
|
|
|
**Prompt (pseudo-code):** |
|
|
|
|
|
```python |
|
|
start = max(s - overlap, 0) |
|
|
end = min(s + chunk_size, T) |
|
|
canvas[..., start:end] = attn_chunk[..., : end - start] |
|
|
``` |
|
|
|
|
|
✅ Prevents wrap-around misplacement during T×T map reconstruction. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task 12 — Default off: full T×T attention logging in chunked runs — COMPLETED ✅ |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/patch \ |
|
|
--file bit_transformer/model.py \ |
|
|
--edit "Set full_attn_logging=False by default when chunk_size is set" |
|
|
``` |
|
|
|
|
|
✅ Big memory/time savings without losing training signal. |
|
|
|
|
|
--- |
|
|
|
|
|
## Phase 2 — Model Creation & Training Tasks (run after Phase 1) |
|
|
|
|
|
### Task A — Train the best current baseline (8×256 with ACT) |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/train \ |
|
|
--layers 8 \ |
|
|
--d_model 256 \ |
|
|
--nhead 8 \ |
|
|
--causal true \ |
|
|
--chunk_size 128 \ |
|
|
--act true \ |
|
|
--reversible true \ |
|
|
--checkpointing true \ |
|
|
--batch_size 64 \ |
|
|
--accum_steps 2 \ |
|
|
--amp bf16 \ |
|
|
--lr_schedule progressive_plateau \ |
|
|
--full_attn_logging false |
|
|
``` |
|
|
|
|
|
✅ Reproduces the validated **sweet spot** with newly enabled efficiency features. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task B — CPU‑friendly deployment (8×128, INT8 + optional QAT) |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/train \ |
|
|
--layers 8 \ |
|
|
--d_model 128 \ |
|
|
--nhead 8 \ |
|
|
--causal true \ |
|
|
--chunk_size 128 \ |
|
|
--quantization int8 \ |
|
|
--qat true \ |
|
|
--reversible true \ |
|
|
--checkpointing true \ |
|
|
--batch_size 128 \ |
|
|
--accum_steps 1 \ |
|
|
--amp bf16 |
|
|
``` |
|
|
|
|
|
✅ Efficient CPU target; QAT optional based on deployment constraints. |
|
|
|
|
|
--- |
|
|
|
|
|
### Task C — Cautious scale‑up candidate (16×256) |
|
|
|
|
|
**Prompt:** |
|
|
|
|
|
```bash |
|
|
codex run bittransformerlm/train \ |
|
|
--layers 16 \ |
|
|
--d_model 256 \ |
|
|
--nhead 8 \ |
|
|
--causal true \ |
|
|
--chunk_size 128 \ |
|
|
--act true \ |
|
|
--reversible true \ |
|
|
--checkpointing true \ |
|
|
--batch_size 48 \ |
|
|
--accum_steps 3 \ |
|
|
--amp bf16 \ |
|
|
--lr_schedule progressive_plateau |
|
|
``` |
|
|
|
|
|
⚠️ Use only after data expansion and schedule retune. |
|
|
|
|
|
--- |
|
|
|
|
|
## Recommended Execution Order |
|
|
|
|
|
1. **Phase 1 Tasks 1–12** (apply all optimizations). |
|
|
2. **Task A** baseline → validate. |
|
|
3. **Task B** CPU build → validate + (optional) QAT. |
|
|
4. **Task C** scale‑up **only** when data/schedule allow. |
|
|
|
|
|
--- |
|
|
|
|
|
### Notes |
|
|
|
|
|
- Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift. |
|
|
- Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention. |
|
|
- When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required. |
|
|
|
|
|
--- |
|
|
|
|
|
|