File size: 6,726 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
---
# π§ BitTransformerLM Codex Playbook (Merged)
A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom.
---
## Phase 1 β Training Loop & Runtime Optimizations (apply these first)
### Task 1 β Make batch size configurable & fix OneCycle accounting β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"
```
β
OneCycleβs horizon matches reality across runs.
---
### Task 2 β Remove hardcoded `total_steps=100` in dashboard/MCP β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file dashboard/manager.py \
--edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"
```
β
Aligns scheduler behavior between direct loop and MCP/dashboard.
---
### Task 3 β Add mixed-precision autocast (AMP, BF16) β COMPLETED β
**Prompt (pseudo-patch):**
```python
with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
logits = model(batch)
loss = criterion(logits, labels)
loss.backward()
```
β
1.2β1.8Γ throughput on attention-heavy training. Keep grad-clip.
---
### Task 4 β Add gradient accumulation β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"
```
β
Simulates larger effective batch sizes without extra memory.
---
### Task 5 β Optimize dataset pipeline (mmap + streaming) β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file data/wikitext_schedule.py \
--edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"
```
β
Removes conversion bottlenecks on large corpora.
---
### Task 6 β Schedule compression probability (safer ramp) β COMPLETED β
**Prompt (pseudo-code):**
```python
compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)
```
β
Prevents early instability from aggressive compression.
---
### Task 7 β Stabilize safety gate (EMA + burnβin) β COMPLETED β
**Prompt (pseudo-patch):**
```python
ema_val = ema(val_loss, decay=0.9)
if step < burn_in_steps:
allow_training = True
elif ema_val > threshold:
trigger_gate()
```
β
Reduces false positives from noisy early validations.
---
### Task 8 β Enable `torch.compile` selectively β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file bit_transformer/training.py \
--edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"
```
β
Opportunistic speedup where supported.
---
### Task 9 β Integrate FlashAttention / SDPA
**Prompt (pseudo-patch):**
```python
from torch.nn import functional as F
def forward_attention(q, k, v, is_causal=True):
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
```
β
Unlocks fused kernels; prefer `is_causal=True` over boolean masks.
---
### Task 10 β Cache causal masks β COMPLETED β
**Prompt (pseudo-code):**
```python
mask_cache = {}
def get_tri_mask(seq_len, device):
key = (seq_len, device)
if key not in mask_cache:
mask_cache[key] = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
)
return mask_cache[key]
```
β
Avoids repeated `triu` allocations when masks are still needed.
---
### Task 11 β Fix stitched attention negative indexing β COMPLETED β
**Prompt (pseudo-code):**
```python
start = max(s - overlap, 0)
end = min(s + chunk_size, T)
canvas[..., start:end] = attn_chunk[..., : end - start]
```
β
Prevents wrap-around misplacement during TΓT map reconstruction.
---
### Task 12 β Default off: full TΓT attention logging in chunked runs β COMPLETED β
**Prompt:**
```bash
codex run bittransformerlm/patch \
--file bit_transformer/model.py \
--edit "Set full_attn_logging=False by default when chunk_size is set"
```
β
Big memory/time savings without losing training signal.
---
## Phase 2 β Model Creation & Training Tasks (run after Phase 1)
### Task A β Train the best current baseline (8Γ256 with ACT)
**Prompt:**
```bash
codex run bittransformerlm/train \
--layers 8 \
--d_model 256 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--act true \
--reversible true \
--checkpointing true \
--batch_size 64 \
--accum_steps 2 \
--amp bf16 \
--lr_schedule progressive_plateau \
--full_attn_logging false
```
β
Reproduces the validated **sweet spot** with newly enabled efficiency features.
---
### Task B β CPUβfriendly deployment (8Γ128, INT8 + optional QAT)
**Prompt:**
```bash
codex run bittransformerlm/train \
--layers 8 \
--d_model 128 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--quantization int8 \
--qat true \
--reversible true \
--checkpointing true \
--batch_size 128 \
--accum_steps 1 \
--amp bf16
```
β
Efficient CPU target; QAT optional based on deployment constraints.
---
### Task C β Cautious scaleβup candidate (16Γ256)
**Prompt:**
```bash
codex run bittransformerlm/train \
--layers 16 \
--d_model 256 \
--nhead 8 \
--causal true \
--chunk_size 128 \
--act true \
--reversible true \
--checkpointing true \
--batch_size 48 \
--accum_steps 3 \
--amp bf16 \
--lr_schedule progressive_plateau
```
β οΈ Use only after data expansion and schedule retune.
---
## Recommended Execution Order
1. **Phase 1 Tasks 1β12** (apply all optimizations).
2. **Task A** baseline β validate.
3. **Task B** CPU build β validate + (optional) QAT.
4. **Task C** scaleβup **only** when data/schedule allow.
---
### Notes
- Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
- Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention.
- When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required.
---
|