Upload 65 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- BitTransformerLM/.github/workflows/ci.yml +29 -0
- BitTransformerLM/.gitignore +103 -0
- BitTransformerLM/ABOUTME.md +110 -0
- BitTransformerLM/AGENTS.md +66 -0
- BitTransformerLM/BitTransformerLM_full_assessment.md +196 -0
- BitTransformerLM/Dockerfile +27 -0
- BitTransformerLM/LICENSE/ALIGNMENT_AND_TRANSPARENCY.txt +42 -0
- BitTransformerLM/LICENSE/COMMERCIAL_LICENSE.txt +34 -0
- BitTransformerLM/LICENSE/CONTRIBUTOR_LICENSE_AGREEMENT.txt +7 -0
- BitTransformerLM/LICENSE/DISCLAIMER.txt +93 -0
- BitTransformerLM/LICENSE/LICENSE.txt +12 -0
- BitTransformerLM/LICENSE/TRADEMARK_POLICY.txt +12 -0
- BitTransformerLM/NEW_CODEX_TASK.md +85 -0
- BitTransformerLM/README.md +177 -0
- BitTransformerLM/bit_transformer/__init__.py +86 -0
- BitTransformerLM/bit_transformer/bit_io.py +97 -0
- BitTransformerLM/bit_transformer/collapse.py +95 -0
- BitTransformerLM/bit_transformer/compression.py +82 -0
- BitTransformerLM/bit_transformer/dashboard.py +58 -0
- BitTransformerLM/bit_transformer/dashboard_app.py +927 -0
- BitTransformerLM/bit_transformer/distil.py +90 -0
- BitTransformerLM/bit_transformer/distributed.py +30 -0
- BitTransformerLM/bit_transformer/hf_checkpoint.py +76 -0
- BitTransformerLM/bit_transformer/model.py +875 -0
- BitTransformerLM/bit_transformer/optimization.py +37 -0
- BitTransformerLM/bit_transformer/parity.py +24 -0
- BitTransformerLM/bit_transformer/quantization.py +89 -0
- BitTransformerLM/bit_transformer/safety.py +149 -0
- BitTransformerLM/bit_transformer/scale.py +36 -0
- BitTransformerLM/bit_transformer/static/style.css +93 -0
- BitTransformerLM/bit_transformer/telemetry.py +95 -0
- BitTransformerLM/bit_transformer/templates/dashboard.html +454 -0
- BitTransformerLM/bit_transformer/torch_utils.py +21 -0
- BitTransformerLM/bit_transformer/training.py +250 -0
- BitTransformerLM/bit_transformer/utils.py +28 -0
- BitTransformerLM/bit_transformer_lm_codex_playbook.md +278 -0
- BitTransformerLM/build_full_bits.py +23 -0
- BitTransformerLM/context_extension.md +43 -0
- BitTransformerLM/example.py +6 -0
- BitTransformerLM/full_bits_train.py +51 -0
- BitTransformerLM/integration_flow.py +110 -0
- BitTransformerLM/integration_schedule.py +379 -0
- BitTransformerLM/mcp_server.py +322 -0
- BitTransformerLM/progressive_scaleup.py +216 -0
- BitTransformerLM/pyproject.toml +18 -0
- BitTransformerLM/recursive_integration_flow.py +128 -0
- BitTransformerLM/requirements.txt +16 -0
- BitTransformerLM/review_cli.py +49 -0
- BitTransformerLM/start.sh +5 -0
- BitTransformerLM/state_of_the_repo_audit.md +98 -0
BitTransformerLM/.github/workflows/ci.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v4
|
14 |
+
- uses: actions/setup-python@v4
|
15 |
+
with:
|
16 |
+
python-version: '3.11'
|
17 |
+
- name: Install dependencies
|
18 |
+
run: |
|
19 |
+
pip install --upgrade pip
|
20 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
|
21 |
+
pip install build
|
22 |
+
- name: Run tests
|
23 |
+
run: pytest -q
|
24 |
+
- name: Build package
|
25 |
+
run: python -m build --sdist --wheel -o dist
|
26 |
+
- uses: actions/upload-artifact@v4
|
27 |
+
with:
|
28 |
+
name: dist
|
29 |
+
path: dist
|
BitTransformerLM/.gitignore
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
*.egg-info/
|
23 |
+
.installed.cfg
|
24 |
+
*.egg
|
25 |
+
MANIFEST
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.nox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
*.py,cover
|
48 |
+
.hypothesis/
|
49 |
+
.pytest_cache/
|
50 |
+
|
51 |
+
# Jupyter Notebook
|
52 |
+
.ipynb_checkpoints
|
53 |
+
|
54 |
+
# Pyre type checker
|
55 |
+
.pyre/
|
56 |
+
|
57 |
+
# mypy
|
58 |
+
.mypy_cache/
|
59 |
+
|
60 |
+
# Environments
|
61 |
+
.env
|
62 |
+
.venv
|
63 |
+
env/
|
64 |
+
venv/
|
65 |
+
ENV/
|
66 |
+
|
67 |
+
# Spyder project settings
|
68 |
+
.spyderproject
|
69 |
+
.spyproject
|
70 |
+
|
71 |
+
# Rope project settings
|
72 |
+
.ropeproject
|
73 |
+
|
74 |
+
# IDEs
|
75 |
+
.idea/
|
76 |
+
.vscode/
|
77 |
+
|
78 |
+
# macOS
|
79 |
+
.DS_Store
|
80 |
+
|
81 |
+
# Logs
|
82 |
+
*.log
|
83 |
+
|
84 |
+
# Plot outputs
|
85 |
+
*.png
|
86 |
+
figures/
|
87 |
+
|
88 |
+
# Model artifacts
|
89 |
+
*.pt
|
90 |
+
*.pth
|
91 |
+
*.bin
|
92 |
+
candidates/
|
93 |
+
approved/
|
94 |
+
review_log.jsonl
|
95 |
+
|
96 |
+
# Configurations
|
97 |
+
*.ini
|
98 |
+
|
99 |
+
# Local data
|
100 |
+
*.sqlite3
|
101 |
+
|
102 |
+
|
103 |
+
*.pt.gz
|
BitTransformerLM/ABOUTME.md
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Here’s a menu of additional, “pure-PyTorch” extensions that can close the gap even further to a production-grade LLM:
|
2 |
+
|
3 |
+
⸻
|
4 |
+
|
5 |
+
1. Native Low-Rank & MoE Layers (DO LAST)
|
6 |
+
|
7 |
+
Why: Expert mixtures and low-rank adapters let you balloon effective parameter count without proportional compute.
|
8 |
+
• Mixture-of-Experts: Implement a tiny gating network (one or two linear layers) that routes each token’s representation to one of E experts (each a small FFN). Only that expert runs on that position, so compute per token stays constant while total capacity grows by E×.
|
9 |
+
• PyTorch sketch:
|
10 |
+
|
11 |
+
class MoE(nn.Module):
|
12 |
+
def __init__(self, d_model, d_ff, n_experts=4):
|
13 |
+
super.__init__
|
14 |
+
self.gate = nn.Linear(d_model, n_experts)
|
15 |
+
self.experts = nn.ModuleList(
|
16 |
+
[nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU, nn.Linear(d_ff, d_model))
|
17 |
+
for _ in range(n_experts)]
|
18 |
+
)
|
19 |
+
def forward(self, x):
|
20 |
+
# x: [T,B,D]
|
21 |
+
logits = self.gate(x) # [T,B,E]
|
22 |
+
w = F.softmax(logits, dim=-1) # [T,B,E]
|
23 |
+
y = torch.stack([expert(x) for expert in self.experts], -1)
|
24 |
+
# y: [T,B,D,E] → weighted sum:
|
25 |
+
out = (y * w.unsqueeze(2)).sum(-1)
|
26 |
+
return out
|
27 |
+
|
28 |
+
|
29 |
+
• Trade-off: You’ll need a load-balancing loss term (e.g. encourage the gate to spread load) and telemetry on expert usage, but the code stays pure PyTorch.
|
30 |
+
|
31 |
+
⸻
|
32 |
+
|
33 |
+
2. [x] Adaptive Computation Time (ACT)
|
34 |
+
|
35 |
+
Why: Let the model learn to spend more depth on “hard” bits and skip layers on easier ones.
|
36 |
+
• Implementation: Add a tiny halting unit after each layer—e.g. a single linear+sigmoid per token that predicts stop/pause. Accumulate “halt probability” across layers and stop processing tokens once they cross a threshold.
|
37 |
+
• Benefit: On average you’ll do fewer layer passes per token, reducing compute without touching PyTorch internals.
|
38 |
+
|
39 |
+
⸻
|
40 |
+
|
41 |
+
3. [x] Advanced PyTorch-Native Quantization
|
42 |
+
|
43 |
+
Why: Move beyond static 4-bit packaging to full QAT / dynamic quant.
|
44 |
+
• FX-graph QAT: Use torch.quantization.prepare_qat_fx on your SparseQuantTransformerLayer with a custom 4-bit observer (we sketched one earlier). Then convert_fx to int8 or 4-bit for weights—no external libs needed.
|
45 |
+
• Dynamic quant for inference: Wrap your model in torch.quantization.quantize_dynamic(...), quantizing only Linear modules to int8 on-the-fly. Gives a big speed/memory win at inference time on CPU.
|
46 |
+
|
47 |
+
⸻
|
48 |
+
|
49 |
+
4. [x] Chunked & Overlapping Attention
|
50 |
+
|
51 |
+
Why: Emulate sparse attention with pure PyTorch and no for-loops.
|
52 |
+
• How: Break your sequence into fixed-size chunks (e.g. 512 bits), attend within each chunk plus a small overlap window to neighbors.
|
53 |
+
• Pure PyTorch: Use unfold + batched torch.matmul to compute all chunked attention in parallel:
|
54 |
+
|
55 |
+
x: [B, L, D], chunk_size=C, overlap=O
|
56 |
+
pads = (O, O)
|
57 |
+
x_padded = F.pad(x, (0,0) + pads) # pad on seq dim
|
58 |
+
chunks = x_padded.unfold(1, C+2*O, C) # [B, n_chunks, C+2O, D]
|
59 |
+
Then project Q,K,V per-chunk and do fused matmuls batchwise
|
60 |
+
|
61 |
+
|
62 |
+
• Benefit: You get an O(L·(C+2O)) algorithm without Python loops, all in tensor ops.
|
63 |
+
|
64 |
+
⸻
|
65 |
+
|
66 |
+
5. Functorch-Based Vectorization & vmap
|
67 |
+
|
68 |
+
Why: Fuse your per-head or per-expert loops automatically.
|
69 |
+
• Use functorch.vmap to turn your per-head attention code (the one inside the for t in range(T)) into a single batched kernel.
|
70 |
+
• Benefit: Cleaner code, fewer Python loops, and TorchInductor can fuse it just as well as hand-written loops.
|
71 |
+
|
72 |
+
⸻
|
73 |
+
|
74 |
+
6. [x] Fully-Sharded DataParallel & Pipeline Parallel (PyTorch-Native)
|
75 |
+
|
76 |
+
Why: Scale out to multiple GPUs without external frameworks.
|
77 |
+
• FSDP: Wrap your model in torch.distributed.fsdp.FullyShardedDataParallel to shard both parameters and optimizer state across GPUs.
|
78 |
+
• Pipe: Use torch.distributed.pipeline.sync.Pipe to split your 40+ layer model across GPUs as pipeline stages.
|
79 |
+
• Benefit: Zero external deps—pure PyTorch DDP/FS/PIPE—so you can train 100M+ parameter models.
|
80 |
+
|
81 |
+
⸻
|
82 |
+
|
83 |
+
7. [x] Mixed Precision & Autocast on CPU (bfloat16)
|
84 |
+
|
85 |
+
Why: PyTorch now supports `torch.amp.autocast('cpu')` for bfloat16 on some architectures.
|
86 |
+
• Surround your forward in with `torch.amp.autocast('cpu')`: to cut memory and speed up linear/attention kernels, even on CPU.
|
87 |
+
|
88 |
+
⸻
|
89 |
+
|
90 |
+
8. [x] Optimized Learning-Rate Schedules & Optimizers
|
91 |
+
|
92 |
+
Why: Achieve GPT-level convergence behavior…
|
93 |
+
• Implement OneCycleLR or CosineAnnealingWarmRestarts directly via torch.optim.lr_scheduler.
|
94 |
+
• Swap to AdamW with decoupled weight decay (torch.optim.AdamW) and dynamic gradient clipping (torch.nn.utils.clip_grad_norm_).
|
95 |
+
• All of these live in core PyTorch.
|
96 |
+
|
97 |
+
⸻
|
98 |
+
|
99 |
+
Putting It All Together
|
100 |
+
1. MoE + ACT will let you scale capacity (E× experts) while controlling average compute.
|
101 |
+
2. FX/QAT + dynamic quant gives you 4-bit int inference with no external libs.
|
102 |
+
3. Chunked attention + vmap replaces loops with giant fused tensor ops.
|
103 |
+
4. FSDP + Pipe moves you onto multi-GPU purely in torch.distributed.
|
104 |
+
5. Autocast (bfloat16) on CPU/GPU for mixed precision speed.
|
105 |
+
|
106 |
+
By layering these techniques, you can:
|
107 |
+
• Reach hundreds of millions (even billions) of effective parameters
|
108 |
+
• Maintain single-library purity (just PyTorch)
|
109 |
+
• Hit LLM-class throughputs (100’s of tokens/sec GPU, 10’s CPU)
|
110 |
+
• Keep full NRB telemetry available for safety checks
|
BitTransformerLM/AGENTS.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AGENTS Guidelines for BitTransformerLM
|
2 |
+
|
3 |
+
## Repository Scope and Purpose
|
4 |
+
- **BitTransformerLM** models raw binary streams using reversible transformer blocks and safety telemetry. The project is the canonical implementation under WCNegentropy.
|
5 |
+
- Core capabilities include bit-native modeling, telemetry metrics (negentropy, LZ complexity, symbiosis), progressive scaling, compression, context extension, diffusion mode (linear/cosine/exp noise schedules with parity correction), dashboard control, distributed training, and quantization.
|
6 |
+
- Phase 1 optimizations provide configurable batch sizing, gradient accumulation, mixed-precision, memory-mapped dataset streaming, scheduled compression ramps, selective `torch.compile`, and an EMA-smoothed safety gate with burn-in.
|
7 |
+
|
8 |
+
## Environment Setup
|
9 |
+
- Requires **Python 3.10+**.
|
10 |
+
- Install dependencies:
|
11 |
+
- CPU: `pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt`
|
12 |
+
- Optional GPU: `pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118`
|
13 |
+
- The package name is `bit-transformer`; project metadata lives in `pyproject.toml`.
|
14 |
+
|
15 |
+
## Repository Layout
|
16 |
+
- `bit_transformer/` – core package (`model`, `compression`, `telemetry`, `safety`, `dashboard_app`, `quantization`, etc.).
|
17 |
+
- `tests/` – pytest suite and historical `TEST_RESULTS.md`.
|
18 |
+
- Scripts: `example.py`, `unified_workflow.py`, `full_bits_train.py`, `build_full_bits.py`, `mcp_server.py`, `wikitext_*` utilities. The legacy `progressive_scaleup.py` is retained for reference but superseded by `integration_schedule.py`.
|
19 |
+
- Docs and specs: `README.md`, `state_of_the_repo_audit.md`, licensing files in `LICENSE/`.
|
20 |
+
|
21 |
+
## Development Practices
|
22 |
+
- Follow snake_case for functions and CamelCase for classes.
|
23 |
+
- Keep functions under ~300 lines and minimize deeply nested control flow.
|
24 |
+
- Avoid reintroducing the deprecated dashboard `/exec` endpoint or other insecure code paths.
|
25 |
+
- Use the `/status` endpoint for model introspection; all routes return JSON and surface errors with stack traces.
|
26 |
+
- Ensure compression, decompression, and halting logic stay consistent with current implementation.
|
27 |
+
- Use the `cpu_autocast()` helper for BF16 mixed precision on CPU instead of
|
28 |
+
calling `torch.amp.autocast` directly.
|
29 |
+
- Adaptive training now expands depth, width, or context only when validation loss plateaus and automatically decays the base learning rate by √2 after each expansion with a 100‑step warm‑up.
|
30 |
+
|
31 |
+
## Workflow & Commands
|
32 |
+
- Run the example: `python example.py`.
|
33 |
+
- Adaptive scaling now lives in `integration_schedule.py`; `progressive_scaleup.py` is deprecated.
|
34 |
+
- Unified workflow (optionally with dashboard or diffusion): `python unified_workflow.py --dashboard` or `python unified_workflow.py --diffusion --diffusion-steps 8 --dataset-size 32`.
|
35 |
+
- Increase `--diffusion-steps` for higher fidelity (8–16) and add `--diffusion-curriculum` to linearly decay noise over epochs.
|
36 |
+
- Disable checkpointing or reversible blocks when speed is prioritized over memory: `python unified_workflow.py --no-checkpoint --no-reversible`.
|
37 |
+
- Enable 4-bit quantization-aware training: `python unified_workflow.py --qat`.
|
38 |
+
- Skip full attention logging during chunked attention for memory savings by constructing the model with `full_attn_logging=False`.
|
39 |
+
- Start MCP server: `python mcp_server.py` and launch dashboard: `MCP_SERVER_ADDR=http://127.0.0.1:7000 python -m bit_transformer.dashboard_app`.
|
40 |
+
- `/metrics` and `/model_config` endpoints expose telemetry streams and hyperparameters.
|
41 |
+
- `/save_checkpoint` and `/download_checkpoint` sync weights with Hugging Face (token defaults to `HF_TOKEN`).
|
42 |
+
- Container build: `docker build -t bittransformerlm .` and run with exposed ports `5000` (dashboard) and `7000` (MCP).
|
43 |
+
|
44 |
+
## Telemetry Metrics
|
45 |
+
| Metric | Meaning | Range |
|
46 |
+
|--------|---------|-------|
|
47 |
+
| **K** | Negentropy – deviation from random noise | 0–1 (1 = ordered) |
|
48 |
+
| **C** | LZ Complexity – compressibility proxy | 0–1 (higher = more changes) |
|
49 |
+
| **S** | Symbiosis – agreement with reference distribution | 0–1 (1 = aligned) |
|
50 |
+
|
51 |
+
ACT halting exports `halt_probs` in telemetry showing how many layers executed. For robust sampling under safety constraints, call `safe_sample_with_retry(model, bits)` which retries with diffusion mode and exponential backoff.
|
52 |
+
|
53 |
+
`TelemetrySynthesizer.cluster_sequences` can be used to select representative training samples before invoking `collapse_submodel`. The distillation helper deepens the model and widens once (`width_scale` = 1.5) if floors are missed, and `save_distilled_model` emits a `metrics.json` summary beside the weights.
|
54 |
+
|
55 |
+
## Testing
|
56 |
+
- Run unit tests after any change: `pytest -q`.
|
57 |
+
- Use `watcher.py` for auto-reload and test on local development if desired.
|
58 |
+
- During training, call `model.train()` and keep dropout probabilities around `0.1–0.2`.
|
59 |
+
- Before running tests, inference, or pushing weights, switch to `model.eval()` and set all dropout probabilities to `0` to avoid flaky results.
|
60 |
+
- Dashboard will warn if telemetry metrics drift by more than 0.2 over the last 10 steps. Adjust via `ModelManager(drift_window, drift_threshold)` as needed.
|
61 |
+
|
62 |
+
## Licensing
|
63 |
+
- Project governed by documents in `LICENSE/` (AGPLv3, commercial terms, disclaimers, etc.). Ensure compliance before contributing or distributing.
|
64 |
+
|
65 |
+
These guidelines keep the repository consistent with the project roadmap and previous audits. Maintain security, style, and testing discipline to keep BitTransformerLM production-ready.
|
66 |
+
|
BitTransformerLM/BitTransformerLM_full_assessment.md
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# BitTransformerLM Deep-Dive Assessment Report
|
3 |
+
|
4 |
+
*(Comprehensive technical review and optimization roadmap)*
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
## Completed Tasks
|
9 |
+
- [x] 3.1 Cosine noise schedule option
|
10 |
+
- [x] 3.2 Post-process parity correction
|
11 |
+
- [x] 2.3 Expose checkpoint & reversible toggles
|
12 |
+
- [x] 2.2 Update deprecated AMP call
|
13 |
+
- [x] 5.2 Metric-drift alerts
|
14 |
+
- [x] 1.3 Expand README / docstrings for telemetry & ACT
|
15 |
+
- [x] 3.3 Safety-gate soft-retry
|
16 |
+
- [x] 7.1 Add ACT halting unit test
|
17 |
+
- [x] 4.1 Integrate performance-based scaling
|
18 |
+
- [x] 4.2 Learning-rate decay on resize
|
19 |
+
- [x] 3.4 Chunked attention logging toggle
|
20 |
+
- [x] 3.5 Quantization-aware training toggle
|
21 |
+
- [x] 7.2 Quantization & QAT tests
|
22 |
+
- [x] 4.3 Dashboard flag wiring
|
23 |
+
- [x] 7.3 Dashboard smoke test
|
24 |
+
- [x] 2.1 Unify flag names & deprecate legacy scale script
|
25 |
+
- [x] 5.1 Telemetry λ and floor UI
|
26 |
+
- [x] 5.3 Cluster-based distillation data
|
27 |
+
- [x] 6.1 Allow width scaling in collapse loop
|
28 |
+
- [x] 6.2 Save distilled metrics summary
|
29 |
+
|
30 |
+
## 1. Overview of BitTransformerLM Architecture and Recent Additions
|
31 |
+
BitTransformerLM is a **reversible Transformer** that operates **directly on binary sequences (bits)**. The immutable core uses multi-head self-attention on bit embeddings with sinusoidal positional encoding and already supports:
|
32 |
+
|
33 |
+
* Safety-centric telemetry (negentropy *K*, LZ complexity *C*, symbiosis *S*)
|
34 |
+
* Run-length compression / decompression paths
|
35 |
+
* Progressive scaling (depth & width) with reversible layers + gradient checkpointing
|
36 |
+
* Quantization (dynamic INT8 + optional 4‑bit QAT)
|
37 |
+
* A non‑causal **Diffusion‑LM mode** for bidirectional, denoising generation
|
38 |
+
* Dashboard, MCP server, and FSDP/pipeline hooks for distributed or edge deployment
|
39 |
+
|
40 |
+
Recent commits locked in deterministic environment setup (ChatGPT Codex container), removed insecure `/exec` endpoints, and added a reliable *course‑to‑fine* diffusion sampler stub. The model now installs and trains reproducibly on CPU‑only hosts, yet scales to multi‑GPU with FSDP.
|
41 |
+
|
42 |
+
---
|
43 |
+
|
44 |
+
## 2. Consistent Naming & Documentation
|
45 |
+
* Codebase generally follows *snake_case* functions / *CamelCase* classes, but CLI flags & helper scripts drift (e.g. `--diffusion` vs internal `causal=False`).
|
46 |
+
**Action:** unify flag names & docstrings; deprecate redundant scripts (`progressive_scaleup.py` vs `integration_schedule.py`).
|
47 |
+
* README and inline docs lack quick intuition for *K, C, S* metrics, ACT, and reversible internals.
|
48 |
+
**Action:** add short metric primers and ACT demo snippets; update `AGENTS.md` quick‑start table.
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
## 3. Optimizing Module Interactions & Performance
|
53 |
+
| Area | Current State | Optimization | Outcome |
|
54 |
+
|------|---------------|--------------|---------|
|
55 |
+
| **Chunked attention** ✅ | Saves RAM but reconstructs full *T×T* matrix for telemetry | Skip full matrix when `chunk_size < seq_len` and user disables `full_attn_logging` | Same metrics, big memory + speed win on long sequences |
|
56 |
+
| **PyTorch 2 features** | Uses `torch.compile` & BF16 autocast inconsistently | Standardize `torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16)`; wrap long loops | 10‑20 % CPU speed‑up, no deprecation warnings |
|
57 |
+
| **Reversible + checkpoint** | Always checkpoints → slower when RAM ample | Expose `--no-checkpoint` flag; document trade‑offs | User‑selectable speed vs memory |
|
58 |
+
| **Quantization** ✅ | INT8 dynamic works; 4‑bit QAT unused | Add `--qat` toggle in training scripts & unit‑test tiny model | Edge‑ready 4‑bit weights validated |
|
59 |
+
| **Compression loops** | Python for‑loops per sample | Batch or vectorized RLE when batch≫8 | Marginal speed‑up for large batches |
|
60 |
+
|
61 |
+
---
|
62 |
+
|
63 |
+
## 4. Fully Leveraging Diffusion Mode
|
64 |
+
1. [x] **Noise schedule** – switchable linear ▸ cosine ▸ exponential; expose `--noise-schedule`.
|
65 |
+
2. [x] **Step count** – allow 8–16 steps for high‑fidelity generation; document compute trade‑off.
|
66 |
+
3. [x] **Parity safeguard** – post‑sampling parity‑bit fix or strict parity sampling to guarantee valid bytes.
|
67 |
+
4. [x] **Training curriculum** – optional schedule: high‑noise → low‑noise over epochs; keep random‑noise fallback.
|
68 |
+
5. [x] **Safety integration** – run `hil_safe_inference(strict=False)` during diffusion; warn (not crash) on metric floor breaches.
|
69 |
+
|
70 |
+
---
|
71 |
+
|
72 |
+
## 5. Enhanced Training Workflow & Scaling Strategy
|
73 |
+
* **Adaptive scaling trigger** – adopt `progressive_scaleup.py` logic: scale only when val‑loss Δ < threshold; alternate width↔context↔depth.
|
74 |
+
* **Context extension** – use `double_length()` when plateau met; maintain chunked attention windows.
|
75 |
+
* **Warm‑up & plateau** – keep 5‑batch freeze after each expansion; add default final plateau epoch.
|
76 |
+
* **LR hygiene** – slight LR decay each scale‑up; document rationale.
|
77 |
+
|
78 |
+
---
|
79 |
+
|
80 |
+
## 6. Telemetry Metrics & Safety Integration
|
81 |
+
* **Metric coefficients** (`λ_K`, `λ_C`, `λ_S`) exposed via dashboard slider; floors (C ≥ 0.3, S ≥ 0.5) adjustable per deployment.
|
82 |
+
* **TelemetrySynthesizer** – cluster activations → representative sequences for distillation & drift detection.
|
83 |
+
* **Metric drift alert** – integrate `detect_metric_drift()` into training monitor; log if Δ > 0.2.
|
84 |
+
|
85 |
+
---
|
86 |
+
|
87 |
+
## 7. Distillation & Model Collapse Optimization
|
88 |
+
1. Use **cluster‑selected sequences** as `cluster_data` for `collapse_submodel` → better coverage.
|
89 |
+
2. Permit optional width growth (`width_scale > 1`) in iterative collapse rounds.
|
90 |
+
3. Log final vs floor metrics in `distilled_metrics.json` for audit trail.
|
91 |
+
4. Optionally auto‑invoke collapse at end of `integration_schedule` with `--auto-collapse`.
|
92 |
+
|
93 |
+
---
|
94 |
+
|
95 |
+
## 8. Additional Testing & Release Readiness
|
96 |
+
* Expand pytest suite: diffusion training/sampling, ACT halting, INT8 + QAT inference, dashboard API smoke tests.
|
97 |
+
* Add multi‑GPU CI job to validate FSDP + reversible layers.
|
98 |
+
* Strengthen debug logs: print mode (causal/diffusion/compression), scale‑up events, safety‑gate warnings.
|
99 |
+
|
100 |
+
---
|
101 |
+
|
102 |
+
## 9. Strategic Summary
|
103 |
+
BitTransformerLM already delivers an **orthogonal bundle of “firsts”**: bit‑native granularity, reversible memory efficiency, metric‑driven safety, and turnkey text diffusion.
|
104 |
+
Executing the roadmap **knits every module into a smooth, reproducible pipeline** without touching core architecture—preserving alignment while boosting usability.
|
105 |
+
|
106 |
+
**Bottom‑line:** With these refinements, BitTransformerLM becomes the reference for transparent, resource‑efficient, safety‑gated language modelling at the bit level—well beyond “just another model.”
|
107 |
+
|
108 |
+
|
109 |
+
Below is an **implementation playbook** that turns every recommendation in *“Overview of BitTransformerLM Architecture and Recent Additions”* into clear tasks and ready‑to‑copy Codex prompts. Where page numbers add context, I note them; all content is from the uploaded PDF. 
|
110 |
+
|
111 |
+
---
|
112 |
+
|
113 |
+
## 1 · Repository Consistency & Documentation
|
114 |
+
|
115 |
+
| # | Task | Key Steps | Codex Prompt (trim or expand as desired) |
|
116 |
+
| --- | -------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
117 |
+
| 1.1 | **Audit & unify public API names** | • Scan for duplicate / mis‑matched flags (e.g. `--diffusion` vs `causal=False`).<br>• Rename or deprecate aliases; update docs. | “List every function, class, and CLI flag whose name does **not** match the style‑guide (snake\_case for funcs, CamelCase for classes) in the BitTransformerLM repo. For each, propose a single canonical name and generate the automated `git mv` or refactor patches.” |
|
118 |
+
| 1.2 | **Consolidate scaling scripts** | • Merge `progressive_scaleup.py` logic into `integration_schedule.py`.<br>• Mark redundant script as example. | “Move the performance‑based scaling criterion from `progressive_scaleup.py` into `integration_schedule.py`. Preserve existing kwargs, add `--improve‑thresh` with default 0.01. Provide diff.” |
|
119 |
+
| 1.3 | **Expand README / docstrings for telemetry & ACT** (pp. 1 ‑ 2) | • Add one‑paragraph explanations of Negentropy (K), LZ Complexity (C), Symbiosis (S), and ACT halting to README.<br>• Link to equations in code comments. | “Insert a new subsection *‘Telemetry Metrics Explained’* into README after the quick‑start block, then add in‑line docstrings for `negentropy_score`, `lz_complexity`, and `symbiosis_score` explaining ranges and typical values.” |
|
120 |
+
|
121 |
+
---
|
122 |
+
|
123 |
+
## 2 · Performance Optimizations
|
124 |
+
|
125 |
+
| # | Task | Key Steps | Codex Prompt |
|
126 |
+
| --- | ------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
127 |
+
| 2.1 | **Vectorize chunked‑attention telemetry** (p. 2) | • Add flag `--attn‑summary`.<br>• When enabled and `chunked_attn=True`, compute per‑chunk entropy and skip full `T × T` map. | “Refactor `_chunked_attn` in `model.py` so that, if `attn_summary` is true, it returns `(attn_entropy_per_chunk, None)` instead of the stitched full map. Fall back to old behaviour otherwise. Update callers.” |
|
128 |
+
| 2.2 | **Update deprecated AMP call** | Replace `torch.cpu.amp.autocast` with `torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16)` everywhere. | “Search repo for `torch.cpu.amp.autocast`, replace with the new API, and add a context‑manager wrapper `cpu_autocast` in `utils/torch_utils.py`.” |
|
129 |
+
| 2.3 | **Expose checkpoint & reversible toggles** (p. 2) | • Add CLI flags `--use-checkpoint / --no-checkpoint` and `--reversible`.<br>• Document memory/compute trade‑off. | “Modify `train.py` argparse to include mutually exclusive `--[no-]checkpoint` flags; wire to `use_checkpoint` in model init.” |
|
130 |
+
| 2.4 | **Batch run‑length encoding** (p. 3) | • Implement NumPy‑vectorised RLE for the full tensor.<br>• Fallback to Python loop if tensor < 1024 bits. | “Implement `batch_rle_encode` in `bit_io.py` using NumPy strides; write unit test comparing speed & correctness to existing per‑sequence encode.” |
|
131 |
+
|
132 |
+
---
|
133 |
+
|
134 |
+
## 3 · Diffusion‑Mode Enhancements
|
135 |
+
|
136 |
+
| # | Task | Key Steps | Codex Prompt | | |
|
137 |
+
| --- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
138 |
+
| 3.1 | **Cosine noise schedule option** (p. 4) | • Add \`schedule="linear | cosine | exp"`arg to`diffusion\_inference\`.<br>• Default remains linear. | “Extend `diffusion_inference` to support a cosine decay of `mask_prob` over `steps`. Provide math and update docstring.” |
|
139 |
+
| 3.2 | **Post‑process parity correction** (p. 4) | • After sampling, flip each parity bit if byte parity invalid.<br>• Log number of corrections. | “Write `enforce_parity(bits)` that patches 9th bit per byte to satisfy even‑parity, return corrected seq + stats.” | | |
|
140 |
+
| 3.3 | **Safety‑gate soft‑retry** | • On failed `hil_safe_inference(strict=True)`, auto‑retry up to 3× with diffusion or random seed.<br>• Surface warning in logs. | “Wrap `hil_safe_inference` in a helper `safe_sample_with_retry`; implement exponential back‑off and logging.” | | |
|
141 |
+
|
142 |
+
---
|
143 |
+
|
144 |
+
## 4 · Adaptive Training Workflow
|
145 |
+
|
146 |
+
| # | Task | Key Steps | Codex Prompt |
|
147 |
+
| --- | ------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
148 |
+
| 4.1 | **Integrate performance‑based scaling** (pp. 5‑6) | • Use `Δval_loss < thresh` as condition to trigger `add_layer()`/`double_width()`.<br>• Alternate occasional `double_length()` for context. | “Inside `integration_schedule.train_loop`, compute rolling val‑loss; if mean improvement < `args.improve_thresh`, call `model.scale_up(strategy=next_step)` where `next_step` cycles \[layer, width, context].” |
|
149 |
+
| 4.2 | **Learning‑rate decay on resize** | • After each scale‑up, reduce base LR by √2.<br>• Provide warm‑up of 100 steps. | “Add `adjust_learning_rate(optimizer, factor)` util; call it after every successful model expansion.” |
|
150 |
+
| 4.3 | **Dashboard flag wiring** | • Map UI toggles (compression, diffusion) to `compress_prob`, `diffusion` args in backend. | “In `dashboard_app.py`, when user toggles compression, pass `compress_prob=1.0` to `ModelManager.train()`.” |
|
151 |
+
|
152 |
+
---
|
153 |
+
|
154 |
+
## 5 · Telemetry & Safety
|
155 |
+
|
156 |
+
| # | Task | Key Steps | Codex Prompt |
|
157 |
+
| --- | -------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
158 |
+
| 5.1 | **Expose λ coefficients and safety floors in UI** (p. 7) | • Add sliders for `λ_K`, `λ_C`, `λ_S`, `C_floor`, `S_floor`.<br>• Persist to model state. | “Add REST endpoints `/config/telemetry` (GET/POST) that read or set lambda values and floors; bind to dashboard sliders.” |
|
159 |
+
| 5.2 | **Metric‑drift alerts** (p. 8) | • After every epoch, call `detect_metric_drift(history, window=100)`; if > 0.2 drift, log & optionally halt training. | “Integrate `detect_metric_drift` into `ModelManager._log_metrics`; raise `MetricDriftWarning` when threshold exceeded.” |
|
160 |
+
| 5.3 | **Cluster‑based distillation data** (pp. 8‑9) | • Use `TelemetrySynthesizer` to pick `k` cluster representatives (default 8).<br>• Feed to `collapse_submodel`. | “Before `collapse_submodel`, run `representatives = TelemetrySynthesizer(model).cluster(train_data, k=8)`. Replace `train_bits[:64]` with `representatives`.” |
|
161 |
+
|
162 |
+
---
|
163 |
+
|
164 |
+
## 6 · Distillation / Collapse Process
|
165 |
+
|
166 |
+
| # | Task | Key Steps | Codex Prompt |
|
167 |
+
| --- | ----------------------------------------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------- |
|
168 |
+
| 6.1 | **Allow width scaling in collapse loop** (p. 8) | • Add `width_scale` param; if metric floors unmet after deepening, double width once then retry. | “Modify `collapse_submodel`: on round‑2 failure, rebuild sub‑model with `hidden_dim *= width_scale` (default 1.5).” |
|
169 |
+
| 6.2 | **Save metrics summary** | • Extend `save_distilled_model` to write `metrics.json` with achieved vs floor values. | “Update `save_distilled_model` to dump `{‘C’:score_C, ‘S’:score_S, ‘floors’:{...}}` alongside weights.” |
|
170 |
+
|
171 |
+
---
|
172 |
+
|
173 |
+
## 7 · Testing & CI Hardening
|
174 |
+
|
175 |
+
| # | Task | Key Steps | Codex Prompt |
|
176 |
+
| --- | ------------------------------------- | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------- |
|
177 |
+
| 7.1 | **Add ACT halting unit test** (p. 10) | • Craft toy seq; assert `sum(halt_prob<1) < n_layers`. | “Write `tests/test_act.py` ensuring at least one layer halts early when `use_act=True, threshold=0.1`.” |
|
178 |
+
| 7.2 | **Quantization & QAT tests** | • After tiny train, run dynamic int8 + fake‑QAT path, assert same logits ±1e‑3. | “Add `pytest` case: train 2‑layer model 1 epoch, call `quantize_dynamic`, compare outputs on 10 random inputs.” |
|
179 |
+
| 7.3 | **Dashboard smoke test** | • In CI, launch Flask app with `pytest‑flask`, hit `/init`, `/train‑step`, `/infer`. | “Create `tests/test_dashboard.py` that starts server in a thread and exercises core endpoints.” |
|
180 |
+
|
181 |
+
---
|
182 |
+
|
183 |
+
## 8 · Packaging & Release
|
184 |
+
|
185 |
+
| # | Task | Key Steps | Codex Prompt |
|
186 |
+
| --- | ---------------------------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
|
187 |
+
| 8.1 | **Rename repository references** (p. 11) | • Replace `Test/` URL stubs with new repo slug.<br>• Update badges in README. | “Search‑replace all GitHub links from `WCNegentropy/Test` to `WCNegentropy/BitTransformerLM`; update badge SVGs.” |
|
188 |
+
| 8.2 | **PyPI build verification** | • Ensure `pyproject.toml` installs cleanly on 3.10 & 3.11 in CI. | “Add GitHub Action matrix for {macOS, ubuntu‑latest} × {3.10, 3.11}; run `pip install -e . && pytest`.” |
|
189 |
+
|
190 |
+
---
|
191 |
+
|
192 |
+
### How to Use These Prompts
|
193 |
+
|
194 |
+
**Run** unit tests; iterate if failures surface.
|
195 |
+
|
196 |
+
This checklist should bring BitTransformerLM to a polished, v1‑ready state while aligning with your NRB‑driven safety and telemetry philosophy. 
|
BitTransformerLM/Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:22.04
|
2 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
3 |
+
RUN apt-get update && \
|
4 |
+
apt-get install -y python3.11 python3-pip python3.11-venv curl && \
|
5 |
+
apt-get clean && rm -rf /var/lib/apt/lists/*
|
6 |
+
|
7 |
+
WORKDIR /opt/bit_transformer
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
ARG TORCH_CUDA=cpu
|
11 |
+
RUN pip3 install --no-cache-dir --upgrade pip && \
|
12 |
+
if [ "$TORCH_CUDA" = "cu118" ]; then \
|
13 |
+
pip3 install torch==2.7.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118; \
|
14 |
+
else \
|
15 |
+
pip3 install torch==2.7.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu; \
|
16 |
+
fi && \
|
17 |
+
pip3 install -r requirements.txt
|
18 |
+
|
19 |
+
ENV MCP_SERVER_ADDR=http://127.0.0.1:7000
|
20 |
+
|
21 |
+
EXPOSE 5000 7000
|
22 |
+
|
23 |
+
RUN chmod +x start.sh
|
24 |
+
|
25 |
+
HEALTHCHECK CMD curl -f http://localhost:7000/health || exit 1
|
26 |
+
|
27 |
+
CMD ["/opt/bit_transformer/start.sh"]
|
BitTransformerLM/LICENSE/ALIGNMENT_AND_TRANSPARENCY.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Alignment and Transparency Agreement for BitTransformerLM
|
2 |
+
This Alignment and Transparency Agreement ("Agreement") outlines requirements
|
3 |
+
for the responsible and aligned commercial deployment of BitTransformerLM,
|
4 |
+
developed by WCNEGENTROPY HOLDINGS LLC.
|
5 |
+
|
6 |
+
## Core Principles
|
7 |
+
1. **Alignment:** The deployment must actively maintain alignment with ethical
|
8 |
+
and epistemic integrity, avoiding harmful or coercive outcomes.
|
9 |
+
2. **Transparency:** Telemetry data (negentropy, complexity, symbiosis scores,
|
10 |
+
etc.) must be transparently maintained and available for audit and inspection.
|
11 |
+
3. **Safety and Stability:** Deployments must utilize provided safety gates
|
12 |
+
(e.g., `hil_safe_inference`) to prevent degenerate or harmful outputs.
|
13 |
+
4. **Epistemic Responsibility:** Adopters commit to responsible use, actively
|
14 |
+
avoiding misuse or unethical applications.
|
15 |
+
|
16 |
+
## Telemetry and Monitoring
|
17 |
+
Commercial license holders must maintain full transparency on telemetry metrics
|
18 |
+
collected from the software, as originally implemented in the BitTransformerLM
|
19 |
+
repository. Telemetry must be available upon request for audit by WCNEGENTROPY HOLDINGS LLC or authorized third parties.
|
20 |
+
|
21 |
+
## Modification and Derivatives
|
22 |
+
Commercial license holders may modify the software for internal commercial use
|
23 |
+
but must explicitly disclose any modifications or derivatives upon request by
|
24 |
+
WCNEGENTROPY HOLDINGS LLC.
|
25 |
+
|
26 |
+
## Violations
|
27 |
+
Non-compliance with any of the terms outlined in this Agreement may result in
|
28 |
+
revocation of commercial licensing rights at the sole discretion of WCNEGENTROPY HOLDINGS LLC.
|
29 |
+
|
30 |
+
### Audit Cadence (added v0.9.0)
|
31 |
+
WCNEGENTROPY HOLDINGS LLC may request a telemetry snapshot **no more than once per
|
32 |
+
calendar quarter**. Licensee must deliver the requested data within **30 days
|
33 |
+
of receipt**. Failure to comply may result in suspension or termination of
|
34 |
+
commercial rights.
|
35 |
+
|
36 |
+
---
|
37 |
+
|
38 |
+
For questions, clarification, or audits, contact:
|
39 |
+
**WCNegentropy Holdings**
|
40 |
+
Email: [email protected]
|
41 |
+
Website: [wcnegentropy.com](https://wcnegentrop
|
42 |
+
y.com)
|
BitTransformerLM/LICENSE/COMMERCIAL_LICENSE.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Commercial License for BitTransformerLM
|
2 |
+
© 2025 WCNEGENTROPY HOLDINGS LLC – All Rights Reserved
|
3 |
+
|
4 |
+
> **Clarification (dual‑license):** This clause applies only to commercial deployments.
|
5 |
+
> Open‑source users remain bound by the GNU AGPL v3 in `LICENSE`.
|
6 |
+
> For holders of a **paid Commercial License**, BitTransformerLM is also provided
|
7 |
+
> under the **Apache License 2.0** (the “Commercial License”), **subject to the
|
8 |
+
> Alignment & Transparency Agreement (ATA)**.
|
9 |
+
|
10 |
+
BitTransformerLM (the “Software”), including all source code, documentation,
|
11 |
+
and associated assets, is the exclusive property of WCNEGENTROPY HOLDINGS LLC
|
12 |
+
(“WCNH”). Commercial use, reproduction, modification, distribution, or
|
13 |
+
sublicensing requires an executed Commercial License Agreement with WCNH.
|
14 |
+
|
15 |
+
## Patent Grant (Defensive)
|
16 |
+
WCNH hereby grants Licensee a perpetual, worldwide, non‑exclusive, no‑charge
|
17 |
+
patent license to make, use, sell, offer to sell, import, and otherwise
|
18 |
+
exploit the Software **provided** Licensee complies with this Commercial
|
19 |
+
License and the ATA. **This patent license terminates automatically** if
|
20 |
+
Licensee initiates patent litigation alleging that the Software infringes any
|
21 |
+
patent claim.
|
22 |
+
|
23 |
+
## Export‑Control & Sanctions Compliance
|
24 |
+
Licensee shall comply with all applicable export‑control and sanctions laws,
|
25 |
+
including but not limited to U.S. EAR, EU dual‑use regulations, and OFAC
|
26 |
+
sanctions lists. Licensee must not export, re‑export, or provide the Software
|
27 |
+
(or derivatives) to any prohibited country, entity, or individual.
|
28 |
+
|
29 |
+
## Alignment & Transparency Obligation
|
30 |
+
Commercial usage is conditional upon adherence to the ATA (see
|
31 |
+
`ALIGNMENT_AND_TRANSPARENCY.txt`). Failure to comply constitutes a material
|
32 |
+
breach and grounds for immediate license revocation.
|
33 |
+
|
34 |
+
For commercial inquiries contact **[email protected]**.
|
BitTransformerLM/LICENSE/CONTRIBUTOR_LICENSE_AGREEMENT.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Individual Contributor License Agreement (ICLA)
|
2 |
+
By submitting a contribution to the BitTransformerLM repository you agree to
|
3 |
+
grant WCNEGENTROPY HOLDINGS LLC an irrevocable, worldwide, royalty‑free
|
4 |
+
copyright license to reproduce, prepare derivative works of, publicly display
|
5 |
+
and distribute your contribution. You certify that you have the right to make
|
6 |
+
this grant and that your contribution is original or you have secured the
|
7 |
+
appropriate rights.
|
BitTransformerLM/LICENSE/DISCLAIMER.txt
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BitTransformerLM – Legal & Risk Disclaimer
|
2 |
+
_Last updated: 2025-08-04_
|
3 |
+
|
4 |
+
BitTransformerLM (the “Software”) is an **experimental, highly-capable, agentic AI
|
5 |
+
model** developed by WC Negentropy Holdings LLC (“WCNH”). By downloading,
|
6 |
+
installing, running, fine-tuning, or otherwise using the Software **you
|
7 |
+
acknowledge and agree to all terms below.**
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
## 1. No Warranty
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED **“AS IS”** AND **WITHOUT WARRANTY** OF ANY KIND,
|
14 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
15 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NON-INFRINGEMENT,
|
16 |
+
AND ERROR-FREE OR UNINTERRUPTED OPERATION.
|
17 |
+
WCNH DOES **NOT** WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS OR
|
18 |
+
THAT ITS OUTPUT WILL BE ACCURATE, COMPLETE, OR RELIABLE.
|
19 |
+
|
20 |
+
## 2. Limitation of Liability
|
21 |
+
|
22 |
+
TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL WCNH,
|
23 |
+
ITS AFFILIATES, CONTRIBUTORS, OR LICENSORS BE LIABLE FOR ANY DIRECT,
|
24 |
+
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
25 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; BUSINESS INTERRUPTION; OR PERSONAL
|
27 |
+
INJURY) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
28 |
+
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
29 |
+
ANY WAY OUT OF THE USE OF THE SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
|
30 |
+
OF SUCH DAMAGE.
|
31 |
+
|
32 |
+
## 3. High-Risk & Regulated Uses
|
33 |
+
|
34 |
+
**DO NOT DEPLOY** the Software in environments where failure or malfunction
|
35 |
+
could lead to death, serious bodily injury, or severe property or
|
36 |
+
environmental damage, including but not limited to:
|
37 |
+
|
38 |
+
- Medical diagnosis or life-support systems
|
39 |
+
- Autonomous vehicles or aviation control
|
40 |
+
- Nuclear facilities, weapons development, or military combat systems
|
41 |
+
- Critical infrastructure (power, water, telecom)
|
42 |
+
- Legal, financial, or governmental decision-making without qualified
|
43 |
+
human review
|
44 |
+
|
45 |
+
You remain solely responsible for conducting appropriate risk assessments,
|
46 |
+
validation, and human oversight before any production deployment.
|
47 |
+
|
48 |
+
## 4. Alignment & Transparency Obligations
|
49 |
+
|
50 |
+
If you hold a **Commercial License**, you must also comply with the
|
51 |
+
Alignment & Transparency Agreement (ATA), including:
|
52 |
+
|
53 |
+
- Logging K-C-S telemetry and retaining it for audit
|
54 |
+
- Supplying telemetry snapshots upon request (max once per quarter)
|
55 |
+
- Cooperating with reasonable misuse investigations
|
56 |
+
|
57 |
+
## 5. Data & Privacy
|
58 |
+
|
59 |
+
You are responsible for:
|
60 |
+
|
61 |
+
- Ensuring you have the legal right to process any data you supply to the
|
62 |
+
Software
|
63 |
+
- Implementing technical and organizational measures to protect personal or
|
64 |
+
sensitive data
|
65 |
+
- Complying with all applicable data-protection and privacy laws (e.g.,
|
66 |
+
GDPR, CCPA, HIPAA)
|
67 |
+
|
68 |
+
## 6. Export & Sanctions Compliance
|
69 |
+
|
70 |
+
You may **not** use or transfer the Software in violation of U.S. export
|
71 |
+
control laws, EU dual-use regulations, or applicable sanctions regimes.
|
72 |
+
This includes, but is not limited to, prohibitions on use in or by
|
73 |
+
countries, entities, or individuals listed on U.S. or EU restricted-party
|
74 |
+
lists.
|
75 |
+
|
76 |
+
## 7. Third-Party Dependencies
|
77 |
+
|
78 |
+
The Software may incorporate open-source components licensed under separate
|
79 |
+
terms. Such components are provided **“as is”** and remain subject to their
|
80 |
+
respective licenses. A complete list is available in `THIRD_PARTY_LICENSES.txt`.
|
81 |
+
|
82 |
+
## 8. No Professional Advice
|
83 |
+
|
84 |
+
The Software’s outputs (including code, text, and recommendations) do **not**
|
85 |
+
constitute professional, legal, medical, financial, or safety advice. Always
|
86 |
+
consult a qualified expert before relying on the Software for any critical
|
87 |
+
decision.
|
88 |
+
|
89 |
+
---
|
90 |
+
|
91 |
+
**© 2023-2025 WCNEGENTROPY HOLDINGS LLC.**
|
92 |
+
All trademarks—including “BitTransformerLM” and the spiral-N logo—are property
|
93 |
+
of WCNH. Unauthorized use is prohibited.
|
BitTransformerLM/LICENSE/LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LICENSE (AGPLv3)
|
2 |
+
Copyright (C) 2025 WCNEGENTROPY HOLDINGS LLC
|
3 |
+
This program is free software: you can redistribute it and/or modify
|
4 |
+
it under the terms of the GNU Affero General Public License as published
|
5 |
+
by the Free Software Foundation, either version 3 of the License, or
|
6 |
+
(at your option) any later version.
|
7 |
+
This program is distributed in the hope that it will be useful,
|
8 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
9 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
10 |
+
GNU Affero General Public License for more details.
|
11 |
+
You should have received a copy of the GNU Affero General Public License
|
12 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
BitTransformerLM/LICENSE/TRADEMARK_POLICY.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
“BitTransformerLM” and the spiral‑N logo are trademarks of WCNEGENTROPY HOLDINGS LLC.
|
2 |
+
|
3 |
+
Permitted use:
|
4 |
+
• Describing, linking to, or referencing **unmodified, official builds** of
|
5 |
+
BitTransformerLM.
|
6 |
+
|
7 |
+
Prohibited use without prior written permission:
|
8 |
+
• Branding or promoting modified forks, derivatives, or third‑party services.
|
9 |
+
• Creating confusingly similar marks or implying endorsement by WCNH.
|
10 |
+
|
11 |
+
Forks must remove or rename the marks to avoid confusion.
|
12 |
+
Contact **[email protected]** for licensing requests.
|
BitTransformerLM/NEW_CODEX_TASK.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DEPRECATED
|
2 |
+
|
3 |
+
All tasks in this file have been implemented (Stages 1–5). The document remains for historical reference only.
|
4 |
+
|
5 |
+
Stage 1: Compression Algorithm Implementation
|
6 |
+
|
7 |
+
Task 1: Choose Compression Method
|
8 |
+
|
9 |
+
Prompt:
|
10 |
+
|
11 |
+
Codex: Provide a concise PyTorch-compatible implementation of lossless binary compression and decompression (e.g., RLE, Huffman, or LZ-based) suitable for binary input sequences represented as tensors of bits.
|
12 |
+
|
13 |
+
Task 2: Implement Compression Functions
|
14 |
+
|
15 |
+
Prompt:
|
16 |
+
|
17 |
+
Codex: Implement PyTorch functions compress_bits(input_tensor) and decompress_bits(compressed_tensor) that accept and return PyTorch tensors (dtype=torch.bool or torch.uint8). Ensure compress → decompress cycle perfectly reconstructs original data, and include simple unit tests.
|
18 |
+
|
19 |
+
⸻
|
20 |
+
|
21 |
+
Stage 2: Encoder/Decoder Integration
|
22 |
+
|
23 |
+
Task 3: Add Compression to Encoder Input
|
24 |
+
|
25 |
+
Prompt:
|
26 |
+
|
27 |
+
Codex: Modify BitTransformerLM’s input pipeline by wrapping the existing model forward pass with a forward_compressed(bits_tensor) method. This method should decompress incoming compressed bit tensors before embedding. Ensure it returns identical outputs as existing uncompressed inputs for verification.
|
28 |
+
|
29 |
+
Task 4: Add Decompression to Decoder Output
|
30 |
+
|
31 |
+
Prompt:
|
32 |
+
|
33 |
+
Codex: Implement a PyTorch-compatible function model_output_decompress(output_bits_tensor) to decompress bit sequences output by BitTransformerLM. Integrate this function as an optional post-processing step after the model’s bitstream generation.
|
34 |
+
|
35 |
+
⸻
|
36 |
+
|
37 |
+
Stage 3: Training and Evaluation Enhancements
|
38 |
+
|
39 |
+
Task 5: Toggle Compression During Training
|
40 |
+
|
41 |
+
Prompt:
|
42 |
+
|
43 |
+
Codex: Modify the existing training loop to randomly compress input bit sequences with a configurable probability (compress_prob=0.5). Ensure that when compression is on, inputs are compressed and decompressed transparently, and when off, inputs bypass compression.
|
44 |
+
|
45 |
+
Task 6: Evaluate Compressed vs Raw Performance
|
46 |
+
|
47 |
+
Prompt:
|
48 |
+
|
49 |
+
Codex: Extend the current training evaluation metrics to separately track loss, accuracy, and compression ratio for both compressed and raw sequences. Log these metrics clearly in the training output.
|
50 |
+
|
51 |
+
⸻
|
52 |
+
|
53 |
+
Stage 4: Advanced Integration (Optional)
|
54 |
+
|
55 |
+
Task 7: Multi-task Training for Compression Learning
|
56 |
+
|
57 |
+
Prompt:
|
58 |
+
|
59 |
+
Codex: Implement an optional multi-task training mode where the model occasionally sees compressed inputs directly without decompression. Add a separate loss calculation to monitor its performance on these compressed inputs. Track and log separately from normal next-bit prediction loss.
|
60 |
+
|
61 |
+
Task 8: Compression-aware Safety Telemetry
|
62 |
+
|
63 |
+
Prompt:
|
64 |
+
|
65 |
+
Codex: Adjust the existing BitTransformerLM telemetry (K, C, and S metrics) to handle compressed sequences appropriately. Modify telemetry calculations to optionally apply metrics to decompressed outputs instead of raw bitstream when compression is enabled.
|
66 |
+
|
67 |
+
⸻
|
68 |
+
|
69 |
+
Stage 5: Dashboard and Runtime Integration
|
70 |
+
|
71 |
+
Task 9: Dashboard Compression UI Toggle
|
72 |
+
|
73 |
+
Prompt:
|
74 |
+
|
75 |
+
Codex: Add a simple UI toggle labeled “Enable Compression” to the existing BitTransformerLM dashboard, controlling whether inputs and outputs are automatically compressed and decompressed. Display compression ratio metrics when enabled.
|
76 |
+
|
77 |
+
Task 10: Error Handling and User Feedback
|
78 |
+
|
79 |
+
Prompt:
|
80 |
+
|
81 |
+
Codex: Implement graceful error handling in the dashboard for compression and decompression failures. Provide clear user-facing feedback in the UI if decompression fails, along with suggestions or fallbacks.
|
82 |
+
|
83 |
+
⸻
|
84 |
+
|
85 |
+
These ten tasks enable incremental, testable integration of binary compression/decompression into BitTransformerLM without fundamentally altering the core transformer model itself.
|
BitTransformerLM/README.md
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BitTransformerLM
|
2 |
+
|
3 |
+
**Project Status:** Pre-release (v1 candidate)
|
4 |
+
|
5 |
+
BitTransformerLM is a bit-centric transformer language model built entirely in PyTorch. The project began as a small prototype but has matured into a near-production system capable of modeling raw binary streams with sophisticated safety telemetry and automated scale-up tooling. This repository now serves as the canonical implementation under WCNegentropy.
|
6 |
+
|
7 |
+
## Historical Background
|
8 |
+
- **Early Experiments** – Initial prototypes explored mapping text to parity-protected bits and training a minimal transformer on random data.
|
9 |
+
- **Telemetry & Safety** – Added negentropy, LZ complexity and symbiosis scoring to measure information flow and gate unsafe outputs.
|
10 |
+
- **Progressive Scaling** – Introduced reversible layers and automatic depth/width expansion for efficient curriculum training. The schedule now triggers expansions only when validation loss plateaus and decays the learning rate by √2 after each growth with a 100-step warm‑up.
|
11 |
+
- **Compression Support** – Integrated run-length encoding and packed bit I/O with optional multi-task training on compressed sequences.
|
12 |
+
- **Context Extension** – Implemented chunked attention and sliding-window inference for long sequences with optional overlapping windows.
|
13 |
+
- **Attention Logging Toggle** – ``full_attn_logging=False`` skips reconstructing full ``T×T`` attention maps during chunked attention, cutting memory use for very long sequences.
|
14 |
+
- **Diffusion LM Mode** – Enable bidirectional denoising by setting ``causal=False`` or toggling **Diffusion LM** in the dashboard. Chunked attention is automatically disabled in this mode and restored afterward.
|
15 |
+
- **Dashboard & MCP Server** – Built a lightweight web UI backed by a management server for real‑time training, inference and model collapse. New `/metrics` and `/model_config` endpoints surface live telemetry and hyperparameters, and `/save_checkpoint` and `/download_checkpoint` enable Hugging Face weight sync. The insecure `/exec` route has been removed.
|
16 |
+
- **Phase 1 Optimizations** – Configurable batch sizes with aligned OneCycle scheduling, gradient accumulation, mixed‑precision, memory‑mapped dataset streaming, scheduled compression ramps, selective ``torch.compile``, and an EMA‑smoothed safety gate with burn‑in to cut false positives.
|
17 |
+
|
18 |
+
The codebase has undergone multiple stress tests and synthetic benchmarks (see `tests/TEST_RESULTS.md`) and now approaches a stable release.
|
19 |
+
|
20 |
+
## Quick Start
|
21 |
+
Install dependencies using the CPU wheel of PyTorch (default):
|
22 |
+
```bash
|
23 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
|
24 |
+
```
|
25 |
+
When GPU acceleration is toggled in the dashboard, the application automatically
|
26 |
+
installs the CUDA-enabled wheel:
|
27 |
+
```bash
|
28 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118
|
29 |
+
```
|
30 |
+
Run the example script:
|
31 |
+
```bash
|
32 |
+
python example.py
|
33 |
+
```
|
34 |
+
Adaptive scaling demo:
|
35 |
+
The legacy `progressive_scaleup.py` script is retained for reference but has been
|
36 |
+
superseded by `integration_schedule.py`, which offers a more flexible scaling
|
37 |
+
workflow.
|
38 |
+
|
39 |
+
Run the unified workflow:
|
40 |
+
```bash
|
41 |
+
python unified_workflow.py --dashboard
|
42 |
+
# disable gradient checkpointing for faster but memory-hungry runs
|
43 |
+
python unified_workflow.py --no-checkpoint
|
44 |
+
# use standard (non-reversible) transformer blocks
|
45 |
+
python unified_workflow.py --no-reversible
|
46 |
+
# enable 4-bit quantization-aware training
|
47 |
+
python unified_workflow.py --qat
|
48 |
+
```
|
49 |
+
|
50 |
+
For faster CPU execution, BitTransformerLM exposes a `cpu_autocast()` helper
|
51 |
+
that enables bfloat16 mixed precision. Models created with
|
52 |
+
`use_autocast=True` apply this automatically, or you can wrap individual
|
53 |
+
forward passes:
|
54 |
+
|
55 |
+
```python
|
56 |
+
from bit_transformer.torch_utils import cpu_autocast
|
57 |
+
|
58 |
+
with cpu_autocast():
|
59 |
+
logits, telemetry = model(bits)
|
60 |
+
```
|
61 |
+
|
62 |
+
Reduce memory use when chunked attention is active by disabling full
|
63 |
+
attention logging:
|
64 |
+
|
65 |
+
```python
|
66 |
+
model = BitTransformerLM(chunk_size=128, full_attn_logging=False)
|
67 |
+
```
|
68 |
+
|
69 |
+
Enable Diffusion LM training and sampling:
|
70 |
+
```bash
|
71 |
+
python unified_workflow.py --diffusion --diffusion-steps 8 --dataset-size 32
|
72 |
+
# choose noise schedule: linear, cosine, exp
|
73 |
+
python unified_workflow.py --diffusion --noise-schedule cosine --diffusion-steps 16 --dataset-size 32
|
74 |
+
# linearly decay noise over epochs
|
75 |
+
python unified_workflow.py --diffusion --diffusion-curriculum --dataset-size 32
|
76 |
+
```
|
77 |
+
Higher `--diffusion-steps` (8–16) improves sample quality at the cost of compute. When using the dashboard, enable the **Diffusion LM** toggle to run the model without causal masking or chunked attention.
|
78 |
+
Generated samples automatically fix parity bits so they can be decoded back to text.
|
79 |
+
To resume training across machines using Hugging Face storage:
|
80 |
+
```bash
|
81 |
+
python unified_workflow.py --hf-repo your-username/bittransformerlm --hf-token $HF_TOKEN
|
82 |
+
```
|
83 |
+
The dashboard exposes matching controls under **Hugging Face Checkpoints**. Provide a repository ID and optional token (falling back to the `HF_TOKEN` environment variable) and click **Upload weights** or **Download weights** to sync the model.
|
84 |
+
Run the unit tests:
|
85 |
+
```bash
|
86 |
+
pytest -q
|
87 |
+
```
|
88 |
+
|
89 |
+
### Mode management
|
90 |
+
|
91 |
+
During training, ensure the model is in training mode with dropout enabled:
|
92 |
+
|
93 |
+
```python
|
94 |
+
from bit_transformer.utils import set_dropout
|
95 |
+
|
96 |
+
model.train()
|
97 |
+
set_dropout(model, 0.1)
|
98 |
+
```
|
99 |
+
|
100 |
+
Before running tests, performing inference, or committing weights to the repository, switch the model to evaluation mode and disable dropout:
|
101 |
+
|
102 |
+
```python
|
103 |
+
model.eval()
|
104 |
+
set_dropout(model, 0.0)
|
105 |
+
```
|
106 |
+
|
107 |
+
This prevents CI failures from accidentally pushing weights that still have active dropout.
|
108 |
+
|
109 |
+
## Telemetry Metrics Explained
|
110 |
+
BitTransformerLM reports three bounded metrics in ``[0, 1]`` during training and inference:
|
111 |
+
|
112 |
+
- **Negentropy (K)** – departure from random noise; ``1`` denotes perfectly ordered bits while ``0`` is uniform randomness.
|
113 |
+
- **LZ Complexity (C)** – differentiable proxy for Lempel–Ziv compressibility; low values imply repetitive patterns and high values frequent transitions.
|
114 |
+
- **Symbiosis (S)** – agreement between model predictions and a reference distribution via KL divergence; scores near ``1`` show strong alignment.
|
115 |
+
|
116 |
+
An Adaptive Computation Time (ACT) mechanism lets layers halt early once confidence exceeds a threshold. Halt probabilities are exported as ``halt_probs`` in telemetry for inspection.
|
117 |
+
|
118 |
+
These metrics are logged alongside losses and can trigger safety gates when thresholds are violated. The dashboard monitors drift and emits warnings when recent values deviate beyond a configurable threshold.
|
119 |
+
|
120 |
+
## Core Features
|
121 |
+
- **Bit-Native Modeling** – Works directly on 0/1 inputs with positional encodings and parity-protected text helpers.
|
122 |
+
- **Telemetry Synthesizer** – Clusters activation summaries to surface coherent subspaces and detect drift.
|
123 |
+
- **Submodel Distillation** – `TelemetrySynthesizer` selects representative sequences for `collapse_submodel`, which deepens
|
124 |
+
and widens once (`width_scale` = 1.5) if telemetry floors aren't met; `save_distilled_model` places a `metrics.json` summary
|
125 |
+
beside the distilled weights.
|
126 |
+
- **Safety Gate** – `hil_safe_inference` enforces minimum complexity and symbiosis scores at runtime with EMA smoothing and a configurable burn‑in period.
|
127 |
+
- **Quantization** – CPU inference can be quantized to int8 or trained with 4-bit QAT using the `--qat` flag.
|
128 |
+
- **Distributed Training** – FSDP and pipeline helpers allow multi‑GPU scaling when hardware is available.
|
129 |
+
- **Interactive Dashboard** – Live control of training, scaling and compression with optional GPU acceleration. The dashboard now exposes reversible layers, gradient checkpointing, ACT thresholds, λ floors, 4‑bit QAT and Diffusion LM toggles, real‑time telemetry charts powered by Chart.js, and Hugging Face checkpoint upload/download controls with `HF_TOKEN` fallback. Settings persist via `localStorage`.
|
130 |
+
- **CI/CD Pipeline** – GitHub Actions install dependencies, run the tests and build distribution artifacts on every push.
|
131 |
+
|
132 |
+
## Development Workflow
|
133 |
+
1. Start the MCP server:
|
134 |
+
```bash
|
135 |
+
python mcp_server.py
|
136 |
+
```
|
137 |
+
2. Launch the dashboard in another terminal:
|
138 |
+
```bash
|
139 |
+
MCP_SERVER_ADDR=http://127.0.0.1:7000 python -m bit_transformer.dashboard_app
|
140 |
+
```
|
141 |
+
3. Submit training batches, scale the model and monitor telemetry from the web UI.
|
142 |
+
The dashboard's appearance is controlled by `bit_transformer/static/style.css`.
|
143 |
+
|
144 |
+
A `watcher.py` script can automatically restart the server and run tests when files change during local development.
|
145 |
+
|
146 |
+
## Container Deployment
|
147 |
+
A `Dockerfile` and `start.sh` script build a minimal VM image that launches both the MCP server and dashboard.
|
148 |
+
|
149 |
+
```bash
|
150 |
+
docker build -t bittransformerlm .
|
151 |
+
docker run -p 5000:5000 -p 7000:7000 bittransformerlm
|
152 |
+
```
|
153 |
+
|
154 |
+
By default the container installs the CPU-only PyTorch wheel. Set the build
|
155 |
+
argument `TORCH_CUDA=cu118` to preinstall the GPU version. The container sets
|
156 |
+
`MCP_SERVER_ADDR=http://127.0.0.1:7000` and exposes the dashboard on port 5000.
|
157 |
+
|
158 |
+
## Roadmap
|
159 |
+
- Finalize S attribution tools and metric drift detection.
|
160 |
+
- Publish an initial release package and rename the repository to **BitTransformerLM**.
|
161 |
+
- Continue benchmarking on real datasets and expanding context window capabilities.
|
162 |
+
|
163 |
+
## Licensing
|
164 |
+
|
165 |
+
This project is released under a combination of licenses and agreements to provide a clear framework for use, distribution, and contribution. All licensing documents can be found in the `LICENSE/` directory.
|
166 |
+
|
167 |
+
The key documents are:
|
168 |
+
|
169 |
+
* `LICENSE.txt`: The primary open-source license for the software, AGPLv3.
|
170 |
+
* `COMMERCIAL_LICENSE.txt`: Terms for commercial use of the software.
|
171 |
+
* `DISCLAIMER.txt`: Important legal disclaimers.
|
172 |
+
* `ALIGNMENT_AND_TRANSPARENCY.txt`: Our commitment to alignment and transparency.
|
173 |
+
* `TRADEMARK_POLICY.txt`: Guidelines for using the project's trademarks.
|
174 |
+
* `CONTRIBUTOR_LICENSE_AGREEMENT.txt`: The agreement for all contributors to sign.
|
175 |
+
|
176 |
+
Please review these documents carefully before using or contributing to the project.
|
177 |
+
|
BitTransformerLM/bit_transformer/__init__.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import (
|
2 |
+
PositionalEncoding,
|
3 |
+
BitTransformerLM,
|
4 |
+
ReversibleLoggingTransformerEncoderLayer,
|
5 |
+
example_usage,
|
6 |
+
example_training_step,
|
7 |
+
infer_long_sequence,
|
8 |
+
diffusion_inference,
|
9 |
+
)
|
10 |
+
from .telemetry import TelemetrySynthesizer, detect_metric_drift
|
11 |
+
from .dashboard import plot_telemetry
|
12 |
+
from .dashboard_app import run_dashboard
|
13 |
+
from .collapse import collapse_submodel, save_distilled_model
|
14 |
+
from .safety import hil_safe_inference, demo_hil_safety, safe_sample_with_retry
|
15 |
+
from .bit_io import (
|
16 |
+
text_to_bits,
|
17 |
+
bits_to_text,
|
18 |
+
infer_text,
|
19 |
+
)
|
20 |
+
from .parity import enforce_parity
|
21 |
+
from .compression import (
|
22 |
+
compress_bits,
|
23 |
+
decompress_bits,
|
24 |
+
model_output_decompress,
|
25 |
+
pack_bits,
|
26 |
+
unpack_bits,
|
27 |
+
)
|
28 |
+
from .distributed import wrap_fsdp, make_pipeline
|
29 |
+
from .optimization import configure_optimizer, adjust_learning_rate
|
30 |
+
from .scale import expand_model
|
31 |
+
from .distil import distill_step, TelemetryLog
|
32 |
+
from .quantization import (
|
33 |
+
quantize_dynamic,
|
34 |
+
prepare_qat_fx,
|
35 |
+
convert_qat_fx,
|
36 |
+
)
|
37 |
+
from .training import train_loop
|
38 |
+
from .utils import save_model, load_model, set_dropout
|
39 |
+
from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
|
40 |
+
from .torch_utils import cpu_autocast
|
41 |
+
|
42 |
+
__all__ = [
|
43 |
+
"PositionalEncoding",
|
44 |
+
"BitTransformerLM",
|
45 |
+
"ReversibleLoggingTransformerEncoderLayer",
|
46 |
+
"example_usage",
|
47 |
+
"example_training_step",
|
48 |
+
"TelemetrySynthesizer",
|
49 |
+
"detect_metric_drift",
|
50 |
+
"collapse_submodel",
|
51 |
+
"save_distilled_model",
|
52 |
+
"hil_safe_inference",
|
53 |
+
"demo_hil_safety",
|
54 |
+
"safe_sample_with_retry",
|
55 |
+
"text_to_bits",
|
56 |
+
"bits_to_text",
|
57 |
+
"infer_text",
|
58 |
+
"enforce_parity",
|
59 |
+
"plot_telemetry",
|
60 |
+
"run_dashboard",
|
61 |
+
"configure_optimizer",
|
62 |
+
"adjust_learning_rate",
|
63 |
+
"expand_model",
|
64 |
+
"distill_step",
|
65 |
+
"TelemetryLog",
|
66 |
+
"quantize_dynamic",
|
67 |
+
"prepare_qat_fx",
|
68 |
+
"convert_qat_fx",
|
69 |
+
"train_loop",
|
70 |
+
"wrap_fsdp",
|
71 |
+
"make_pipeline",
|
72 |
+
"compress_bits",
|
73 |
+
"decompress_bits",
|
74 |
+
"model_output_decompress",
|
75 |
+
"pack_bits",
|
76 |
+
"unpack_bits",
|
77 |
+
"infer_long_sequence",
|
78 |
+
"diffusion_inference",
|
79 |
+
"save_model",
|
80 |
+
"load_model",
|
81 |
+
"set_dropout",
|
82 |
+
"hf_login",
|
83 |
+
"save_checkpoint",
|
84 |
+
"download_checkpoint",
|
85 |
+
"cpu_autocast",
|
86 |
+
]
|
BitTransformerLM/bit_transformer/bit_io.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, TYPE_CHECKING
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
|
5 |
+
try: # torch.compile may be unavailable or unsupported
|
6 |
+
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
|
7 |
+
compile_fn = torch.compile
|
8 |
+
else:
|
9 |
+
raise RuntimeError
|
10 |
+
except Exception: # pragma: no cover
|
11 |
+
|
12 |
+
def compile_fn(fn=None, **kwargs):
|
13 |
+
if fn is None:
|
14 |
+
return lambda f: f
|
15 |
+
return fn
|
16 |
+
|
17 |
+
|
18 |
+
if TYPE_CHECKING: # pragma: no cover
|
19 |
+
from .model import BitTransformerLM
|
20 |
+
|
21 |
+
|
22 |
+
@compile_fn
|
23 |
+
def bytes_to_bits(data: bytes) -> List[int]:
|
24 |
+
"""Convert bytes to bits with per-byte parity bit."""
|
25 |
+
result: List[int] = []
|
26 |
+
for b in data:
|
27 |
+
bits = [(b >> i) & 1 for i in reversed(range(8))]
|
28 |
+
parity = sum(bits) % 2
|
29 |
+
result.extend(bits + [parity])
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
@compile_fn
|
34 |
+
def bits_to_bytes(bits: List[int]) -> bytes:
|
35 |
+
"""Convert parity-protected bits back to bytes."""
|
36 |
+
if len(bits) % 9 != 0:
|
37 |
+
raise ValueError("Bit stream length must be multiple of 9")
|
38 |
+
out = bytearray()
|
39 |
+
for i in range(0, len(bits), 9):
|
40 |
+
chunk = bits[i : i + 9]
|
41 |
+
payload = chunk[:8]
|
42 |
+
parity = chunk[8]
|
43 |
+
if parity != sum(payload) % 2:
|
44 |
+
raise ValueError("Parity check failed")
|
45 |
+
value = 0
|
46 |
+
for bit in payload:
|
47 |
+
value = (value << 1) | bit
|
48 |
+
out.append(value)
|
49 |
+
return bytes(out)
|
50 |
+
|
51 |
+
|
52 |
+
def text_to_bits(text: str) -> List[int]:
|
53 |
+
return bytes_to_bits(text.encode("utf-8"))
|
54 |
+
|
55 |
+
|
56 |
+
def bits_to_text(bits: List[int]) -> str:
|
57 |
+
return bits_to_bytes(bits).decode("utf-8", errors="replace")
|
58 |
+
|
59 |
+
|
60 |
+
def infer_text(
|
61 |
+
model: "BitTransformerLM",
|
62 |
+
text: str,
|
63 |
+
c_floor: float = 0.3,
|
64 |
+
s_floor: float = 0.5,
|
65 |
+
) -> str:
|
66 |
+
"""Run text through the model using the safety gate."""
|
67 |
+
from .safety import hil_safe_inference
|
68 |
+
bits = text_to_bits(text)
|
69 |
+
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
|
70 |
+
out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor)
|
71 |
+
return bits_to_text(out_bits.squeeze(0).tolist())
|
72 |
+
|
73 |
+
|
74 |
+
def sample_text(
|
75 |
+
model: "BitTransformerLM",
|
76 |
+
prompt: str,
|
77 |
+
max_new_tokens: int = 16,
|
78 |
+
temperature: float = 1.0,
|
79 |
+
top_p: float = 1.0,
|
80 |
+
) -> str:
|
81 |
+
"""Generate text from the model using simple top-p sampling."""
|
82 |
+
model.eval()
|
83 |
+
bits = text_to_bits(prompt)
|
84 |
+
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
|
85 |
+
for _ in range(max_new_tokens * 9):
|
86 |
+
if tensor.size(1) >= model.pos_enc.pe.size(0):
|
87 |
+
break
|
88 |
+
logits, _ = model(tensor, causal=True)
|
89 |
+
prob = logits[0, -1].softmax(-1) / temperature
|
90 |
+
sorted_prob, sorted_idx = prob.sort(descending=True)
|
91 |
+
cumulative = sorted_prob.cumsum(0)
|
92 |
+
mask = cumulative > top_p
|
93 |
+
sorted_prob[mask] = 0
|
94 |
+
sorted_prob = sorted_prob / sorted_prob.sum()
|
95 |
+
next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)]
|
96 |
+
tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1)
|
97 |
+
return bits_to_text(tensor.squeeze(0).tolist())
|
BitTransformerLM/bit_transformer/collapse.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .model import BitTransformerLM
|
8 |
+
from .training import train_loop
|
9 |
+
|
10 |
+
|
11 |
+
def collapse_submodel(
|
12 |
+
cluster_data: List[List[int]],
|
13 |
+
target_params: Dict,
|
14 |
+
floors: Optional[Dict[str, float]] = None,
|
15 |
+
max_rounds: int = 3,
|
16 |
+
width_scale: float = 1.5,
|
17 |
+
forward_kwargs: Optional[Dict] = None,
|
18 |
+
) -> Tuple[BitTransformerLM, Dict[str, float]]:
|
19 |
+
"""Distill a submodel from clustered bit sequences.
|
20 |
+
|
21 |
+
The routine deepens the target model when telemetry floors are unmet and,
|
22 |
+
after the first deepening fails, widens the hidden dimensions by
|
23 |
+
``width_scale`` once before retrying. Returns the distilled model and its
|
24 |
+
final telemetry metrics.
|
25 |
+
"""
|
26 |
+
if floors is None:
|
27 |
+
floors = {"negentropy": 0.5, "lz_complexity": 0.3, "symbiosis_score": 0.5}
|
28 |
+
|
29 |
+
bit_tensor = torch.tensor(cluster_data, dtype=torch.long)
|
30 |
+
n = len(bit_tensor)
|
31 |
+
split = max(1, int(0.8 * n))
|
32 |
+
train_bits = bit_tensor[:split]
|
33 |
+
val_bits = bit_tensor[split:]
|
34 |
+
if len(val_bits) == 0:
|
35 |
+
val_bits = train_bits
|
36 |
+
|
37 |
+
params = target_params.copy()
|
38 |
+
metrics: Dict[str, float] = {}
|
39 |
+
width_scaled = False
|
40 |
+
for round_idx in range(max_rounds):
|
41 |
+
model = BitTransformerLM(**params)
|
42 |
+
train_loop(
|
43 |
+
model,
|
44 |
+
train_bits,
|
45 |
+
epochs=2,
|
46 |
+
compress_prob=0.5,
|
47 |
+
direct_prob=0.0,
|
48 |
+
log=False,
|
49 |
+
forward_kwargs=forward_kwargs,
|
50 |
+
)
|
51 |
+
with torch.no_grad():
|
52 |
+
logits, telemetry = model(val_bits, **(forward_kwargs or {}))
|
53 |
+
neg_k = model.negentropy_logits(logits).mean().item()
|
54 |
+
lz_c = model.lz_complexity_logits(logits).mean().item()
|
55 |
+
sym_s = telemetry["symbiosis_score"].mean().item()
|
56 |
+
metrics = {
|
57 |
+
"negentropy": neg_k,
|
58 |
+
"lz_complexity": lz_c,
|
59 |
+
"symbiosis_score": sym_s,
|
60 |
+
}
|
61 |
+
if (
|
62 |
+
neg_k >= floors["negentropy"]
|
63 |
+
and lz_c >= floors["lz_complexity"]
|
64 |
+
and sym_s >= floors["symbiosis_score"]
|
65 |
+
):
|
66 |
+
break
|
67 |
+
if round_idx == 0:
|
68 |
+
params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
|
69 |
+
elif not width_scaled:
|
70 |
+
params["d_model"] = int(params.get("d_model", 32) * width_scale)
|
71 |
+
params["dim_feedforward"] = int(
|
72 |
+
params.get("dim_feedforward", 64) * width_scale
|
73 |
+
)
|
74 |
+
width_scaled = True
|
75 |
+
else:
|
76 |
+
params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
|
77 |
+
return model, metrics
|
78 |
+
|
79 |
+
|
80 |
+
def save_distilled_model(
|
81 |
+
model: BitTransformerLM,
|
82 |
+
path: str,
|
83 |
+
metrics: Dict[str, float],
|
84 |
+
floors: Optional[Dict[str, float]] = None,
|
85 |
+
) -> None:
|
86 |
+
"""Serialize a distilled model and its metric summary to disk.
|
87 |
+
|
88 |
+
Weights are written to ``path`` and a ``metrics.json`` file is placed in the
|
89 |
+
same directory containing the achieved metrics alongside the target floors.
|
90 |
+
"""
|
91 |
+
torch.save(model.state_dict(), path)
|
92 |
+
payload = {"metrics": metrics, "floors": floors or {}}
|
93 |
+
metrics_path = os.path.join(os.path.dirname(path), "metrics.json")
|
94 |
+
with open(metrics_path, "w") as f:
|
95 |
+
json.dump(payload, f)
|
BitTransformerLM/bit_transformer/compression.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
def compress_bits(bits: torch.Tensor) -> torch.Tensor:
|
6 |
+
"""Run-length encode a 1D tensor of bits.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
bits: 1D tensor with values 0 or 1 (bool or uint8).
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
1D uint8 tensor containing interleaved values and run lengths.
|
13 |
+
"""
|
14 |
+
if bits.dim() != 1:
|
15 |
+
raise ValueError("compress_bits expects a 1D tensor")
|
16 |
+
b = bits.to(torch.uint8).flatten()
|
17 |
+
if b.numel() == 0:
|
18 |
+
return b
|
19 |
+
changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
|
20 |
+
starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
|
21 |
+
ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
|
22 |
+
values = b[starts.to(torch.long)]
|
23 |
+
counts = ends - starts
|
24 |
+
|
25 |
+
out_vals: List[int] = []
|
26 |
+
out_counts: List[int] = []
|
27 |
+
for v, c in zip(values.tolist(), counts.tolist()):
|
28 |
+
while c > 255:
|
29 |
+
out_vals.append(v)
|
30 |
+
out_counts.append(255)
|
31 |
+
c -= 255
|
32 |
+
out_vals.append(v)
|
33 |
+
out_counts.append(c)
|
34 |
+
values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
|
35 |
+
counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
|
36 |
+
out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
|
41 |
+
"""Decode a run-length encoded bit tensor."""
|
42 |
+
if compressed.dim() != 1 or compressed.numel() % 2 != 0:
|
43 |
+
raise ValueError("compressed tensor must be 1D even-length")
|
44 |
+
data = compressed.to(torch.uint8)
|
45 |
+
values = data[0::2]
|
46 |
+
counts = data[1::2].to(torch.long)
|
47 |
+
return torch.repeat_interleave(values, counts)
|
48 |
+
|
49 |
+
|
50 |
+
def model_output_decompress(compressed_batch) -> torch.Tensor:
|
51 |
+
"""Decompress a batch of compressed bit sequences."""
|
52 |
+
if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
|
53 |
+
sequences = [decompress_bits(compressed_batch)]
|
54 |
+
else:
|
55 |
+
sequences = [decompress_bits(row) for row in compressed_batch]
|
56 |
+
lengths = [seq.numel() for seq in sequences]
|
57 |
+
if len(set(lengths)) != 1:
|
58 |
+
raise ValueError("Sequences decompress to different lengths")
|
59 |
+
return torch.stack(sequences)
|
60 |
+
|
61 |
+
|
62 |
+
import numpy as np
|
63 |
+
|
64 |
+
|
65 |
+
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
|
66 |
+
"""Pack groups of 8 bits into uint8 values using numpy.packbits."""
|
67 |
+
if bits.dim() != 1:
|
68 |
+
raise ValueError("pack_bits expects a 1D tensor")
|
69 |
+
arr = bits.to(torch.uint8).cpu().numpy()
|
70 |
+
packed = np.packbits(arr)
|
71 |
+
return torch.from_numpy(packed)
|
72 |
+
|
73 |
+
|
74 |
+
def unpack_bits(packed: torch.Tensor, *, n_bits: int | None = None) -> torch.Tensor:
|
75 |
+
"""Unpack uint8 values back into a bit tensor."""
|
76 |
+
if packed.dim() != 1:
|
77 |
+
raise ValueError("unpack_bits expects a 1D tensor")
|
78 |
+
arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
|
79 |
+
if n_bits is not None:
|
80 |
+
arr = arr[:n_bits]
|
81 |
+
return torch.from_numpy(arr)
|
82 |
+
|
BitTransformerLM/bit_transformer/dashboard.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
|
5 |
+
def plot_telemetry(
|
6 |
+
metrics_log: Dict[str, List[float]],
|
7 |
+
k_floor: float = 0.5,
|
8 |
+
c_floor: float = 0.3,
|
9 |
+
s_floor: float = 0.5,
|
10 |
+
) -> Tuple[plt.Figure, List[plt.Axes]]:
|
11 |
+
"""Plot K, C, S metrics over time with cluster transitions.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
metrics_log: Dictionary with keys ``negentropy``, ``lz_complexity``,
|
15 |
+
``symbiosis_score`` and optional ``clusters`` listing cluster
|
16 |
+
assignments per step.
|
17 |
+
k_floor: Threshold for negentropy (K).
|
18 |
+
c_floor: Threshold for LZ complexity (C).
|
19 |
+
s_floor: Threshold for symbiosis score (S).
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
(figure, axes) tuple for further customization or saving.
|
23 |
+
"""
|
24 |
+
steps = list(range(len(metrics_log.get("negentropy", []))))
|
25 |
+
fig, axes = plt.subplots(3, 1, sharex=True, figsize=(10, 6))
|
26 |
+
metrics = [
|
27 |
+
("negentropy", k_floor, "K"),
|
28 |
+
("lz_complexity", c_floor, "C"),
|
29 |
+
("symbiosis_score", s_floor, "S"),
|
30 |
+
]
|
31 |
+
for ax, (key, floor, label) in zip(axes, metrics):
|
32 |
+
values = metrics_log.get(key, [])
|
33 |
+
ax.plot(steps, values, label=label)
|
34 |
+
ax.axhline(floor, color="r", linestyle="--", linewidth=1)
|
35 |
+
violations = [i for i, v in enumerate(values) if v < floor]
|
36 |
+
if violations:
|
37 |
+
ax.scatter(
|
38 |
+
[steps[i] for i in violations],
|
39 |
+
[values[i] for i in violations],
|
40 |
+
color="r",
|
41 |
+
zorder=5,
|
42 |
+
label="violation",
|
43 |
+
)
|
44 |
+
ax.set_ylabel(label)
|
45 |
+
ax.legend(loc="upper right")
|
46 |
+
|
47 |
+
clusters = metrics_log.get("clusters")
|
48 |
+
if clusters is not None:
|
49 |
+
prev = clusters[0]
|
50 |
+
for t, c in enumerate(clusters):
|
51 |
+
if t > 0 and c != prev:
|
52 |
+
for ax in axes:
|
53 |
+
ax.axvline(t, color="gray", linestyle=":", alpha=0.5)
|
54 |
+
prev = c
|
55 |
+
|
56 |
+
axes[-1].set_xlabel("step")
|
57 |
+
plt.tight_layout()
|
58 |
+
return fig, axes
|
BitTransformerLM/bit_transformer/dashboard_app.py
ADDED
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
import inspect
|
6 |
+
from typing import Any, Dict, List
|
7 |
+
|
8 |
+
from flask import Flask, jsonify, request, render_template, send_file
|
9 |
+
import subprocess
|
10 |
+
import sys
|
11 |
+
import warnings
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import requests
|
16 |
+
import gzip
|
17 |
+
|
18 |
+
from .model import BitTransformerLM, infer_long_sequence
|
19 |
+
from .optimization import configure_optimizer
|
20 |
+
from .collapse import collapse_submodel
|
21 |
+
from .dashboard import plot_telemetry
|
22 |
+
from .scale import expand_model
|
23 |
+
from .bit_io import text_to_bits, bits_to_text
|
24 |
+
from .safety import hil_safe_inference
|
25 |
+
from .compression import model_output_decompress, compress_bits
|
26 |
+
from .distributed import wrap_fsdp
|
27 |
+
from .training import train_loop
|
28 |
+
from .telemetry import detect_metric_drift
|
29 |
+
from .quantization import prepare_qat_fx, convert_qat_fx
|
30 |
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
31 |
+
from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
|
32 |
+
|
33 |
+
|
34 |
+
app = Flask(__name__)
|
35 |
+
app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 # 1MB upload limit
|
36 |
+
|
37 |
+
MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR")
|
38 |
+
|
39 |
+
|
40 |
+
@app.errorhandler(Exception)
|
41 |
+
def handle_exception(err):
|
42 |
+
"""Return JSON error responses with stack traces."""
|
43 |
+
return (
|
44 |
+
jsonify({"error": str(err), "trace": traceback.format_exc()}),
|
45 |
+
getattr(err, "code", 500),
|
46 |
+
)
|
47 |
+
|
48 |
+
class MetricDriftWarning(UserWarning):
|
49 |
+
"""Raised when telemetry metrics drift beyond the configured threshold."""
|
50 |
+
|
51 |
+
def _switch_torch(use_gpu: bool) -> None:
|
52 |
+
"""Install the appropriate PyTorch wheel and restart the process."""
|
53 |
+
have_cuda = torch.version.cuda is not None
|
54 |
+
if use_gpu == have_cuda:
|
55 |
+
return
|
56 |
+
wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu"
|
57 |
+
url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu"
|
58 |
+
subprocess.run([
|
59 |
+
sys.executable,
|
60 |
+
"-m",
|
61 |
+
"pip",
|
62 |
+
"install",
|
63 |
+
"--extra-index-url",
|
64 |
+
url,
|
65 |
+
wheel,
|
66 |
+
], check=True)
|
67 |
+
os.execv(sys.executable, [sys.executable] + sys.argv)
|
68 |
+
|
69 |
+
def mcp_post(path: str, data=None):
|
70 |
+
if not MCP_SERVER_ADDR:
|
71 |
+
return None
|
72 |
+
url = MCP_SERVER_ADDR.rstrip("/") + path
|
73 |
+
resp = requests.post(url, json=data)
|
74 |
+
resp.raise_for_status()
|
75 |
+
if resp.headers.get("Content-Type", "").startswith("image/"):
|
76 |
+
return resp.content
|
77 |
+
return resp.json()
|
78 |
+
|
79 |
+
def mcp_get(path: str):
|
80 |
+
if not MCP_SERVER_ADDR:
|
81 |
+
return None
|
82 |
+
url = MCP_SERVER_ADDR.rstrip("/") + path
|
83 |
+
resp = requests.get(url)
|
84 |
+
resp.raise_for_status()
|
85 |
+
if resp.headers.get("Content-Type", "").startswith("image/"):
|
86 |
+
return resp.content
|
87 |
+
return resp.json()
|
88 |
+
|
89 |
+
class ModelManager:
|
90 |
+
"""Manage model state and training utilities for the dashboard."""
|
91 |
+
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
snapshot_dir: str | None = None,
|
95 |
+
telemetry_log: str | None = None,
|
96 |
+
*,
|
97 |
+
drift_window: int = 10,
|
98 |
+
drift_threshold: float = 0.2,
|
99 |
+
) -> None:
|
100 |
+
self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots")
|
101 |
+
self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG")
|
102 |
+
if self.telemetry_log is None:
|
103 |
+
self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json")
|
104 |
+
os.makedirs(self.snapshot_dir, exist_ok=True)
|
105 |
+
self.weights_path = os.path.join(self.snapshot_dir, "model.pt")
|
106 |
+
|
107 |
+
self.model: BitTransformerLM | None = None
|
108 |
+
self.optimizer: torch.optim.Optimizer | None = None
|
109 |
+
self.scheduler: torch.optim.lr_scheduler._LRScheduler | None = None
|
110 |
+
self.total_steps = 100
|
111 |
+
self.metrics: Dict[str, List[float]] = {
|
112 |
+
"negentropy_logits": [],
|
113 |
+
"lz_complexity_logits": [],
|
114 |
+
"symbiosis_score": [],
|
115 |
+
}
|
116 |
+
self.drift_window = drift_window
|
117 |
+
self.drift_threshold = drift_threshold
|
118 |
+
self.lambda_K = 1.0
|
119 |
+
self.lambda_C = 1.0
|
120 |
+
self.lambda_S = 1.0
|
121 |
+
self.c_floor = 0.3
|
122 |
+
self.s_floor = 0.5
|
123 |
+
self.causal = True
|
124 |
+
self.diffusion = False
|
125 |
+
self.decompress_output = False
|
126 |
+
self.use_compression = False
|
127 |
+
self.use_gpu = False
|
128 |
+
self.qat = False
|
129 |
+
|
130 |
+
# Load any existing state
|
131 |
+
if os.path.exists(self.telemetry_log):
|
132 |
+
try:
|
133 |
+
with open(self.telemetry_log) as f:
|
134 |
+
saved = json.load(f)
|
135 |
+
for key in self.metrics:
|
136 |
+
self.metrics[key] = saved.get(key, [])
|
137 |
+
except Exception:
|
138 |
+
pass
|
139 |
+
if os.path.exists(self.weights_path):
|
140 |
+
try:
|
141 |
+
self.model = torch.load(self.weights_path, map_location="cpu")
|
142 |
+
self.optimizer, self.scheduler = configure_optimizer(
|
143 |
+
self.model, lr=1e-3, total_steps=self.total_steps
|
144 |
+
)
|
145 |
+
self._apply_device()
|
146 |
+
except Exception:
|
147 |
+
self.model = None
|
148 |
+
|
149 |
+
config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json")
|
150 |
+
if self.model is None and os.path.exists(config_path):
|
151 |
+
try:
|
152 |
+
with open(config_path) as f:
|
153 |
+
params = json.load(f)
|
154 |
+
self.init_model(params)
|
155 |
+
except Exception:
|
156 |
+
pass
|
157 |
+
|
158 |
+
def init_model(self, params: Dict) -> None:
|
159 |
+
int_fields = {
|
160 |
+
"d_model",
|
161 |
+
"nhead",
|
162 |
+
"num_layers",
|
163 |
+
"dim_feedforward",
|
164 |
+
"max_seq_len",
|
165 |
+
"chunk_size",
|
166 |
+
"overlap",
|
167 |
+
}
|
168 |
+
float_fields = {"act_threshold"}
|
169 |
+
bool_fields = {"reversible", "use_checkpoint"}
|
170 |
+
clean: Dict[str, Any] = {}
|
171 |
+
for k, v in params.items():
|
172 |
+
if v is None:
|
173 |
+
clean[k] = None
|
174 |
+
elif k in int_fields:
|
175 |
+
clean[k] = int(v)
|
176 |
+
elif k in float_fields:
|
177 |
+
clean[k] = float(v)
|
178 |
+
elif k in bool_fields:
|
179 |
+
clean[k] = bool(v)
|
180 |
+
else:
|
181 |
+
clean[k] = v
|
182 |
+
self.model = BitTransformerLM(
|
183 |
+
**clean,
|
184 |
+
lambda_K=self.lambda_K,
|
185 |
+
lambda_C=self.lambda_C,
|
186 |
+
lambda_S=self.lambda_S,
|
187 |
+
)
|
188 |
+
self.optimizer, self.scheduler = configure_optimizer(
|
189 |
+
self.model, lr=1e-3, total_steps=self.total_steps
|
190 |
+
)
|
191 |
+
self._apply_device()
|
192 |
+
for key in self.metrics:
|
193 |
+
self.metrics[key].clear()
|
194 |
+
|
195 |
+
def set_lambdas(self, k: float, c: float, s: float) -> None:
|
196 |
+
"""Update λ weights and propagate to the model."""
|
197 |
+
self.lambda_K = k
|
198 |
+
self.lambda_C = c
|
199 |
+
self.lambda_S = s
|
200 |
+
if self.model is not None:
|
201 |
+
self.model.set_lambdas(k, c, s)
|
202 |
+
|
203 |
+
def set_floors(self, c_floor: float, s_floor: float) -> None:
|
204 |
+
"""Update safety floors for complexity (C) and symbiosis (S)."""
|
205 |
+
self.c_floor = c_floor
|
206 |
+
self.s_floor = s_floor
|
207 |
+
|
208 |
+
def set_diffusion(self, flag: bool) -> None:
|
209 |
+
"""Toggle Diffusion LM mode which disables causal masking and chunking."""
|
210 |
+
self.diffusion = flag
|
211 |
+
self.causal = not flag
|
212 |
+
if self.model is not None and flag:
|
213 |
+
self.model.chunk_size = None
|
214 |
+
|
215 |
+
def set_decompress_output(self, flag: bool) -> None:
|
216 |
+
"""Enable or disable decompression of model outputs."""
|
217 |
+
self.decompress_output = flag
|
218 |
+
|
219 |
+
def set_compression(self, flag: bool) -> None:
|
220 |
+
"""Toggle automatic compression of inputs."""
|
221 |
+
self.use_compression = flag
|
222 |
+
|
223 |
+
def set_qat(self, flag: bool) -> None:
|
224 |
+
"""Enable or disable 4-bit quantization-aware training."""
|
225 |
+
self.qat = flag
|
226 |
+
if self.model is None:
|
227 |
+
return
|
228 |
+
if flag:
|
229 |
+
self.model = prepare_qat_fx(self.model)
|
230 |
+
else:
|
231 |
+
self.model = convert_qat_fx(self.model)
|
232 |
+
|
233 |
+
def set_gpu(self, flag: bool) -> None:
|
234 |
+
"""Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed."""
|
235 |
+
_switch_torch(flag)
|
236 |
+
self.use_gpu = flag and torch.cuda.is_available()
|
237 |
+
self._apply_device()
|
238 |
+
|
239 |
+
def _apply_device(self) -> None:
|
240 |
+
"""Move the model to the selected device and wrap with FSDP if needed."""
|
241 |
+
if self.model is None:
|
242 |
+
return
|
243 |
+
if self.use_gpu:
|
244 |
+
device = torch.device("cuda")
|
245 |
+
if isinstance(self.model, FullyShardedDataParallel):
|
246 |
+
base = self.model.module
|
247 |
+
else:
|
248 |
+
base = self.model
|
249 |
+
base = base.to(device)
|
250 |
+
self.model = wrap_fsdp(base, device_id=device)
|
251 |
+
else:
|
252 |
+
device = torch.device("cpu")
|
253 |
+
if isinstance(self.model, FullyShardedDataParallel):
|
254 |
+
self.model = self.model.module
|
255 |
+
self.model = self.model.to(device)
|
256 |
+
|
257 |
+
def train_step(self, bits: torch.Tensor) -> float:
|
258 |
+
assert (
|
259 |
+
self.model is not None
|
260 |
+
and self.optimizer is not None
|
261 |
+
and self.scheduler is not None
|
262 |
+
)
|
263 |
+
self.model.train()
|
264 |
+
device = next(self.model.parameters()).device
|
265 |
+
bits = bits.to(device)
|
266 |
+
ratio = 1.0
|
267 |
+
if self.use_compression:
|
268 |
+
comps = [compress_bits(row.to(torch.uint8)) for row in bits]
|
269 |
+
comp_len = sum(c.numel() for c in comps)
|
270 |
+
ratio = min(comp_len / bits.numel(), 1.0)
|
271 |
+
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
|
272 |
+
else:
|
273 |
+
logits, telemetry = self.model(bits, causal=self.causal)
|
274 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
275 |
+
target = bits[:, 1:].reshape(-1)
|
276 |
+
loss = F.cross_entropy(pred, target)
|
277 |
+
loss.backward()
|
278 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
279 |
+
self.optimizer.step()
|
280 |
+
self.scheduler.step()
|
281 |
+
self.optimizer.zero_grad()
|
282 |
+
self._log_metrics(telemetry)
|
283 |
+
self._save_state()
|
284 |
+
return loss.item(), ratio
|
285 |
+
|
286 |
+
def train_epochs(
|
287 |
+
self,
|
288 |
+
bits: torch.Tensor,
|
289 |
+
*,
|
290 |
+
epochs: int = 1,
|
291 |
+
compress_prob: float = 0.5,
|
292 |
+
direct_prob: float = 0.0,
|
293 |
+
batch_size: int = 8,
|
294 |
+
num_workers: int = 0,
|
295 |
+
accum_steps: int = 1,
|
296 |
+
amp: bool = False,
|
297 |
+
compile_model: bool = False,
|
298 |
+
) -> List[Dict[str, float]]:
|
299 |
+
"""Run ``train_loop`` on a batch tensor and persist the state."""
|
300 |
+
assert self.model is not None
|
301 |
+
device = next(self.model.parameters()).device
|
302 |
+
bits = bits.to(device)
|
303 |
+
import math
|
304 |
+
steps_per_epoch = max(1, math.ceil(len(bits) / batch_size))
|
305 |
+
self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps)
|
306 |
+
self.optimizer, self.scheduler = configure_optimizer(
|
307 |
+
self.model, lr=1e-3, total_steps=self.total_steps
|
308 |
+
)
|
309 |
+
metrics = train_loop(
|
310 |
+
self.model,
|
311 |
+
bits,
|
312 |
+
epochs=epochs,
|
313 |
+
compress_prob=compress_prob if self.use_compression else 0.0,
|
314 |
+
direct_prob=direct_prob,
|
315 |
+
batch_size=batch_size,
|
316 |
+
num_workers=num_workers,
|
317 |
+
accum_steps=accum_steps,
|
318 |
+
amp=amp,
|
319 |
+
compile_model=compile_model,
|
320 |
+
forward_kwargs={"causal": self.causal},
|
321 |
+
optimizer=self.optimizer,
|
322 |
+
scheduler=self.scheduler,
|
323 |
+
)
|
324 |
+
self._save_state()
|
325 |
+
return metrics
|
326 |
+
|
327 |
+
def scale_up(self, width_mult: float = 1.0) -> None:
|
328 |
+
assert self.model is not None
|
329 |
+
params = dict(
|
330 |
+
d_model=int(self.model.d_model * width_mult),
|
331 |
+
nhead=self.model.layers[0].self_attn.num_heads,
|
332 |
+
num_layers=self.model.num_layers * 2,
|
333 |
+
dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult),
|
334 |
+
max_seq_len=self.model.pos_enc.pe.size(0),
|
335 |
+
)
|
336 |
+
self.model = expand_model(self.model, {
|
337 |
+
**params,
|
338 |
+
"lambda_K": self.lambda_K,
|
339 |
+
"lambda_C": self.lambda_C,
|
340 |
+
"lambda_S": self.lambda_S,
|
341 |
+
})
|
342 |
+
self.optimizer, self.scheduler = configure_optimizer(
|
343 |
+
self.model, lr=1e-3, total_steps=self.total_steps
|
344 |
+
)
|
345 |
+
self._save_state()
|
346 |
+
|
347 |
+
def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None:
|
348 |
+
self.model, _ = collapse_submodel(
|
349 |
+
cluster_bits,
|
350 |
+
target_params,
|
351 |
+
width_scale=width_scale,
|
352 |
+
forward_kwargs={"causal": self.causal},
|
353 |
+
)
|
354 |
+
self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S)
|
355 |
+
self.optimizer, self.scheduler = configure_optimizer(
|
356 |
+
self.model, lr=1e-3, total_steps=self.total_steps
|
357 |
+
)
|
358 |
+
self._apply_device()
|
359 |
+
for key in self.metrics:
|
360 |
+
self.metrics[key].clear()
|
361 |
+
|
362 |
+
def infer(self, bits: torch.Tensor) -> Dict:
|
363 |
+
assert self.model is not None
|
364 |
+
self.model.eval()
|
365 |
+
device = next(self.model.parameters()).device
|
366 |
+
bits = bits.to(device)
|
367 |
+
ratio = 1.0
|
368 |
+
with torch.no_grad():
|
369 |
+
if self.use_compression:
|
370 |
+
comps = [compress_bits(row.to(torch.uint8)) for row in bits]
|
371 |
+
comp_len = sum(c.numel() for c in comps)
|
372 |
+
ratio = min(comp_len / bits.numel(), 1.0)
|
373 |
+
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
|
374 |
+
else:
|
375 |
+
logits, telemetry = self.model(bits, causal=self.causal)
|
376 |
+
self._log_metrics(telemetry)
|
377 |
+
pred_bits = logits.argmax(-1)
|
378 |
+
if self.decompress_output:
|
379 |
+
try:
|
380 |
+
pred_bits = model_output_decompress(pred_bits)
|
381 |
+
except Exception as e:
|
382 |
+
return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."}
|
383 |
+
def _to_python(obj):
|
384 |
+
if isinstance(obj, torch.Tensor):
|
385 |
+
return obj.tolist()
|
386 |
+
if isinstance(obj, list):
|
387 |
+
return [_to_python(o) for o in obj]
|
388 |
+
if isinstance(obj, dict):
|
389 |
+
return {kk: _to_python(vv) for kk, vv in obj.items()}
|
390 |
+
return obj
|
391 |
+
tele = {k: _to_python(v) for k, v in telemetry.items()}
|
392 |
+
return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio}
|
393 |
+
|
394 |
+
def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict:
|
395 |
+
"""Run sliding-window inference on a long sequence."""
|
396 |
+
assert self.model is not None
|
397 |
+
device = next(self.model.parameters()).device
|
398 |
+
bits = bits.to(device)
|
399 |
+
preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap)
|
400 |
+
for tele in logs:
|
401 |
+
self._log_metrics(tele)
|
402 |
+
return {"predicted": preds.tolist(), "windows": len(logs)}
|
403 |
+
|
404 |
+
def _log_metrics(self, telemetry: Dict) -> None:
|
405 |
+
for key in self.metrics:
|
406 |
+
val = telemetry[key].mean().item()
|
407 |
+
self.metrics[key].append(val)
|
408 |
+
drift = detect_metric_drift(
|
409 |
+
self.metrics, window=self.drift_window, threshold=self.drift_threshold
|
410 |
+
)
|
411 |
+
bad = [k for k, v in drift.items() if v]
|
412 |
+
if bad:
|
413 |
+
warnings.warn(
|
414 |
+
f"Metric drift detected: {', '.join(bad)}",
|
415 |
+
MetricDriftWarning,
|
416 |
+
)
|
417 |
+
|
418 |
+
def infer_text(self, text: str) -> Dict[str, Any]:
|
419 |
+
"""Run text through the model using the safety gate."""
|
420 |
+
assert self.model is not None
|
421 |
+
device = next(self.model.parameters()).device
|
422 |
+
bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device)
|
423 |
+
out_bits, telemetry = hil_safe_inference(
|
424 |
+
self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor
|
425 |
+
)
|
426 |
+
self._log_metrics(telemetry)
|
427 |
+
return {
|
428 |
+
"output": bits_to_text(out_bits.squeeze(0).tolist()),
|
429 |
+
"telemetry": telemetry,
|
430 |
+
}
|
431 |
+
|
432 |
+
def get_status(self) -> Dict[str, Any]:
|
433 |
+
info: Dict[str, Any] = {
|
434 |
+
"use_gpu": self.use_gpu,
|
435 |
+
"diffusion": self.diffusion,
|
436 |
+
"compression": self.use_compression,
|
437 |
+
"lambda_K": self.lambda_K,
|
438 |
+
"lambda_C": self.lambda_C,
|
439 |
+
"lambda_S": self.lambda_S,
|
440 |
+
"c_floor": self.c_floor,
|
441 |
+
"s_floor": self.s_floor,
|
442 |
+
"qat": self.qat,
|
443 |
+
}
|
444 |
+
if self.model is not None:
|
445 |
+
info.update(
|
446 |
+
{
|
447 |
+
"d_model": self.model.d_model,
|
448 |
+
"num_layers": self.model.num_layers,
|
449 |
+
"d_ff": self.model.layers[0].linear1.out_features,
|
450 |
+
"nhead": self.model.layers[0].self_attn.num_heads,
|
451 |
+
"max_seq_len": self.model.pos_enc.pe.size(0),
|
452 |
+
}
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
info.update(
|
456 |
+
{
|
457 |
+
"d_model": None,
|
458 |
+
"num_layers": 0,
|
459 |
+
"d_ff": None,
|
460 |
+
"nhead": None,
|
461 |
+
"max_seq_len": None,
|
462 |
+
}
|
463 |
+
)
|
464 |
+
return info
|
465 |
+
|
466 |
+
def get_model_config(self) -> Dict[str, Any]:
|
467 |
+
"""Return current model hyperparameters and safety settings."""
|
468 |
+
cfg: Dict[str, Any] = {
|
469 |
+
"lambda_K": self.lambda_K,
|
470 |
+
"lambda_C": self.lambda_C,
|
471 |
+
"lambda_S": self.lambda_S,
|
472 |
+
"c_floor": self.c_floor,
|
473 |
+
"s_floor": self.s_floor,
|
474 |
+
}
|
475 |
+
if self.model is not None:
|
476 |
+
cfg.update(
|
477 |
+
{
|
478 |
+
"d_model": self.model.d_model,
|
479 |
+
"nhead": self.model.layers[0].self_attn.num_heads,
|
480 |
+
"num_layers": self.model.num_layers,
|
481 |
+
"dim_feedforward": self.model.layers[0].linear1.out_features,
|
482 |
+
"max_seq_len": self.model.pos_enc.pe.size(0),
|
483 |
+
"chunk_size": self.model.chunk_size,
|
484 |
+
"reversible": self.model.reversible,
|
485 |
+
"use_checkpoint": self.model.use_checkpoint,
|
486 |
+
}
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
cfg.update(
|
490 |
+
{
|
491 |
+
"d_model": None,
|
492 |
+
"nhead": None,
|
493 |
+
"num_layers": 0,
|
494 |
+
"dim_feedforward": None,
|
495 |
+
"max_seq_len": None,
|
496 |
+
"chunk_size": None,
|
497 |
+
"reversible": None,
|
498 |
+
"use_checkpoint": None,
|
499 |
+
}
|
500 |
+
)
|
501 |
+
return cfg
|
502 |
+
|
503 |
+
def get_metrics(self) -> Dict[str, Any]:
|
504 |
+
"""Return logged telemetry metrics with summary statistics."""
|
505 |
+
from statistics import mean, stdev
|
506 |
+
|
507 |
+
data = {
|
508 |
+
"negentropy": self.metrics["negentropy_logits"],
|
509 |
+
"lz_complexity": self.metrics["lz_complexity_logits"],
|
510 |
+
"symbiosis": self.metrics["symbiosis_score"],
|
511 |
+
}
|
512 |
+
summary: Dict[str, Dict[str, float | None]] = {}
|
513 |
+
for key, values in data.items():
|
514 |
+
if values:
|
515 |
+
m = mean(values)
|
516 |
+
s = stdev(values) if len(values) > 1 else 0.0
|
517 |
+
summary[key] = {"mean": m, "std": s}
|
518 |
+
else:
|
519 |
+
summary[key] = {"mean": None, "std": None}
|
520 |
+
data["summary"] = summary
|
521 |
+
return data
|
522 |
+
|
523 |
+
|
524 |
+
def _save_state(self) -> None:
|
525 |
+
if self.model is None:
|
526 |
+
return
|
527 |
+
torch.save(self.model, self.weights_path)
|
528 |
+
with open(self.telemetry_log, "w") as f:
|
529 |
+
json.dump(self.metrics, f)
|
530 |
+
|
531 |
+
|
532 |
+
manager: ModelManager | None = None
|
533 |
+
|
534 |
+
|
535 |
+
@app.route("/")
|
536 |
+
def index():
|
537 |
+
return render_template(
|
538 |
+
"dashboard.html",
|
539 |
+
metrics=manager.metrics,
|
540 |
+
lambdas={
|
541 |
+
"lambda_K": manager.lambda_K,
|
542 |
+
"lambda_C": manager.lambda_C,
|
543 |
+
"lambda_S": manager.lambda_S,
|
544 |
+
},
|
545 |
+
diffusion=manager.diffusion,
|
546 |
+
compression=manager.use_compression,
|
547 |
+
defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty},
|
548 |
+
c_floor=manager.c_floor,
|
549 |
+
s_floor=manager.s_floor,
|
550 |
+
qat=manager.qat,
|
551 |
+
)
|
552 |
+
|
553 |
+
|
554 |
+
@app.route("/status", methods=["GET"])
|
555 |
+
def status():
|
556 |
+
if MCP_SERVER_ADDR:
|
557 |
+
return jsonify(mcp_get("/status"))
|
558 |
+
return jsonify(manager.get_status())
|
559 |
+
|
560 |
+
|
561 |
+
@app.route("/model_config", methods=["GET"])
|
562 |
+
def model_config():
|
563 |
+
if MCP_SERVER_ADDR:
|
564 |
+
return jsonify(mcp_get("/model_config"))
|
565 |
+
return jsonify(manager.get_model_config())
|
566 |
+
|
567 |
+
|
568 |
+
@app.route("/metrics", methods=["GET"])
|
569 |
+
def metrics():
|
570 |
+
if MCP_SERVER_ADDR:
|
571 |
+
return jsonify(mcp_get("/metrics"))
|
572 |
+
return jsonify(manager.get_metrics())
|
573 |
+
|
574 |
+
|
575 |
+
@app.route("/save_checkpoint", methods=["POST"])
|
576 |
+
def save_checkpoint_route():
|
577 |
+
repo_id = request.json.get("repo_id")
|
578 |
+
token = request.json.get("token") or os.getenv("HF_TOKEN")
|
579 |
+
if MCP_SERVER_ADDR:
|
580 |
+
return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token}))
|
581 |
+
if manager.model is None:
|
582 |
+
return jsonify({"error": "model not initialized"}), 400
|
583 |
+
if token:
|
584 |
+
hf_login(token=token)
|
585 |
+
save_checkpoint(manager.model, repo_id=repo_id)
|
586 |
+
return jsonify({"status": "saved"})
|
587 |
+
|
588 |
+
|
589 |
+
@app.route("/download_checkpoint", methods=["POST"])
|
590 |
+
def download_checkpoint_route():
|
591 |
+
repo_id = request.json.get("repo_id")
|
592 |
+
token = request.json.get("token") or os.getenv("HF_TOKEN")
|
593 |
+
if MCP_SERVER_ADDR:
|
594 |
+
return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token}))
|
595 |
+
if token:
|
596 |
+
hf_login(token=token)
|
597 |
+
dest = manager.weights_path + ".gz"
|
598 |
+
ok = download_checkpoint(dest, repo_id=repo_id)
|
599 |
+
if not ok:
|
600 |
+
return jsonify({"status": "failed"}), 500
|
601 |
+
if manager.model is None:
|
602 |
+
return jsonify({"status": "downloaded", "loaded": False})
|
603 |
+
with gzip.open(dest, "rb") as f:
|
604 |
+
state = torch.load(f, map_location="cpu")
|
605 |
+
manager.model.load_state_dict(state)
|
606 |
+
manager.optimizer, manager.scheduler = configure_optimizer(
|
607 |
+
manager.model, lr=1e-3, total_steps=manager.total_steps
|
608 |
+
)
|
609 |
+
manager._apply_device()
|
610 |
+
manager._save_state()
|
611 |
+
return jsonify({"status": "downloaded", "loaded": True})
|
612 |
+
|
613 |
+
|
614 |
+
@app.route("/text_to_bits", methods=["POST"])
|
615 |
+
def text_to_bits_route():
|
616 |
+
text = request.json.get("text", "")
|
617 |
+
if len(text) > 100_000:
|
618 |
+
return jsonify({"error": "text too large"}), 413
|
619 |
+
return jsonify({"bits": text_to_bits(text)})
|
620 |
+
|
621 |
+
|
622 |
+
@app.route("/dataset", methods=["GET"])
|
623 |
+
def dataset_route():
|
624 |
+
name = request.args.get("name", "")
|
625 |
+
split = request.args.get("split", "train")
|
626 |
+
size = int(request.args.get("size", 1))
|
627 |
+
seq_len = int(request.args.get("seq_len", 64))
|
628 |
+
if size * seq_len > 1_000_000:
|
629 |
+
return jsonify({"error": "dataset too large"}), 413
|
630 |
+
if name == "wikitext2":
|
631 |
+
try:
|
632 |
+
from datasets import load_dataset
|
633 |
+
|
634 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
|
635 |
+
lines = [t for t in ds["text"] if t.strip()][:size]
|
636 |
+
except Exception:
|
637 |
+
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
|
638 |
+
return jsonify({"bits": bits.tolist()})
|
639 |
+
bits_list = []
|
640 |
+
for text in lines:
|
641 |
+
b = text_to_bits(text)[:seq_len]
|
642 |
+
if len(b) < seq_len:
|
643 |
+
b.extend([0] * (seq_len - len(b)))
|
644 |
+
bits_list.append(b)
|
645 |
+
if len(bits_list) < size:
|
646 |
+
pad = size - len(bits_list)
|
647 |
+
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
|
648 |
+
return jsonify({"bits": bits_list})
|
649 |
+
return jsonify({"error": "unknown dataset"}), 400
|
650 |
+
|
651 |
+
|
652 |
+
@app.route("/init", methods=["POST"])
|
653 |
+
def init_model():
|
654 |
+
data = request.json or {}
|
655 |
+
int_fields = {
|
656 |
+
"d_model",
|
657 |
+
"nhead",
|
658 |
+
"num_layers",
|
659 |
+
"dim_feedforward",
|
660 |
+
"max_seq_len",
|
661 |
+
"chunk_size",
|
662 |
+
"overlap",
|
663 |
+
}
|
664 |
+
float_fields = {"act_threshold"}
|
665 |
+
bool_fields = {"reversible", "use_checkpoint"}
|
666 |
+
params = {}
|
667 |
+
for k, v in data.items():
|
668 |
+
if v is None:
|
669 |
+
params[k] = None
|
670 |
+
elif k in int_fields:
|
671 |
+
params[k] = int(v)
|
672 |
+
elif k in float_fields:
|
673 |
+
params[k] = float(v)
|
674 |
+
elif k in bool_fields:
|
675 |
+
params[k] = bool(v)
|
676 |
+
else:
|
677 |
+
params[k] = v
|
678 |
+
if MCP_SERVER_ADDR:
|
679 |
+
data = mcp_post("/init", params)
|
680 |
+
return jsonify(data)
|
681 |
+
manager.init_model(params)
|
682 |
+
return jsonify({"status": "initialized", "params": params})
|
683 |
+
|
684 |
+
|
685 |
+
@app.route("/train", methods=["POST"])
|
686 |
+
def train_model():
|
687 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
688 |
+
if MCP_SERVER_ADDR:
|
689 |
+
data = mcp_post("/train", {"bits": request.json["bits"]})
|
690 |
+
return jsonify(data)
|
691 |
+
loss, ratio = manager.train_step(bits)
|
692 |
+
return jsonify({"loss": loss, "ratio": ratio})
|
693 |
+
|
694 |
+
|
695 |
+
@app.route("/train_epochs", methods=["POST"])
|
696 |
+
def train_epochs_route():
|
697 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
698 |
+
epochs = int(request.json.get("epochs", 1))
|
699 |
+
compress_prob = float(request.json.get("compress_prob", 0.5))
|
700 |
+
direct_prob = float(request.json.get("direct_prob", 0.0))
|
701 |
+
if MCP_SERVER_ADDR:
|
702 |
+
data = mcp_post(
|
703 |
+
"/train_epochs",
|
704 |
+
{
|
705 |
+
"bits": request.json["bits"],
|
706 |
+
"epochs": epochs,
|
707 |
+
"compress_prob": compress_prob,
|
708 |
+
"direct_prob": direct_prob,
|
709 |
+
},
|
710 |
+
)
|
711 |
+
return jsonify(data)
|
712 |
+
metrics = manager.train_epochs(
|
713 |
+
bits,
|
714 |
+
epochs=epochs,
|
715 |
+
compress_prob=compress_prob,
|
716 |
+
direct_prob=direct_prob,
|
717 |
+
)
|
718 |
+
return jsonify({"metrics": metrics})
|
719 |
+
|
720 |
+
|
721 |
+
@app.route("/scale_up", methods=["POST"])
|
722 |
+
def scale_up():
|
723 |
+
width_mult = float(request.json.get("width_mult", 1.0))
|
724 |
+
if MCP_SERVER_ADDR:
|
725 |
+
data = mcp_post("/scale_up", {"width_mult": width_mult})
|
726 |
+
return jsonify(data)
|
727 |
+
manager.scale_up(width_mult)
|
728 |
+
return jsonify({
|
729 |
+
"status": "scaled",
|
730 |
+
"layers": manager.model.num_layers,
|
731 |
+
"d_model": manager.model.d_model,
|
732 |
+
})
|
733 |
+
|
734 |
+
|
735 |
+
@app.route("/collapse", methods=["POST"])
|
736 |
+
def collapse_model():
|
737 |
+
cluster_bits = request.json["clusters"]
|
738 |
+
params = {k: int(v) for k, v in request.json["params"].items()}
|
739 |
+
width_scale = float(request.json.get("width_scale", 1.0))
|
740 |
+
if MCP_SERVER_ADDR:
|
741 |
+
data = mcp_post(
|
742 |
+
"/collapse",
|
743 |
+
{"clusters": cluster_bits, "params": params, "width_scale": width_scale},
|
744 |
+
)
|
745 |
+
return jsonify(data)
|
746 |
+
manager.collapse(cluster_bits, params, width_scale)
|
747 |
+
return jsonify({"status": "collapsed"})
|
748 |
+
|
749 |
+
|
750 |
+
@app.route("/lambdas", methods=["GET", "POST"])
|
751 |
+
def update_lambdas():
|
752 |
+
if request.method == "POST":
|
753 |
+
data = request.json
|
754 |
+
if MCP_SERVER_ADDR:
|
755 |
+
res = mcp_post("/lambdas", data)
|
756 |
+
return jsonify(res)
|
757 |
+
manager.set_lambdas(
|
758 |
+
float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])
|
759 |
+
)
|
760 |
+
return jsonify({"status": "updated"})
|
761 |
+
else:
|
762 |
+
if MCP_SERVER_ADDR:
|
763 |
+
return jsonify(mcp_get("/lambdas"))
|
764 |
+
return jsonify(
|
765 |
+
{
|
766 |
+
"lambda_K": manager.lambda_K,
|
767 |
+
"lambda_C": manager.lambda_C,
|
768 |
+
"lambda_S": manager.lambda_S,
|
769 |
+
}
|
770 |
+
)
|
771 |
+
|
772 |
+
|
773 |
+
@app.route("/config/telemetry", methods=["GET", "POST"])
|
774 |
+
def telemetry_config():
|
775 |
+
"""Get or update telemetry λ weights and safety floors."""
|
776 |
+
if request.method == "POST":
|
777 |
+
data = request.json
|
778 |
+
if MCP_SERVER_ADDR:
|
779 |
+
res = mcp_post("/config/telemetry", data)
|
780 |
+
return jsonify(res)
|
781 |
+
manager.set_lambdas(
|
782 |
+
float(data.get("lambda_K", manager.lambda_K)),
|
783 |
+
float(data.get("lambda_C", manager.lambda_C)),
|
784 |
+
float(data.get("lambda_S", manager.lambda_S)),
|
785 |
+
)
|
786 |
+
manager.set_floors(
|
787 |
+
float(data.get("c_floor", manager.c_floor)),
|
788 |
+
float(data.get("s_floor", manager.s_floor)),
|
789 |
+
)
|
790 |
+
return jsonify({"status": "updated"})
|
791 |
+
else:
|
792 |
+
if MCP_SERVER_ADDR:
|
793 |
+
return jsonify(mcp_get("/config/telemetry"))
|
794 |
+
return jsonify(
|
795 |
+
{
|
796 |
+
"lambda_K": manager.lambda_K,
|
797 |
+
"lambda_C": manager.lambda_C,
|
798 |
+
"lambda_S": manager.lambda_S,
|
799 |
+
"c_floor": manager.c_floor,
|
800 |
+
"s_floor": manager.s_floor,
|
801 |
+
}
|
802 |
+
)
|
803 |
+
|
804 |
+
|
805 |
+
@app.route("/diffusion", methods=["GET", "POST"])
|
806 |
+
def update_diffusion():
|
807 |
+
if request.method == "POST":
|
808 |
+
if MCP_SERVER_ADDR:
|
809 |
+
return jsonify(mcp_post("/diffusion", request.json))
|
810 |
+
manager.set_diffusion(bool(request.json.get("diffusion", False)))
|
811 |
+
return jsonify({"status": "updated"})
|
812 |
+
else:
|
813 |
+
if MCP_SERVER_ADDR:
|
814 |
+
return jsonify(mcp_get("/diffusion"))
|
815 |
+
return jsonify({"diffusion": manager.diffusion})
|
816 |
+
|
817 |
+
|
818 |
+
@app.route("/gpu", methods=["GET", "POST"])
|
819 |
+
def update_gpu():
|
820 |
+
if request.method == "POST":
|
821 |
+
if MCP_SERVER_ADDR:
|
822 |
+
return jsonify(mcp_post("/gpu", request.json))
|
823 |
+
manager.set_gpu(bool(request.json.get("use_gpu", False)))
|
824 |
+
return jsonify({"status": "updated"})
|
825 |
+
else:
|
826 |
+
if MCP_SERVER_ADDR:
|
827 |
+
return jsonify(mcp_get("/gpu"))
|
828 |
+
return jsonify({"use_gpu": manager.use_gpu})
|
829 |
+
|
830 |
+
|
831 |
+
@app.route("/compression", methods=["GET", "POST"])
|
832 |
+
def update_compression():
|
833 |
+
if request.method == "POST":
|
834 |
+
if MCP_SERVER_ADDR:
|
835 |
+
return jsonify(mcp_post("/compression", request.json))
|
836 |
+
manager.set_compression(bool(request.json.get("compression", False)))
|
837 |
+
return jsonify({"status": "updated"})
|
838 |
+
else:
|
839 |
+
if MCP_SERVER_ADDR:
|
840 |
+
return jsonify(mcp_get("/compression"))
|
841 |
+
return jsonify({"compression": manager.use_compression})
|
842 |
+
|
843 |
+
|
844 |
+
@app.route("/qat", methods=["GET", "POST"])
|
845 |
+
def update_qat():
|
846 |
+
if request.method == "POST":
|
847 |
+
if MCP_SERVER_ADDR:
|
848 |
+
return jsonify(mcp_post("/qat", request.json))
|
849 |
+
manager.set_qat(bool(request.json.get("qat", False)))
|
850 |
+
return jsonify({"status": "updated"})
|
851 |
+
else:
|
852 |
+
if MCP_SERVER_ADDR:
|
853 |
+
return jsonify(mcp_get("/qat"))
|
854 |
+
return jsonify({"qat": manager.qat})
|
855 |
+
|
856 |
+
|
857 |
+
@app.route("/infer", methods=["POST"])
|
858 |
+
def inference():
|
859 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
860 |
+
if MCP_SERVER_ADDR:
|
861 |
+
data = mcp_post("/infer", {"bits": request.json["bits"]})
|
862 |
+
return jsonify(data)
|
863 |
+
result = manager.infer(bits)
|
864 |
+
return jsonify(result)
|
865 |
+
|
866 |
+
|
867 |
+
@app.route("/infer_long", methods=["POST"])
|
868 |
+
def inference_long():
|
869 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
870 |
+
ctx = int(request.json.get("ctx_bits", 4096))
|
871 |
+
overlap = int(request.json.get("overlap", 256))
|
872 |
+
if MCP_SERVER_ADDR:
|
873 |
+
data = mcp_post(
|
874 |
+
"/infer_long",
|
875 |
+
{"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap},
|
876 |
+
)
|
877 |
+
return jsonify(data)
|
878 |
+
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
|
879 |
+
return jsonify(result)
|
880 |
+
|
881 |
+
|
882 |
+
@app.route("/infer_text", methods=["POST"])
|
883 |
+
def inference_text():
|
884 |
+
text = request.json.get("text", "")
|
885 |
+
if MCP_SERVER_ADDR:
|
886 |
+
data = mcp_post("/infer_text", {"text": text})
|
887 |
+
return jsonify(data)
|
888 |
+
result = manager.infer_text(text)
|
889 |
+
return jsonify(result)
|
890 |
+
|
891 |
+
@app.route("/plot.png")
|
892 |
+
def plot_png():
|
893 |
+
if MCP_SERVER_ADDR:
|
894 |
+
resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png")
|
895 |
+
resp.raise_for_status()
|
896 |
+
return send_file(io.BytesIO(resp.content), mimetype="image/png")
|
897 |
+
fig, _ = plot_telemetry(manager.metrics)
|
898 |
+
buf = io.BytesIO()
|
899 |
+
fig.savefig(buf, format="png")
|
900 |
+
plt.close(fig)
|
901 |
+
buf.seek(0)
|
902 |
+
return send_file(buf, mimetype="image/png")
|
903 |
+
|
904 |
+
|
905 |
+
def run_dashboard(host: str | None = None, port: int | None = None,
|
906 |
+
snapshot_dir: str | None = None, telemetry_log: str | None = None) -> None:
|
907 |
+
"""Launch the Flask dashboard server."""
|
908 |
+
env_host = os.getenv("HOST", "0.0.0.0")
|
909 |
+
env_port = int(os.getenv("PORT", "5000"))
|
910 |
+
host = host or env_host
|
911 |
+
port = port or env_port
|
912 |
+
global manager
|
913 |
+
if manager is None:
|
914 |
+
manager = ModelManager(snapshot_dir, telemetry_log)
|
915 |
+
app.run(host=host, port=port, debug=True)
|
916 |
+
|
917 |
+
|
918 |
+
if __name__ == "__main__":
|
919 |
+
import argparse
|
920 |
+
|
921 |
+
parser = argparse.ArgumentParser(description="Run dashboard server")
|
922 |
+
parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0"))
|
923 |
+
parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000")))
|
924 |
+
parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots"))
|
925 |
+
parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG"))
|
926 |
+
args = parser.parse_args()
|
927 |
+
run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log)
|
BitTransformerLM/bit_transformer/distil.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .model import BitTransformerLM
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class TelemetryLog:
|
14 |
+
"""Telemetry container holding attention maps across steps.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
attention_maps: Tensor of shape [steps, heads, seq, seq].
|
18 |
+
"""
|
19 |
+
|
20 |
+
attention_maps: torch.Tensor
|
21 |
+
|
22 |
+
|
23 |
+
def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM:
|
24 |
+
"""Return a pruned copy of ``model`` according to attention telemetry.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model: Teacher model to distill from.
|
28 |
+
scale: Fraction of weights to retain (0 < scale <= 1).
|
29 |
+
telemetry: Logged attention maps used to estimate parameter importance.
|
30 |
+
|
31 |
+
This function computes an importance score for each weight in the model's
|
32 |
+
linear layers using the supplied attention maps. The score is the mean
|
33 |
+
activation over time multiplied by the number of visits (non-zero
|
34 |
+
attention). The bottom ``(1 - scale)`` fraction of weights in each layer are
|
35 |
+
zeroed out, yielding a sparsified student model.
|
36 |
+
"""
|
37 |
+
if not (0.0 < scale <= 1.0):
|
38 |
+
raise ValueError("scale must lie in (0, 1].")
|
39 |
+
|
40 |
+
# Clone the model so the teacher remains untouched.
|
41 |
+
student = BitTransformerLM(
|
42 |
+
d_model=model.d_model,
|
43 |
+
nhead=model.layers[0].self_attn.num_heads,
|
44 |
+
num_layers=model.num_layers,
|
45 |
+
dim_feedforward=model.layers[0].linear1.out_features,
|
46 |
+
max_seq_len=model.pos_enc.pe.size(0),
|
47 |
+
lambda_K=model.lambda_K,
|
48 |
+
lambda_C=model.lambda_C,
|
49 |
+
lambda_S=model.lambda_S,
|
50 |
+
reversible=model.reversible,
|
51 |
+
use_checkpoint=model.use_checkpoint,
|
52 |
+
use_autocast=model.use_autocast,
|
53 |
+
use_act=model.use_act,
|
54 |
+
act_threshold=model.act_threshold,
|
55 |
+
chunk_size=model.chunk_size,
|
56 |
+
overlap=model.overlap,
|
57 |
+
)
|
58 |
+
student.load_state_dict(model.state_dict())
|
59 |
+
|
60 |
+
attn = telemetry.attention_maps # [steps, heads, seq, seq]
|
61 |
+
steps = attn.shape[0]
|
62 |
+
heads = attn.shape[1]
|
63 |
+
mean_act = attn.mean(dim=(0, 2, 3))
|
64 |
+
visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1)
|
65 |
+
head_importance = mean_act * visits
|
66 |
+
head_importance = head_importance / head_importance.sum()
|
67 |
+
|
68 |
+
prune_frac = 1.0 - scale
|
69 |
+
|
70 |
+
for module in student.modules():
|
71 |
+
if isinstance(module, nn.Linear):
|
72 |
+
weight = module.weight.data
|
73 |
+
out_features = weight.size(0)
|
74 |
+
if out_features % heads == 0:
|
75 |
+
repeats = out_features // heads
|
76 |
+
row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1)
|
77 |
+
else:
|
78 |
+
row_scores = head_importance.mean().expand(out_features, 1)
|
79 |
+
|
80 |
+
importance = weight.abs() * row_scores
|
81 |
+
k = int(importance.numel() * prune_frac)
|
82 |
+
if k > 0:
|
83 |
+
thresh = torch.topk(importance.view(-1), k, largest=False).values.max()
|
84 |
+
mask = importance > thresh
|
85 |
+
weight.mul_(mask)
|
86 |
+
if module.bias is not None:
|
87 |
+
row_mask = mask.view(out_features, -1).any(dim=1)
|
88 |
+
module.bias.data.mul_(row_mask)
|
89 |
+
|
90 |
+
return student
|
BitTransformerLM/bit_transformer/distributed.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
6 |
+
try:
|
7 |
+
from torch.distributed.pipeline.sync import Pipe
|
8 |
+
except Exception: # pragma: no cover - Pipe may not be available in CPU builds
|
9 |
+
Pipe = None
|
10 |
+
|
11 |
+
from .model import BitTransformerLM
|
12 |
+
|
13 |
+
|
14 |
+
def wrap_fsdp(model: BitTransformerLM, **kwargs) -> FullyShardedDataParallel:
|
15 |
+
"""Return a ``FullyShardedDataParallel`` wrapped model on the given device."""
|
16 |
+
device = kwargs.pop("device_id", torch.device("cpu"))
|
17 |
+
model = model.to(device)
|
18 |
+
return FullyShardedDataParallel(model, device_id=device, **kwargs)
|
19 |
+
|
20 |
+
|
21 |
+
def make_pipeline(model: BitTransformerLM, chunks: int = 1) -> Pipe:
|
22 |
+
"""Wrap the model with ``Pipe`` for simple pipeline parallelism.
|
23 |
+
|
24 |
+
The entire model is placed in an ``nn.Sequential`` so all existing telemetry
|
25 |
+
remains available. ``chunks`` controls microbatch splitting.
|
26 |
+
"""
|
27 |
+
if Pipe is None:
|
28 |
+
raise RuntimeError("Pipeline parallelism not available in this build")
|
29 |
+
seq = nn.Sequential(model)
|
30 |
+
return Pipe(seq, chunks=chunks)
|
BitTransformerLM/bit_transformer/hf_checkpoint.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import gzip
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from huggingface_hub import HfApi, hf_hub_download, login
|
11 |
+
|
12 |
+
REPO_ID = "architect/bittransformerlm"
|
13 |
+
FILENAME = "model.pt.gz"
|
14 |
+
|
15 |
+
|
16 |
+
def hf_login(token: Optional[str] = None) -> None:
|
17 |
+
"""Authenticate with Hugging Face.
|
18 |
+
|
19 |
+
The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
|
20 |
+
variable. If omitted entirely, the library will attempt an interactive login.
|
21 |
+
"""
|
22 |
+
login(token=token)
|
23 |
+
|
24 |
+
|
25 |
+
def save_checkpoint(
|
26 |
+
model: torch.nn.Module,
|
27 |
+
*,
|
28 |
+
repo_id: str = REPO_ID,
|
29 |
+
filename: str = FILENAME,
|
30 |
+
) -> None:
|
31 |
+
"""Upload the model weights to ``repo_id`` under ``filename``.
|
32 |
+
|
33 |
+
The file within the repository is overwritten each time to avoid
|
34 |
+
accumulating checkpoints.
|
35 |
+
"""
|
36 |
+
with tempfile.TemporaryDirectory() as tmp:
|
37 |
+
tmp_pt = os.path.join(tmp, "model.pt")
|
38 |
+
tmp_gz = os.path.join(tmp, filename)
|
39 |
+
torch.save(model.state_dict(), tmp_pt)
|
40 |
+
with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
|
41 |
+
dst.write(src.read())
|
42 |
+
HfApi().upload_file(
|
43 |
+
path_or_fileobj=tmp_gz,
|
44 |
+
path_in_repo=f"checkpoints/{filename}",
|
45 |
+
repo_id=repo_id,
|
46 |
+
repo_type="model",
|
47 |
+
overwrite=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def download_checkpoint(
|
52 |
+
dest_path: str,
|
53 |
+
*,
|
54 |
+
repo_id: str = REPO_ID,
|
55 |
+
filename: str = FILENAME,
|
56 |
+
) -> bool:
|
57 |
+
"""Download the latest checkpoint to ``dest_path``.
|
58 |
+
|
59 |
+
Returns ``True`` if the checkpoint was successfully retrieved.
|
60 |
+
"""
|
61 |
+
try:
|
62 |
+
buf = hf_hub_download(
|
63 |
+
repo_id,
|
64 |
+
f"checkpoints/{filename}",
|
65 |
+
repo_type="model",
|
66 |
+
force_download=True,
|
67 |
+
)
|
68 |
+
except Exception as exc: # pragma: no cover - network errors
|
69 |
+
print("Failed to download checkpoint", exc)
|
70 |
+
return False
|
71 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
72 |
+
shutil.copyfile(buf, dest_path)
|
73 |
+
return True
|
74 |
+
|
75 |
+
|
76 |
+
__all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]
|
BitTransformerLM/bit_transformer/model.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import contextlib
|
3 |
+
import logging
|
4 |
+
from typing import Dict, List, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
import sys
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
|
13 |
+
from .torch_utils import cpu_autocast
|
14 |
+
|
15 |
+
from .optimization import configure_optimizer
|
16 |
+
from .compression import decompress_bits
|
17 |
+
from .parity import enforce_parity
|
18 |
+
|
19 |
+
_mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
|
20 |
+
|
21 |
+
|
22 |
+
def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
23 |
+
"""Return or create a cached upper-triangular mask."""
|
24 |
+
key = (seq_len, device)
|
25 |
+
if key not in _mask_cache:
|
26 |
+
_mask_cache[key] = torch.triu(
|
27 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
|
28 |
+
)
|
29 |
+
return _mask_cache[key]
|
30 |
+
|
31 |
+
try: # torch.compile may not work on all Python versions
|
32 |
+
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
|
33 |
+
compile_fn = torch.compile
|
34 |
+
else:
|
35 |
+
raise RuntimeError
|
36 |
+
except Exception: # pragma: no cover - handle missing torch or unsupported version
|
37 |
+
|
38 |
+
def compile_fn(fn=None, **kwargs):
|
39 |
+
if fn is None:
|
40 |
+
return lambda f: f
|
41 |
+
return fn
|
42 |
+
|
43 |
+
|
44 |
+
class PositionalEncoding(nn.Module):
|
45 |
+
"""Sinusoidal positional encoding."""
|
46 |
+
|
47 |
+
def __init__(self, d_model: int, max_len: int = 1024) -> None:
|
48 |
+
super().__init__()
|
49 |
+
pe = torch.zeros(max_len, d_model)
|
50 |
+
pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
51 |
+
inv = torch.exp(
|
52 |
+
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
53 |
+
)
|
54 |
+
pe[:, 0::2] = torch.sin(pos * inv)
|
55 |
+
pe[:, 1::2] = torch.cos(pos * inv)
|
56 |
+
self.register_buffer("pe", pe.unsqueeze(1))
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
+
"""Add positional encoding to input tensor."""
|
60 |
+
return x + self.pe[: x.size(0)]
|
61 |
+
|
62 |
+
|
63 |
+
class LoggingTransformerEncoderLayer(nn.Module):
|
64 |
+
"""Transformer encoder layer that exposes attention weights.
|
65 |
+
|
66 |
+
It optionally performs chunked attention with a fixed window size.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
d_model: int,
|
72 |
+
nhead: int,
|
73 |
+
dim_feedforward: int = 512,
|
74 |
+
dropout: float = 0.1,
|
75 |
+
chunk_size: int | None = None,
|
76 |
+
overlap: int = 0,
|
77 |
+
full_attn_logging: bool | None = None,
|
78 |
+
) -> None:
|
79 |
+
super().__init__()
|
80 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
81 |
+
self.chunk_size = chunk_size
|
82 |
+
self.overlap = overlap
|
83 |
+
if full_attn_logging is None:
|
84 |
+
full_attn_logging = False if chunk_size is not None else True
|
85 |
+
self.full_attn_logging = full_attn_logging
|
86 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
87 |
+
self.dropout = nn.Dropout(dropout)
|
88 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
89 |
+
self.norm1 = nn.LayerNorm(d_model)
|
90 |
+
self.norm2 = nn.LayerNorm(d_model)
|
91 |
+
self.dropout1 = nn.Dropout(dropout)
|
92 |
+
self.dropout2 = nn.Dropout(dropout)
|
93 |
+
self.activation = F.relu
|
94 |
+
|
95 |
+
def _chunked_attn(
|
96 |
+
self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
|
97 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
98 |
+
"""Perform chunked self attention with overlap."""
|
99 |
+
T, B, D = src.shape
|
100 |
+
src_b = src.transpose(0, 1) # [B, T, D]
|
101 |
+
C = self.chunk_size or T
|
102 |
+
O = self.overlap
|
103 |
+
n_chunks = (T + C - 1) // C
|
104 |
+
pad_len = n_chunks * C - T
|
105 |
+
src_pad = F.pad(src_b, (0, 0, O, pad_len + O))
|
106 |
+
chunk_len = C + 2 * O
|
107 |
+
chunks = src_pad.unfold(1, chunk_len, C) # [B, n_chunks, chunk_len, D]
|
108 |
+
mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None
|
109 |
+
out, weights = self.self_attn(
|
110 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
111 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
112 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
113 |
+
attn_mask=mask,
|
114 |
+
need_weights=True,
|
115 |
+
average_attn_weights=False,
|
116 |
+
)
|
117 |
+
out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
|
118 |
+
weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
|
119 |
+
:, :, :, O : O + C
|
120 |
+
]
|
121 |
+
seq = out.reshape(B, n_chunks * C, D)[:, :T]
|
122 |
+
if self.full_attn_logging and C < T:
|
123 |
+
full_attn = torch.zeros(
|
124 |
+
B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=src.device
|
125 |
+
)
|
126 |
+
for idx in range(n_chunks):
|
127 |
+
s = idx * C
|
128 |
+
start = max(s - O, 0)
|
129 |
+
end = min(s + C, n_chunks * C)
|
130 |
+
src_start = O - (s - start)
|
131 |
+
src_end = src_start + (end - start)
|
132 |
+
full_attn[:, :, s : s + C, start:end] = weights[:, idx, :, src_start:src_end]
|
133 |
+
full_attn = full_attn[:, :, :T, :T]
|
134 |
+
attn_out = full_attn.detach()
|
135 |
+
else:
|
136 |
+
attn_out = torch.empty(0, device=src.device)
|
137 |
+
return seq.transpose(0, 1), attn_out
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
|
141 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
142 |
+
"""Return output and attention map."""
|
143 |
+
if self.chunk_size is not None:
|
144 |
+
attn_output, attn_weights = self._chunked_attn(src, attn_mask)
|
145 |
+
else:
|
146 |
+
qkv = src.transpose(0, 1)
|
147 |
+
attn_output, attn_weights = self.self_attn(
|
148 |
+
qkv,
|
149 |
+
qkv,
|
150 |
+
qkv,
|
151 |
+
attn_mask=attn_mask,
|
152 |
+
need_weights=True,
|
153 |
+
average_attn_weights=False,
|
154 |
+
)
|
155 |
+
attn_output = attn_output.transpose(0, 1)
|
156 |
+
src = src + self.dropout1(attn_output)
|
157 |
+
src = self.norm1(src)
|
158 |
+
out = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
159 |
+
src = src + self.dropout2(out)
|
160 |
+
src = self.norm2(src)
|
161 |
+
return src, attn_weights.detach()
|
162 |
+
|
163 |
+
|
164 |
+
class ReversibleLoggingTransformerEncoderLayer(nn.Module):
|
165 |
+
"""Reversible transformer encoder layer with checkpointing."""
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
d_model: int,
|
170 |
+
nhead: int,
|
171 |
+
dim_feedforward: int = 512,
|
172 |
+
dropout: float = 0.1,
|
173 |
+
chunk_size: int | None = None,
|
174 |
+
overlap: int = 0,
|
175 |
+
full_attn_logging: bool | None = None,
|
176 |
+
) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
179 |
+
self.chunk_size = chunk_size
|
180 |
+
self.overlap = overlap
|
181 |
+
if full_attn_logging is None:
|
182 |
+
full_attn_logging = False if chunk_size is not None else True
|
183 |
+
self.full_attn_logging = full_attn_logging
|
184 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
185 |
+
self.dropout = nn.Dropout(dropout)
|
186 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
187 |
+
self.norm1 = nn.LayerNorm(d_model)
|
188 |
+
self.norm2 = nn.LayerNorm(d_model)
|
189 |
+
self.dropout1 = nn.Dropout(dropout)
|
190 |
+
self.dropout2 = nn.Dropout(dropout)
|
191 |
+
self.activation = F.relu
|
192 |
+
|
193 |
+
@compile_fn
|
194 |
+
def _sa_block(
|
195 |
+
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
|
196 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
197 |
+
if self.chunk_size is not None:
|
198 |
+
T, B, D = x.shape
|
199 |
+
x_b = x.transpose(0, 1)
|
200 |
+
C = self.chunk_size or T
|
201 |
+
O = self.overlap
|
202 |
+
n_chunks = (T + C - 1) // C
|
203 |
+
pad_len = n_chunks * C - T
|
204 |
+
src_pad = F.pad(x_b, (0, 0, O, pad_len + O))
|
205 |
+
chunk_len = C + 2 * O
|
206 |
+
chunks = src_pad.unfold(1, chunk_len, C)
|
207 |
+
mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None
|
208 |
+
out, weights = self.self_attn(
|
209 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
210 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
211 |
+
chunks.reshape(B * n_chunks, chunk_len, D),
|
212 |
+
attn_mask=mask,
|
213 |
+
need_weights=True,
|
214 |
+
average_attn_weights=False,
|
215 |
+
)
|
216 |
+
out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
|
217 |
+
weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
|
218 |
+
:, :, :, O : O + C
|
219 |
+
]
|
220 |
+
seq = out.reshape(B, n_chunks * C, D)[:, :T]
|
221 |
+
if self.full_attn_logging and C < T:
|
222 |
+
full_attn = torch.zeros(
|
223 |
+
B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device
|
224 |
+
)
|
225 |
+
for idx in range(n_chunks):
|
226 |
+
s = idx * C
|
227 |
+
start = max(s - O, 0)
|
228 |
+
end = min(s + C, n_chunks * C)
|
229 |
+
src_start = O - (s - start)
|
230 |
+
src_end = src_start + (end - start)
|
231 |
+
full_attn[:, :, s : s + C, start:end] = weights[
|
232 |
+
:, idx, :, src_start:src_end
|
233 |
+
]
|
234 |
+
full_attn = full_attn[:, :, :T, :T]
|
235 |
+
weights = full_attn.detach()
|
236 |
+
else:
|
237 |
+
weights = torch.empty(0, device=x.device)
|
238 |
+
attn_out = seq.transpose(0, 1)
|
239 |
+
else:
|
240 |
+
qkv = x.transpose(0, 1)
|
241 |
+
attn_out, weights = self.self_attn(
|
242 |
+
qkv,
|
243 |
+
qkv,
|
244 |
+
qkv,
|
245 |
+
attn_mask=attn_mask,
|
246 |
+
need_weights=True,
|
247 |
+
average_attn_weights=False,
|
248 |
+
)
|
249 |
+
attn_out = attn_out.transpose(0, 1)
|
250 |
+
x = self.norm1(x + self.dropout1(attn_out))
|
251 |
+
return x, weights.detach()
|
252 |
+
|
253 |
+
@compile_fn
|
254 |
+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
255 |
+
out = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
256 |
+
x = self.norm2(x + self.dropout2(out))
|
257 |
+
return x
|
258 |
+
|
259 |
+
def forward(
|
260 |
+
self,
|
261 |
+
x1: torch.Tensor,
|
262 |
+
x2: torch.Tensor,
|
263 |
+
attn_mask: torch.Tensor | None = None,
|
264 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
265 |
+
y1, weights = self._sa_block(x2, attn_mask)
|
266 |
+
y1 = x1 + y1
|
267 |
+
y2 = x2 + self._ff_block(y1)
|
268 |
+
return y1, y2, weights
|
269 |
+
|
270 |
+
|
271 |
+
class BitTransformerLM(nn.Module):
|
272 |
+
"""Transformer language model that operates on raw bits (0/1) with telemetry."""
|
273 |
+
|
274 |
+
def __init__(
|
275 |
+
self,
|
276 |
+
d_model: int = 128,
|
277 |
+
nhead: int = 8,
|
278 |
+
num_layers: int = 4,
|
279 |
+
dim_feedforward: int = 512,
|
280 |
+
max_seq_len: int = 1024,
|
281 |
+
lambda_K: float = 1.0,
|
282 |
+
lambda_C: float = 1.0,
|
283 |
+
lambda_S: float = 1.0,
|
284 |
+
reversible: bool = False,
|
285 |
+
use_checkpoint: bool = True,
|
286 |
+
use_autocast: bool = False,
|
287 |
+
use_act: bool = False,
|
288 |
+
act_threshold: float = 0.9,
|
289 |
+
chunk_size: int | None = None,
|
290 |
+
overlap: int = 0,
|
291 |
+
full_attn_logging: bool | None = None,
|
292 |
+
) -> None:
|
293 |
+
"""Create a BitTransformer language model.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
full_attn_logging: When ``False`` and ``chunk_size`` is
|
297 |
+
smaller than the sequence length, the model skips
|
298 |
+
reconstructing the full ``T×T`` attention matrices for
|
299 |
+
telemetry to reduce memory use.
|
300 |
+
"""
|
301 |
+
super().__init__()
|
302 |
+
self.d_model = d_model
|
303 |
+
self.num_layers = num_layers
|
304 |
+
self.lambda_K = lambda_K
|
305 |
+
self.lambda_C = lambda_C
|
306 |
+
self.lambda_S = lambda_S
|
307 |
+
self.reversible = reversible
|
308 |
+
self.use_checkpoint = use_checkpoint
|
309 |
+
self.use_autocast = use_autocast
|
310 |
+
self.use_act = use_act
|
311 |
+
self.act_threshold = act_threshold
|
312 |
+
self.chunk_size = chunk_size
|
313 |
+
self.overlap = overlap
|
314 |
+
if full_attn_logging is None:
|
315 |
+
full_attn_logging = False if chunk_size is not None else True
|
316 |
+
self.full_attn_logging = full_attn_logging
|
317 |
+
|
318 |
+
# Bit embedding: two possible input values
|
319 |
+
self.embedding = nn.Embedding(2, d_model)
|
320 |
+
self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len)
|
321 |
+
|
322 |
+
layer_cls = (
|
323 |
+
ReversibleLoggingTransformerEncoderLayer
|
324 |
+
if reversible
|
325 |
+
else LoggingTransformerEncoderLayer
|
326 |
+
)
|
327 |
+
self.layers = nn.ModuleList(
|
328 |
+
[
|
329 |
+
layer_cls(
|
330 |
+
d_model=d_model,
|
331 |
+
nhead=nhead,
|
332 |
+
dim_feedforward=dim_feedforward,
|
333 |
+
chunk_size=chunk_size,
|
334 |
+
overlap=overlap,
|
335 |
+
full_attn_logging=full_attn_logging,
|
336 |
+
)
|
337 |
+
for _ in range(num_layers)
|
338 |
+
]
|
339 |
+
)
|
340 |
+
|
341 |
+
if self.use_act:
|
342 |
+
self.halt_projs = nn.ModuleList(
|
343 |
+
[nn.Linear(d_model, 1) for _ in range(num_layers)]
|
344 |
+
)
|
345 |
+
|
346 |
+
self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1
|
347 |
+
|
348 |
+
def expand_positional_encoding(self, new_len: int) -> None:
|
349 |
+
"""Expand positional encoding to at least ``new_len``."""
|
350 |
+
cur_len = self.pos_enc.pe.size(0)
|
351 |
+
if new_len <= cur_len:
|
352 |
+
return
|
353 |
+
device = self.pos_enc.pe.device
|
354 |
+
d_model = self.d_model
|
355 |
+
pe = torch.zeros(new_len, d_model, device=device)
|
356 |
+
pe[:cur_len] = self.pos_enc.pe.squeeze(1)
|
357 |
+
pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1)
|
358 |
+
inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model))
|
359 |
+
pe[cur_len:, 0::2] = torch.sin(pos * inv)
|
360 |
+
pe[cur_len:, 1::2] = torch.cos(pos * inv)
|
361 |
+
self.pos_enc.pe = pe.unsqueeze(1)
|
362 |
+
|
363 |
+
def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None:
|
364 |
+
"""Update weighting coefficients for telemetry metrics."""
|
365 |
+
self.lambda_K = lambda_K
|
366 |
+
self.lambda_C = lambda_C
|
367 |
+
self.lambda_S = lambda_S
|
368 |
+
|
369 |
+
def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor:
|
370 |
+
"""Return raw bit sequences, decompressing if input appears run-length encoded."""
|
371 |
+
if codes.dim() <= 1:
|
372 |
+
return codes
|
373 |
+
needs_decompress = codes.max().item() > 1
|
374 |
+
if not needs_decompress and codes.size(1) % 2 == 0:
|
375 |
+
vals = codes[:, 0::2]
|
376 |
+
if torch.all(vals[:, 1:] != vals[:, :-1]):
|
377 |
+
needs_decompress = True
|
378 |
+
if not needs_decompress:
|
379 |
+
return codes
|
380 |
+
seqs = [decompress_bits(row.to(torch.uint8)) for row in codes]
|
381 |
+
max_len = max(seq.numel() for seq in seqs)
|
382 |
+
padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs]
|
383 |
+
return torch.stack(padded)
|
384 |
+
|
385 |
+
def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor:
|
386 |
+
"""Approximate negentropy of bit sequences.
|
387 |
+
|
388 |
+
Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered
|
389 |
+
sequence (all zeros or ones) and ``0`` reflects maximal entropy.
|
390 |
+
"""
|
391 |
+
codes = self._maybe_decompress(codes)
|
392 |
+
p = codes.float().mean(dim=1)
|
393 |
+
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
|
394 |
+
max_e = math.log(2.0)
|
395 |
+
return 1 - entropy / max_e
|
396 |
+
|
397 |
+
def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor:
|
398 |
+
"""Differentiable proxy for Lempel–Ziv complexity.
|
399 |
+
|
400 |
+
Values near ``0`` indicate highly compressible sequences while values
|
401 |
+
approaching ``1`` correspond to rapid bit alternation.
|
402 |
+
"""
|
403 |
+
codes = self._maybe_decompress(codes)
|
404 |
+
diffs = torch.abs(codes[:, 1:] - codes[:, :-1])
|
405 |
+
return diffs.float().mean(dim=1)
|
406 |
+
|
407 |
+
def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
|
408 |
+
"""Negentropy computed from model logits.
|
409 |
+
|
410 |
+
Parameters
|
411 |
+
----------
|
412 |
+
logits: ``torch.Tensor``
|
413 |
+
Logit tensor of shape ``(B, T, 2)``.
|
414 |
+
detach: bool, default ``True``
|
415 |
+
When ``True`` the computation is detached from the autograd graph.
|
416 |
+
"""
|
417 |
+
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
|
418 |
+
prob = logits.softmax(-1)
|
419 |
+
if detach:
|
420 |
+
prob = prob.detach()
|
421 |
+
p = prob[..., 1].mean(dim=1)
|
422 |
+
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
|
423 |
+
max_e = math.log(2.0)
|
424 |
+
return 1 - entropy / max_e
|
425 |
+
|
426 |
+
def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
|
427 |
+
"""LZ complexity proxy computed from logits.
|
428 |
+
|
429 |
+
Parameters
|
430 |
+
----------
|
431 |
+
logits: ``torch.Tensor``
|
432 |
+
Logit tensor of shape ``(B, T, 2)``.
|
433 |
+
detach: bool, default ``True``
|
434 |
+
When ``True`` the computation is detached from the autograd graph.
|
435 |
+
"""
|
436 |
+
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
|
437 |
+
prob = logits.softmax(-1)
|
438 |
+
if detach:
|
439 |
+
prob = prob.detach()
|
440 |
+
prob1 = prob[..., 1]
|
441 |
+
diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1])
|
442 |
+
return diffs.mean(dim=1)
|
443 |
+
|
444 |
+
def symbiosis_kl_logits(
|
445 |
+
self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True
|
446 |
+
) -> torch.Tensor:
|
447 |
+
"""Symbiosis score from KL divergence to a reference distribution.
|
448 |
+
|
449 |
+
Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with
|
450 |
+
the reference distribution and ``0`` indicating maximal divergence.
|
451 |
+
"""
|
452 |
+
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
|
453 |
+
probs = logits.softmax(-1)
|
454 |
+
if detach:
|
455 |
+
probs = probs.detach()
|
456 |
+
ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device)
|
457 |
+
kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1)
|
458 |
+
max_kl = math.log(2.0)
|
459 |
+
return 1 - kl / max_kl
|
460 |
+
|
461 |
+
def _act_step(
|
462 |
+
self,
|
463 |
+
hidden: torch.Tensor,
|
464 |
+
idx: int,
|
465 |
+
halt_prob: torch.Tensor,
|
466 |
+
act_state: torch.Tensor,
|
467 |
+
halt_history: List[torch.Tensor],
|
468 |
+
) -> Tuple[torch.Tensor, torch.Tensor, bool]:
|
469 |
+
"""Apply one step of ACT halting logic."""
|
470 |
+
p = torch.sigmoid(self.halt_projs[idx](hidden))
|
471 |
+
delta = (1 - halt_prob) * p
|
472 |
+
halt_prob = halt_prob + delta
|
473 |
+
act_state = act_state + hidden * delta
|
474 |
+
halt_history.append(halt_prob.detach())
|
475 |
+
min_prob = halt_prob.detach().min()
|
476 |
+
if dist.is_initialized():
|
477 |
+
dist.all_reduce(min_prob, op=dist.ReduceOp.MIN)
|
478 |
+
return halt_prob, act_state, min_prob.item() >= self.act_threshold
|
479 |
+
|
480 |
+
def forward(
|
481 |
+
self, bit_seq: torch.Tensor, causal: bool = True
|
482 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
483 |
+
"""Forward pass returning logits and telemetry from the same graph.
|
484 |
+
|
485 |
+
By default the model uses causal masking and (optional) chunked
|
486 |
+
attention. When ``causal`` is ``False`` the model operates in
|
487 |
+
"Diffusion LM" mode. In this mode chunked attention is temporarily
|
488 |
+
disabled so that every token can attend to the full sequence
|
489 |
+
bidirectionally. The original chunking configuration is restored after
|
490 |
+
the forward pass.
|
491 |
+
"""
|
492 |
+
|
493 |
+
# Disable chunking when running in bidirectional (non-causal) mode
|
494 |
+
orig_chunks = None
|
495 |
+
orig_model_chunk = None
|
496 |
+
if not causal and self.chunk_size is not None:
|
497 |
+
orig_model_chunk = self.chunk_size
|
498 |
+
orig_chunks = [layer.chunk_size for layer in self.layers]
|
499 |
+
self.chunk_size = None
|
500 |
+
for layer in self.layers:
|
501 |
+
layer.chunk_size = None
|
502 |
+
|
503 |
+
try:
|
504 |
+
ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext()
|
505 |
+
with ctx:
|
506 |
+
x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model)
|
507 |
+
x = self.pos_enc(x)
|
508 |
+
|
509 |
+
attn_mask = get_tri_mask(x.size(0), x.device) if causal else None
|
510 |
+
|
511 |
+
activations: List[torch.Tensor] = []
|
512 |
+
attn_maps: List[torch.Tensor] = []
|
513 |
+
halt_history: List[torch.Tensor] = []
|
514 |
+
if self.use_act:
|
515 |
+
halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device)
|
516 |
+
act_state = torch.zeros_like(x)
|
517 |
+
if self.reversible:
|
518 |
+
x1, x2 = x, x
|
519 |
+
for idx, layer in enumerate(self.layers):
|
520 |
+
if self.use_checkpoint:
|
521 |
+
x1, x2, attn = checkpoint.checkpoint(
|
522 |
+
layer, x1, x2, attn_mask
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
x1, x2, attn = layer(x1, x2, attn_mask)
|
526 |
+
combined = (x1 + x2) / 2
|
527 |
+
activations.append(combined)
|
528 |
+
if attn.numel() > 0:
|
529 |
+
attn_maps.append(attn)
|
530 |
+
if self.use_act:
|
531 |
+
halt_prob, act_state, should_break = self._act_step(
|
532 |
+
combined, idx, halt_prob, act_state, halt_history
|
533 |
+
)
|
534 |
+
if should_break:
|
535 |
+
break
|
536 |
+
x = (x1 + x2) / 2
|
537 |
+
else:
|
538 |
+
for idx, layer in enumerate(self.layers):
|
539 |
+
if self.use_checkpoint:
|
540 |
+
x, attn = checkpoint.checkpoint(layer, x, attn_mask)
|
541 |
+
else:
|
542 |
+
x, attn = layer(x, attn_mask)
|
543 |
+
activations.append(x)
|
544 |
+
if attn.numel() > 0:
|
545 |
+
attn_maps.append(attn)
|
546 |
+
if self.use_act:
|
547 |
+
halt_prob, act_state, should_break = self._act_step(
|
548 |
+
x, idx, halt_prob, act_state, halt_history
|
549 |
+
)
|
550 |
+
if should_break:
|
551 |
+
break
|
552 |
+
if self.use_act:
|
553 |
+
act_state = act_state + x * (1 - halt_prob)
|
554 |
+
x = act_state
|
555 |
+
logits = self.out_head(x)
|
556 |
+
|
557 |
+
# Per-layer entropy of activations
|
558 |
+
entropies = []
|
559 |
+
for act in activations:
|
560 |
+
prob = act.softmax(-1)
|
561 |
+
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean()
|
562 |
+
entropies.append(ent)
|
563 |
+
|
564 |
+
attn_entropies = []
|
565 |
+
for attn in attn_maps:
|
566 |
+
prob = attn # weights are already softmaxed
|
567 |
+
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1)
|
568 |
+
ent = ent.mean(1)
|
569 |
+
attn_entropies.append(ent)
|
570 |
+
if attn_entropies:
|
571 |
+
attn_entropy_map = torch.stack(attn_entropies).mean(0)
|
572 |
+
else:
|
573 |
+
attn_entropy_map = torch.zeros(
|
574 |
+
bit_seq.size(0), bit_seq.size(1), device=bit_seq.device
|
575 |
+
)
|
576 |
+
max_ent = math.log(attn_entropy_map.size(-1))
|
577 |
+
attn_entropy_map = attn_entropy_map / max_ent
|
578 |
+
attn_entropy = attn_entropy_map.mean(1)
|
579 |
+
|
580 |
+
logits_bt = logits.transpose(0, 1)
|
581 |
+
negentropy_in = self.negentropy_kpi(bit_seq)
|
582 |
+
lz_in = self.lz_complexity(bit_seq.float())
|
583 |
+
negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False)
|
584 |
+
lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False)
|
585 |
+
kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False)
|
586 |
+
|
587 |
+
raw_sym = (
|
588 |
+
(self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2
|
589 |
+
+ negentropy_logits_b * lz_logits_b
|
590 |
+
- self.lambda_S * kl_div_b
|
591 |
+
- 0.1 * attn_entropy
|
592 |
+
)
|
593 |
+
weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach()
|
594 |
+
raw_sym = raw_sym - 0.01 * weight_norm
|
595 |
+
sym_score = torch.sigmoid(raw_sym)
|
596 |
+
|
597 |
+
B, T = bit_seq.shape
|
598 |
+
assert logits_bt.shape[:2] == (B, T)
|
599 |
+
assert attn_entropy_map.shape == (B, T)
|
600 |
+
|
601 |
+
telemetry = {
|
602 |
+
"activations": activations,
|
603 |
+
"attention_maps": attn_maps,
|
604 |
+
"attention_entropy": attn_entropy_map,
|
605 |
+
"entropy": entropies,
|
606 |
+
"attention_entropy_mean": attn_entropy,
|
607 |
+
"negentropy_input": negentropy_in.detach(),
|
608 |
+
"lz_complexity_input": lz_in.detach(),
|
609 |
+
"negentropy_logits": negentropy_logits_b.detach(),
|
610 |
+
"lz_complexity_logits": lz_logits_b.detach(),
|
611 |
+
"symbiosis_kl": kl_div_b.detach(),
|
612 |
+
"symbiosis_score": sym_score.detach(),
|
613 |
+
}
|
614 |
+
if self.use_act:
|
615 |
+
telemetry["halt_probs"] = halt_history
|
616 |
+
|
617 |
+
return logits_bt, telemetry
|
618 |
+
finally:
|
619 |
+
if orig_chunks is not None:
|
620 |
+
self.chunk_size = orig_model_chunk
|
621 |
+
for layer, chunk in zip(self.layers, orig_chunks):
|
622 |
+
layer.chunk_size = chunk
|
623 |
+
|
624 |
+
def forward_compressed(
|
625 |
+
self, compressed_bits, causal: bool = True
|
626 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
627 |
+
"""Decompress bit sequences then run the normal forward pass."""
|
628 |
+
if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1:
|
629 |
+
sequences = [decompress_bits(compressed_bits).to(torch.long)]
|
630 |
+
else:
|
631 |
+
sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits]
|
632 |
+
lengths = [seq.numel() for seq in sequences]
|
633 |
+
if len(set(lengths)) != 1:
|
634 |
+
raise ValueError("Sequences decompress to different lengths")
|
635 |
+
bits = torch.stack(sequences)
|
636 |
+
return self.forward(bits, causal=causal)
|
637 |
+
|
638 |
+
def _current_params(self) -> Dict:
|
639 |
+
"""Return a dictionary with the current model hyperparameters."""
|
640 |
+
return {
|
641 |
+
"d_model": self.d_model,
|
642 |
+
"nhead": self.layers[0].self_attn.num_heads,
|
643 |
+
"num_layers": self.num_layers,
|
644 |
+
"dim_feedforward": self.layers[0].linear1.out_features,
|
645 |
+
"max_seq_len": self.pos_enc.pe.size(0),
|
646 |
+
"lambda_K": self.lambda_K,
|
647 |
+
"lambda_C": self.lambda_C,
|
648 |
+
"lambda_S": self.lambda_S,
|
649 |
+
"reversible": self.reversible,
|
650 |
+
"use_checkpoint": self.use_checkpoint,
|
651 |
+
"use_autocast": self.use_autocast,
|
652 |
+
"use_act": self.use_act,
|
653 |
+
"act_threshold": self.act_threshold,
|
654 |
+
"chunk_size": self.chunk_size,
|
655 |
+
"overlap": self.overlap,
|
656 |
+
}
|
657 |
+
|
658 |
+
def double_width(self) -> "BitTransformerLM":
|
659 |
+
"""Return a copy of the model with doubled hidden size."""
|
660 |
+
from .scale import expand_model
|
661 |
+
|
662 |
+
params = self._current_params()
|
663 |
+
params["d_model"] *= 2
|
664 |
+
params["dim_feedforward"] *= 2
|
665 |
+
return expand_model(self, params)
|
666 |
+
|
667 |
+
def double_layers(self) -> "BitTransformerLM":
|
668 |
+
"""Return a copy of the model with twice as many layers."""
|
669 |
+
from .scale import expand_model
|
670 |
+
|
671 |
+
params = self._current_params()
|
672 |
+
params["num_layers"] *= 2
|
673 |
+
return expand_model(self, params)
|
674 |
+
|
675 |
+
def double_length(self) -> "BitTransformerLM":
|
676 |
+
"""Return a copy of the model with doubled maximum sequence length."""
|
677 |
+
from .scale import expand_model
|
678 |
+
|
679 |
+
params = self._current_params()
|
680 |
+
params["max_seq_len"] *= 2
|
681 |
+
params["chunk_size"] = params["max_seq_len"]
|
682 |
+
return expand_model(self, params)
|
683 |
+
|
684 |
+
def train_full_sequence(
|
685 |
+
self,
|
686 |
+
bits: torch.Tensor,
|
687 |
+
*,
|
688 |
+
ctx_bits: int = 4096,
|
689 |
+
detach_every_n: int = 1_048_576,
|
690 |
+
) -> float:
|
691 |
+
"""Train on a long bit tensor using sliding windows.
|
692 |
+
|
693 |
+
Parameters
|
694 |
+
----------
|
695 |
+
bits: ``torch.Tensor``
|
696 |
+
1D tensor containing the full bit sequence.
|
697 |
+
ctx_bits: int
|
698 |
+
Size of the training context window.
|
699 |
+
detach_every_n: int
|
700 |
+
Interval in bits for optimizer updates and graph detachment.
|
701 |
+
Returns
|
702 |
+
-------
|
703 |
+
float
|
704 |
+
Mean loss over all windows.
|
705 |
+
"""
|
706 |
+
self.train()
|
707 |
+
optimizer, scheduler = configure_optimizer(
|
708 |
+
self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits)
|
709 |
+
)
|
710 |
+
accum = 0
|
711 |
+
total_loss = 0.0
|
712 |
+
count = 0
|
713 |
+
for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits):
|
714 |
+
segment = bits[start : start + ctx_bits + 1].unsqueeze(0)
|
715 |
+
logits, _ = self(segment)
|
716 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
717 |
+
target = segment[:, 1:].reshape(-1)
|
718 |
+
loss = F.cross_entropy(pred, target)
|
719 |
+
loss.backward()
|
720 |
+
accum += ctx_bits
|
721 |
+
total_loss += loss.item()
|
722 |
+
count += 1
|
723 |
+
if accum >= detach_every_n:
|
724 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
|
725 |
+
optimizer.step()
|
726 |
+
scheduler.step()
|
727 |
+
optimizer.zero_grad()
|
728 |
+
accum = 0
|
729 |
+
if accum > 0:
|
730 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
|
731 |
+
optimizer.step()
|
732 |
+
scheduler.step()
|
733 |
+
optimizer.zero_grad()
|
734 |
+
return total_loss / max(1, count)
|
735 |
+
|
736 |
+
|
737 |
+
def infer_long_sequence(
|
738 |
+
model: BitTransformerLM,
|
739 |
+
bits: torch.Tensor,
|
740 |
+
*,
|
741 |
+
ctx_bits: int = 4096,
|
742 |
+
overlap: int = 256,
|
743 |
+
) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]:
|
744 |
+
"""Infer a long bit sequence using sliding windows with overlap."""
|
745 |
+
model.eval()
|
746 |
+
device = next(model.parameters()).device
|
747 |
+
bits = bits.to(device)
|
748 |
+
step = ctx_bits - overlap
|
749 |
+
outputs: List[torch.Tensor] = []
|
750 |
+
logs: List[Dict[str, torch.Tensor]] = []
|
751 |
+
for start in range(0, bits.numel(), step):
|
752 |
+
window = bits[start : start + ctx_bits].unsqueeze(0)
|
753 |
+
logits, tele = model(window, causal=True)
|
754 |
+
pred = logits.argmax(-1).squeeze(0)
|
755 |
+
outputs.append(pred)
|
756 |
+
logs.append(tele)
|
757 |
+
out = torch.cat(outputs)[: bits.numel()]
|
758 |
+
return out, logs
|
759 |
+
|
760 |
+
|
761 |
+
def diffusion_inference(
|
762 |
+
model: BitTransformerLM,
|
763 |
+
*,
|
764 |
+
length: int,
|
765 |
+
steps: int = 8,
|
766 |
+
batch_size: int = 1,
|
767 |
+
init_bits: torch.Tensor | None = None,
|
768 |
+
schedule: str = "linear",
|
769 |
+
) -> torch.Tensor:
|
770 |
+
"""Generate bit sequences using iterative denoising diffusion.
|
771 |
+
|
772 |
+
Parameters
|
773 |
+
----------
|
774 |
+
model: ``BitTransformerLM``
|
775 |
+
The model used for denoising. It is run in non-causal mode with
|
776 |
+
chunked attention disabled, enabling full-context bidirectional
|
777 |
+
attention.
|
778 |
+
length: int
|
779 |
+
Length of the bit sequences to generate.
|
780 |
+
steps: int, default ``8``
|
781 |
+
Number of denoising iterations. More steps generally yield sharper
|
782 |
+
samples at the cost of compute.
|
783 |
+
batch_size: int, default ``1``
|
784 |
+
Number of sequences to generate in parallel.
|
785 |
+
init_bits: ``torch.Tensor`` | ``None``
|
786 |
+
Optional initial noisy bits of shape ``(batch_size, length)``. When
|
787 |
+
``None`` random noise is used.
|
788 |
+
schedule: str, default ``"linear"``
|
789 |
+
Noise schedule for the denoising mask probability. Options are
|
790 |
+
``"linear"``, ``"cosine"``, and ``"exp"``.
|
791 |
+
|
792 |
+
Returns
|
793 |
+
-------
|
794 |
+
``torch.Tensor``
|
795 |
+
A tensor of shape ``(batch_size, length)`` containing generated bits.
|
796 |
+
"""
|
797 |
+
|
798 |
+
model.eval()
|
799 |
+
device = next(model.parameters()).device
|
800 |
+
if init_bits is None:
|
801 |
+
bits = torch.randint(0, 2, (batch_size, length), device=device)
|
802 |
+
else:
|
803 |
+
bits = init_bits.to(device)
|
804 |
+
if bits.shape != (batch_size, length):
|
805 |
+
raise ValueError("init_bits must have shape (batch_size, length)")
|
806 |
+
|
807 |
+
for step in range(steps):
|
808 |
+
logits, _ = model(bits, causal=False)
|
809 |
+
prob = logits.softmax(-1)[..., 1]
|
810 |
+
t = (step + 1) / steps
|
811 |
+
if schedule == "linear":
|
812 |
+
mask_prob = 1.0 - t
|
813 |
+
elif schedule == "cosine":
|
814 |
+
mask_prob = math.cos(math.pi * t / 2)
|
815 |
+
elif schedule == "exp":
|
816 |
+
mask_prob = math.exp(-5 * t)
|
817 |
+
else:
|
818 |
+
raise ValueError(f"unknown schedule: {schedule}")
|
819 |
+
mask = (torch.rand_like(bits.float()) < mask_prob).long()
|
820 |
+
sampled = torch.bernoulli(prob).long()
|
821 |
+
bits = torch.where(mask.bool(), sampled, bits)
|
822 |
+
if bits.shape[-1] % 9 == 0:
|
823 |
+
bits, corrections = enforce_parity(bits)
|
824 |
+
if corrections:
|
825 |
+
logging.info("Parity corrections applied: %d", corrections)
|
826 |
+
try:
|
827 |
+
from .safety import hil_safe_inference
|
828 |
+
|
829 |
+
hil_safe_inference(model, bits, causal=False, strict=False)
|
830 |
+
except RuntimeError as exc:
|
831 |
+
logging.warning("Safety gate warning: %s", exc)
|
832 |
+
return bits
|
833 |
+
|
834 |
+
|
835 |
+
def example_usage() -> float:
|
836 |
+
"""Run the example from the README and return the loss."""
|
837 |
+
B, L = 4, 16
|
838 |
+
model = BitTransformerLM(
|
839 |
+
d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L
|
840 |
+
)
|
841 |
+
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
|
842 |
+
logits, _ = model(bits)
|
843 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
844 |
+
target = bits[:, 1:].reshape(-1)
|
845 |
+
loss = F.cross_entropy(pred, target)
|
846 |
+
return loss.item()
|
847 |
+
|
848 |
+
|
849 |
+
def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]:
|
850 |
+
"""Demonstrate a training step where metrics do not affect gradients."""
|
851 |
+
B, L = 4, 16
|
852 |
+
model = BitTransformerLM(
|
853 |
+
d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L
|
854 |
+
)
|
855 |
+
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1)
|
856 |
+
|
857 |
+
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
|
858 |
+
logits, telemetry = model(bits)
|
859 |
+
|
860 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
861 |
+
target = bits[:, 1:].reshape(-1)
|
862 |
+
loss = F.cross_entropy(pred, target)
|
863 |
+
|
864 |
+
loss.backward()
|
865 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
866 |
+
optimizer.step()
|
867 |
+
scheduler.step()
|
868 |
+
optimizer.zero_grad()
|
869 |
+
return loss.item(), telemetry
|
870 |
+
|
871 |
+
|
872 |
+
if __name__ == "__main__":
|
873 |
+
loss, telemetry = example_training_step()
|
874 |
+
print("Composite loss:", loss)
|
875 |
+
print("Telemetry keys:", list(telemetry.keys()))
|
BitTransformerLM/bit_transformer/optimization.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.optim import AdamW
|
4 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
5 |
+
|
6 |
+
|
7 |
+
def configure_optimizer(
|
8 |
+
model: nn.Module,
|
9 |
+
lr: float = 1e-3,
|
10 |
+
weight_decay: float = 0.01,
|
11 |
+
total_steps: int = 100
|
12 |
+
):
|
13 |
+
"""Return AdamW optimizer with OneCycleLR scheduler."""
|
14 |
+
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
15 |
+
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps)
|
16 |
+
return optimizer, scheduler
|
17 |
+
|
18 |
+
|
19 |
+
def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float:
|
20 |
+
"""Scale the learning rate of all param groups by ``factor``.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
optimizer:
|
25 |
+
The optimizer whose learning rate will be adjusted.
|
26 |
+
factor:
|
27 |
+
Multiplicative factor applied to the current learning rate.
|
28 |
+
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
float
|
32 |
+
The updated learning rate of the first parameter group.
|
33 |
+
"""
|
34 |
+
for param_group in optimizer.param_groups:
|
35 |
+
param_group["lr"] *= factor
|
36 |
+
return optimizer.param_groups[0]["lr"]
|
37 |
+
|
BitTransformerLM/bit_transformer/parity.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def enforce_parity(bits: torch.Tensor) -> tuple[torch.Tensor, int]:
|
4 |
+
"""Fix parity bits so each 9-bit chunk has even parity.
|
5 |
+
|
6 |
+
Parameters
|
7 |
+
----------
|
8 |
+
bits: ``torch.Tensor``
|
9 |
+
Tensor of shape ``(..., length)`` where ``length`` is a multiple of 9.
|
10 |
+
|
11 |
+
Returns
|
12 |
+
-------
|
13 |
+
tuple[torch.Tensor, int]
|
14 |
+
Corrected tensor and number of bytes that were adjusted.
|
15 |
+
"""
|
16 |
+
if bits.shape[-1] % 9 != 0:
|
17 |
+
raise ValueError("Bit stream length must be multiple of 9")
|
18 |
+
flat = bits.clone().view(-1, 9)
|
19 |
+
payload = flat[:, :8]
|
20 |
+
parity = flat[:, 8]
|
21 |
+
new_parity = payload.sum(dim=1) % 2
|
22 |
+
corrections = (parity != new_parity).sum().item()
|
23 |
+
flat[:, 8] = new_parity
|
24 |
+
return flat.view_as(bits), corrections
|
BitTransformerLM/bit_transformer/quantization.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from torch.ao.quantization.fake_quantize import FakeQuantize
|
5 |
+
from torch.ao.quantization.observer import MinMaxObserver
|
6 |
+
from torch.ao.quantization.qconfig import QConfig
|
7 |
+
from torch.ao.quantization import convert
|
8 |
+
|
9 |
+
from .model import BitTransformerLM
|
10 |
+
|
11 |
+
|
12 |
+
def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
|
13 |
+
"""Return a dynamically quantized copy of the model for inference."""
|
14 |
+
quantized = torch.quantization.quantize_dynamic(
|
15 |
+
model, {nn.Linear}, dtype=dtype
|
16 |
+
)
|
17 |
+
return quantized
|
18 |
+
|
19 |
+
|
20 |
+
class FourBitObserver(MinMaxObserver):
|
21 |
+
"""Min-max observer configured for 4-bit quantization."""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
super().__init__(
|
25 |
+
quant_min=0,
|
26 |
+
quant_max=15,
|
27 |
+
dtype=torch.quint8,
|
28 |
+
qscheme=torch.per_tensor_affine,
|
29 |
+
**kwargs,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)
|
34 |
+
|
35 |
+
four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)
|
36 |
+
|
37 |
+
|
38 |
+
class QATLinear(nn.Linear):
|
39 |
+
"""Linear layer with fake quantization for QAT."""
|
40 |
+
|
41 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
42 |
+
super().__init__(in_features, out_features, bias)
|
43 |
+
self.weight_fake_quant = FourBitFakeQuantize()
|
44 |
+
self.activation_post_process = FourBitFakeQuantize()
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_float(cls, mod: nn.Linear) -> "QATLinear":
|
48 |
+
qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
|
49 |
+
qat.weight = mod.weight
|
50 |
+
qat.bias = mod.bias
|
51 |
+
return qat
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
+
x = self.activation_post_process(x)
|
55 |
+
w = self.weight_fake_quant(self.weight)
|
56 |
+
return nn.functional.linear(x, w, self.bias)
|
57 |
+
|
58 |
+
|
59 |
+
def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
|
60 |
+
"""Prepare BitTransformerLM for quantization-aware training."""
|
61 |
+
|
62 |
+
for name, module in model.named_children():
|
63 |
+
if isinstance(module, nn.Linear):
|
64 |
+
setattr(model, name, QATLinear.from_float(module))
|
65 |
+
else:
|
66 |
+
prepare_qat_fx(module)
|
67 |
+
return model
|
68 |
+
|
69 |
+
|
70 |
+
def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
|
71 |
+
"""Convert a QAT-prepared model to a quantized version."""
|
72 |
+
|
73 |
+
for name, module in model.named_children():
|
74 |
+
if isinstance(module, QATLinear):
|
75 |
+
w = module.weight.data
|
76 |
+
qmin, qmax = 0, 15
|
77 |
+
min_w = w.min()
|
78 |
+
max_w = w.max()
|
79 |
+
scale = (max_w - min_w) / (qmax - qmin + 1e-8)
|
80 |
+
zero_point = qmin - torch.round(min_w / scale)
|
81 |
+
q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
|
82 |
+
new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
|
83 |
+
new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
|
84 |
+
if module.bias is not None:
|
85 |
+
new_mod.bias = nn.Parameter(module.bias.data)
|
86 |
+
setattr(model, name, new_mod)
|
87 |
+
else:
|
88 |
+
convert_qat_fx(module)
|
89 |
+
return model
|
BitTransformerLM/bit_transformer/safety.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
from typing import Dict, Optional, Tuple
|
5 |
+
|
6 |
+
from .model import BitTransformerLM
|
7 |
+
|
8 |
+
|
9 |
+
class SafetyGate:
|
10 |
+
"""Exponential moving average safety gate with burn-in."""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
*,
|
15 |
+
c_floor: float = 0.3,
|
16 |
+
s_floor: float = 0.5,
|
17 |
+
decay: float = 0.9,
|
18 |
+
burn_in: int = 10,
|
19 |
+
) -> None:
|
20 |
+
self.c_floor = c_floor
|
21 |
+
self.s_floor = s_floor
|
22 |
+
self.decay = decay
|
23 |
+
self.burn_in = burn_in
|
24 |
+
self.step = 0
|
25 |
+
self._c_ema: Optional[float] = None
|
26 |
+
self._s_ema: Optional[float] = None
|
27 |
+
|
28 |
+
def should_trigger(self, c_val: float, s_val: float) -> bool:
|
29 |
+
"""Update EMA scores and check if gating should trigger."""
|
30 |
+
|
31 |
+
self.step += 1
|
32 |
+
if self._c_ema is None:
|
33 |
+
self._c_ema = c_val
|
34 |
+
self._s_ema = s_val
|
35 |
+
else:
|
36 |
+
self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val
|
37 |
+
self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val
|
38 |
+
if self.step <= self.burn_in:
|
39 |
+
return False
|
40 |
+
return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor
|
41 |
+
|
42 |
+
|
43 |
+
def hil_safe_inference(
|
44 |
+
model: BitTransformerLM,
|
45 |
+
bit_seq: torch.Tensor,
|
46 |
+
c_floor: float = 0.3,
|
47 |
+
s_floor: float = 0.5,
|
48 |
+
*,
|
49 |
+
causal: bool = True,
|
50 |
+
strict: bool = True,
|
51 |
+
gate: Optional[SafetyGate] = None,
|
52 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
53 |
+
"""Run inference with telemetry gating.
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
model:
|
58 |
+
Model to run inference with.
|
59 |
+
bit_seq:
|
60 |
+
Input bit sequences.
|
61 |
+
c_floor, s_floor:
|
62 |
+
Minimum LZ complexity and symbiosis score required for safe output.
|
63 |
+
causal:
|
64 |
+
Whether to run the model in causal (autoregressive) mode. When ``False``
|
65 |
+
the model performs full-context Diffusion LM inference.
|
66 |
+
strict:
|
67 |
+
If ``False`` the function returns model outputs even when the floors are
|
68 |
+
not met instead of raising ``RuntimeError``.
|
69 |
+
gate:
|
70 |
+
Optional :class:`SafetyGate` that applies EMA smoothing and burn-in
|
71 |
+
before enforcing the floors.
|
72 |
+
"""
|
73 |
+
model.eval()
|
74 |
+
with torch.no_grad():
|
75 |
+
logits, telemetry = model(bit_seq, causal=causal)
|
76 |
+
c_val = float(telemetry["lz_complexity_logits"].mean().item())
|
77 |
+
s_val = float(telemetry["symbiosis_score"].mean().item())
|
78 |
+
c_val = max(0.0, min(1.0, c_val))
|
79 |
+
s_val = max(0.0, min(1.0, s_val))
|
80 |
+
if gate is not None:
|
81 |
+
triggered = gate.should_trigger(c_val, s_val)
|
82 |
+
else:
|
83 |
+
triggered = c_val <= c_floor or s_val <= s_floor
|
84 |
+
if strict and triggered:
|
85 |
+
raise RuntimeError(
|
86 |
+
f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}"
|
87 |
+
)
|
88 |
+
return logits.argmax(-1), telemetry
|
89 |
+
|
90 |
+
|
91 |
+
def demo_hil_safety() -> None:
|
92 |
+
"""Demonstrate gating on random bits."""
|
93 |
+
bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
|
94 |
+
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
|
95 |
+
try:
|
96 |
+
out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0)
|
97 |
+
print("Safe output bits:", out.squeeze(0).tolist())
|
98 |
+
except RuntimeError as e:
|
99 |
+
print("Gate triggered:", e)
|
100 |
+
|
101 |
+
|
102 |
+
def safe_sample_with_retry(
|
103 |
+
model: BitTransformerLM,
|
104 |
+
bit_seq: torch.Tensor,
|
105 |
+
c_floor: float = 0.3,
|
106 |
+
s_floor: float = 0.5,
|
107 |
+
*,
|
108 |
+
causal: bool = True,
|
109 |
+
max_retries: int = 3,
|
110 |
+
backoff: float = 0.1,
|
111 |
+
gate: Optional[SafetyGate] = None,
|
112 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
113 |
+
"""Run :func:`hil_safe_inference` with automatic retries.
|
114 |
+
|
115 |
+
The helper retries failed safety checks by toggling diffusion mode and
|
116 |
+
refreshing the input bits. An exponential backoff is applied between
|
117 |
+
attempts and warnings are logged for each retry.
|
118 |
+
|
119 |
+
Parameters
|
120 |
+
----------
|
121 |
+
gate:
|
122 |
+
Optional :class:`SafetyGate` instance shared across retries to apply
|
123 |
+
EMA smoothing and burn-in.
|
124 |
+
|
125 |
+
Returns
|
126 |
+
-------
|
127 |
+
Tuple[torch.Tensor, Dict[str, torch.Tensor]]
|
128 |
+
The sampled bits and associated telemetry.
|
129 |
+
"""
|
130 |
+
|
131 |
+
for attempt in range(max_retries):
|
132 |
+
try:
|
133 |
+
return hil_safe_inference(
|
134 |
+
model,
|
135 |
+
bit_seq,
|
136 |
+
c_floor,
|
137 |
+
s_floor,
|
138 |
+
causal=causal,
|
139 |
+
strict=True,
|
140 |
+
gate=gate,
|
141 |
+
)
|
142 |
+
except RuntimeError as exc: # safety gate triggered
|
143 |
+
logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc)
|
144 |
+
if attempt >= max_retries - 1:
|
145 |
+
raise
|
146 |
+
time.sleep(backoff * (2 ** attempt))
|
147 |
+
causal = False # retry in diffusion mode
|
148 |
+
bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)
|
149 |
+
|
BitTransformerLM/bit_transformer/scale.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Dict
|
3 |
+
from .model import BitTransformerLM
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def expand_model(model: BitTransformerLM, new_params: Dict) -> BitTransformerLM:
|
8 |
+
"""Return a new model with updated params and copied weights."""
|
9 |
+
new_model = BitTransformerLM(**new_params)
|
10 |
+
new_state = new_model.state_dict()
|
11 |
+
old_state = model.state_dict()
|
12 |
+
|
13 |
+
for k, v in old_state.items():
|
14 |
+
if k in new_state:
|
15 |
+
dest = new_state[k]
|
16 |
+
slices = tuple(slice(0, min(d, s)) for d, s in zip(dest.shape, v.shape))
|
17 |
+
dest[slices].copy_(v[slices])
|
18 |
+
if dest.shape != v.shape:
|
19 |
+
mask = torch.ones_like(dest, dtype=torch.bool)
|
20 |
+
mask[slices] = False
|
21 |
+
if "bias" in k:
|
22 |
+
dest[mask] = 0.0
|
23 |
+
else:
|
24 |
+
dest[mask] = 0.001 * torch.randn_like(dest[mask])
|
25 |
+
|
26 |
+
for k, v in new_state.items():
|
27 |
+
if k not in old_state:
|
28 |
+
if "bias" in k:
|
29 |
+
v.zero_()
|
30 |
+
elif v.dim() > 1:
|
31 |
+
nn.init.normal_(v, mean=0.0, std=1e-3)
|
32 |
+
else:
|
33 |
+
v.zero_()
|
34 |
+
|
35 |
+
new_model.load_state_dict(new_state)
|
36 |
+
return new_model
|
BitTransformerLM/bit_transformer/static/style.css
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
:root {
|
2 |
+
--primary: #1e40af;
|
3 |
+
--bg: #f5f6fa;
|
4 |
+
}
|
5 |
+
|
6 |
+
body {
|
7 |
+
font-family: Arial, sans-serif;
|
8 |
+
background-color: var(--bg);
|
9 |
+
margin: 0;
|
10 |
+
padding: 0;
|
11 |
+
line-height: 1.5;
|
12 |
+
color: #333;
|
13 |
+
}
|
14 |
+
|
15 |
+
.container {
|
16 |
+
max-width: 900px;
|
17 |
+
margin: 0 auto;
|
18 |
+
padding-bottom: 2rem;
|
19 |
+
}
|
20 |
+
|
21 |
+
h1 {
|
22 |
+
text-align: center;
|
23 |
+
background: var(--primary);
|
24 |
+
color: #fff;
|
25 |
+
margin: 0;
|
26 |
+
padding: 1rem 0;
|
27 |
+
}
|
28 |
+
|
29 |
+
section {
|
30 |
+
background: #fff;
|
31 |
+
margin: 1rem auto;
|
32 |
+
padding: 1rem 1.5rem;
|
33 |
+
border-radius: 8px;
|
34 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
35 |
+
width: 90%;
|
36 |
+
max-width: 800px;
|
37 |
+
}
|
38 |
+
|
39 |
+
section h2 {
|
40 |
+
margin-top: 0;
|
41 |
+
color: var(--primary);
|
42 |
+
font-size: 1.25rem;
|
43 |
+
}
|
44 |
+
|
45 |
+
form {
|
46 |
+
display: flex;
|
47 |
+
flex-wrap: wrap;
|
48 |
+
gap: 0.5rem 1rem;
|
49 |
+
}
|
50 |
+
|
51 |
+
form input[type="text"],
|
52 |
+
form input[type="number"],
|
53 |
+
form textarea {
|
54 |
+
flex: 1 1 200px;
|
55 |
+
padding: 0.4em;
|
56 |
+
border: 1px solid #ccc;
|
57 |
+
border-radius: 4px;
|
58 |
+
}
|
59 |
+
|
60 |
+
form button,
|
61 |
+
button#scaleBtn {
|
62 |
+
padding: 0.4em 0.8em;
|
63 |
+
border: none;
|
64 |
+
background: var(--primary);
|
65 |
+
color: #fff;
|
66 |
+
border-radius: 4px;
|
67 |
+
cursor: pointer;
|
68 |
+
}
|
69 |
+
|
70 |
+
form button:hover,
|
71 |
+
button#scaleBtn:hover {
|
72 |
+
background-color: #1d4ed8;
|
73 |
+
}
|
74 |
+
|
75 |
+
pre, p#trainOut {
|
76 |
+
background: #f0f0f0;
|
77 |
+
padding: 0.5rem;
|
78 |
+
border-radius: 4px;
|
79 |
+
overflow-x: auto;
|
80 |
+
}
|
81 |
+
|
82 |
+
label {
|
83 |
+
display: flex;
|
84 |
+
align-items: center;
|
85 |
+
gap: 0.5rem;
|
86 |
+
}
|
87 |
+
|
88 |
+
img#plot {
|
89 |
+
max-width: 100%;
|
90 |
+
height: auto;
|
91 |
+
display: block;
|
92 |
+
margin: auto;
|
93 |
+
}
|
BitTransformerLM/bit_transformer/telemetry.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict, List, TYPE_CHECKING
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
|
7 |
+
if TYPE_CHECKING: # pragma: no cover
|
8 |
+
from .model import BitTransformerLM
|
9 |
+
|
10 |
+
|
11 |
+
class TelemetrySynthesizer:
|
12 |
+
"""Analyze telemetry batches and cluster activation patterns."""
|
13 |
+
|
14 |
+
def __init__(self, n_clusters: int = 2) -> None:
|
15 |
+
self.n_clusters = n_clusters
|
16 |
+
|
17 |
+
def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
|
18 |
+
"""Compute activation/attention summaries for a single telemetry dict."""
|
19 |
+
acts = telemetry["activations"]
|
20 |
+
attn = telemetry["attention_maps"]
|
21 |
+
summaries = []
|
22 |
+
for a, m in zip(acts, attn):
|
23 |
+
mean = a.mean().item()
|
24 |
+
var = a.var(unbiased=False).item()
|
25 |
+
prob = m.softmax(-1)
|
26 |
+
entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
|
27 |
+
summaries.append([mean, var, entropy])
|
28 |
+
return np.array(summaries).ravel()
|
29 |
+
|
30 |
+
def synthesize(
|
31 |
+
self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
|
32 |
+
) -> Dict[str, List]:
|
33 |
+
"""Cluster telemetry summaries and return cluster info."""
|
34 |
+
data = np.stack([self._summary(t) for t in telemetries])
|
35 |
+
km = KMeans(n_clusters=self.n_clusters, n_init=1)
|
36 |
+
labels = km.fit_predict(data)
|
37 |
+
representatives: List[List[int]] = []
|
38 |
+
for c in range(self.n_clusters):
|
39 |
+
idx = np.where(labels == c)[0]
|
40 |
+
if len(idx) > 0:
|
41 |
+
representatives.append(bit_seqs[idx[0]].tolist())
|
42 |
+
else:
|
43 |
+
representatives.append([])
|
44 |
+
return {"cluster_assignments": labels.tolist(), "representatives": representatives}
|
45 |
+
|
46 |
+
def cluster_sequences(
|
47 |
+
self, model: "BitTransformerLM", bit_seqs: torch.Tensor
|
48 |
+
) -> List[List[int]]:
|
49 |
+
"""Run the model to gather telemetry and return representative sequences.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
model: BitTransformerLM
|
54 |
+
Model used to compute telemetry for each sequence.
|
55 |
+
bit_seqs: torch.Tensor
|
56 |
+
Tensor containing one bit sequence per row.
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
list[list[int]]
|
61 |
+
Representative sequences chosen from KMeans clusters.
|
62 |
+
"""
|
63 |
+
telemetries: List[Dict[str, List[torch.Tensor]]] = []
|
64 |
+
with torch.no_grad():
|
65 |
+
for seq in bit_seqs:
|
66 |
+
_, tele = model(seq.unsqueeze(0))
|
67 |
+
telemetries.append(tele)
|
68 |
+
info = self.synthesize(telemetries, bit_seqs)
|
69 |
+
return info["representatives"]
|
70 |
+
|
71 |
+
|
72 |
+
def detect_metric_drift(
|
73 |
+
metrics_log: Dict[str, List[float]],
|
74 |
+
window: int = 10,
|
75 |
+
threshold: float = 0.2,
|
76 |
+
) -> Dict[str, bool]:
|
77 |
+
"""Detect metric drift between consecutive windows.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
metrics_log: History of scalar metrics keyed by name.
|
81 |
+
window: Number of recent steps to compare.
|
82 |
+
threshold: Absolute difference required to flag drift.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Dictionary mapping metric keys to a boolean drift indicator.
|
86 |
+
"""
|
87 |
+
drift = {}
|
88 |
+
for key, values in metrics_log.items():
|
89 |
+
if len(values) < window * 2:
|
90 |
+
drift[key] = False
|
91 |
+
continue
|
92 |
+
recent = np.mean(values[-window:])
|
93 |
+
prev = np.mean(values[-2 * window : -window])
|
94 |
+
drift[key] = abs(recent - prev) > threshold
|
95 |
+
return drift
|
BitTransformerLM/bit_transformer/templates/dashboard.html
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<title>Bit Transformer Dashboard</title>
|
6 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
|
7 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<h1>Bit Transformer Dashboard</h1>
|
11 |
+
<div class="container">
|
12 |
+
<section>
|
13 |
+
<h2>Initialize Model</h2>
|
14 |
+
<form id="initForm">
|
15 |
+
d_model: <input type="number" name="d_model" value="{{ defaults.d_model }}" title="Model width (default {{ defaults.d_model }})"><br>
|
16 |
+
nhead: <input type="number" name="nhead" value="{{ defaults.nhead }}" title="Attention heads (default {{ defaults.nhead }})"><br>
|
17 |
+
num_layers: <input type="number" name="num_layers" value="{{ defaults.num_layers }}" title="Transformer layers (default {{ defaults.num_layers }})"><br>
|
18 |
+
dim_feedforward: <input type="number" name="dim_feedforward" value="{{ defaults.dim_feedforward }}" title="Feedforward dim (default {{ defaults.dim_feedforward }})"><br>
|
19 |
+
max_seq_len: <input type="number" name="max_seq_len" value="{{ defaults.max_seq_len }}" title="Max sequence length (default {{ defaults.max_seq_len }})"><br>
|
20 |
+
chunk_size: <input type="number" name="chunk_size" title="Chunked attention size"><br>
|
21 |
+
overlap: <input type="number" name="overlap" value="{{ defaults.overlap }}" title="Sliding window overlap"><br>
|
22 |
+
Reversible: <input type="checkbox" name="reversible" id="reversible_box" title="Use reversible layers (default {{ defaults.reversible }})"><br>
|
23 |
+
Gradient Checkpointing: <input type="checkbox" name="use_checkpoint" id="checkpoint_box" checked title="Enable gradient checkpointing (default {{ defaults.use_checkpoint }})"><br>
|
24 |
+
act_threshold: <input type="number" step="0.01" name="act_threshold" value="{{ defaults.act_threshold }}" title="ACT halt threshold (default {{ defaults.act_threshold }})"><br>
|
25 |
+
c_floor: <input type="number" step="0.01" name="c_floor" value="{{ c_floor }}" title="Complexity floor"><br>
|
26 |
+
s_floor: <input type="number" step="0.01" name="s_floor" value="{{ s_floor }}" title="Symbiosis floor"><br>
|
27 |
+
<button type="submit">Init</button>
|
28 |
+
</form>
|
29 |
+
</section>
|
30 |
+
<section>
|
31 |
+
<h2>Train Step</h2>
|
32 |
+
<form id="trainForm">
|
33 |
+
Bits (e.g. 0 1 0 1): <input type="text" name="bits" value="0 1 0 1"><br>
|
34 |
+
Upload file: <input type="file" id="train_file"><br>
|
35 |
+
<button type="submit">Train</button>
|
36 |
+
</form>
|
37 |
+
<label>Load sample dataset:
|
38 |
+
<select id="datasetSelect">
|
39 |
+
<option value="">--Select--</option>
|
40 |
+
<option value="wikitext2_train">Wikitext-2 (train)</option>
|
41 |
+
<option value="wikitext2_validation">Wikitext-2 (validation)</option>
|
42 |
+
</select>
|
43 |
+
</label>
|
44 |
+
<p id="trainOut"></p>
|
45 |
+
</section>
|
46 |
+
<section>
|
47 |
+
<h2>Scale Up</h2>
|
48 |
+
Width Mult: <input type="number" step="0.1" id="width_mult" value="1.0"><br>
|
49 |
+
<button id="scaleBtn">Scale Model</button>
|
50 |
+
</section>
|
51 |
+
<section>
|
52 |
+
<h2>Collapse Submodel</h2>
|
53 |
+
<form id="collapseForm">
|
54 |
+
Cluster Bits (JSON array of arrays):<br>
|
55 |
+
<textarea name="clusters" rows="3" cols="40">[[0,1,0,1],[1,1,0,0]]</textarea><br>
|
56 |
+
Target Params (JSON):<br>
|
57 |
+
<textarea name="params" rows="3" cols="40">{"d_model":32,"nhead":4,"num_layers":1,"dim_feedforward":64,"max_seq_len":16}</textarea><br>
|
58 |
+
Width Scale: <input type="number" step="0.1" id="width_scale" value="1.0"><br>
|
59 |
+
<button type="submit">Collapse</button>
|
60 |
+
</form>
|
61 |
+
</section>
|
62 |
+
<section>
|
63 |
+
<h2>Inference</h2>
|
64 |
+
<form id="inferForm">
|
65 |
+
Bits: <input type="text" name="bits" value="0 1 0 1"><br>
|
66 |
+
Upload file: <input type="file" id="infer_file"><br>
|
67 |
+
<button type="submit">Infer</button>
|
68 |
+
</form>
|
69 |
+
<pre id="inferOut"></pre>
|
70 |
+
</section>
|
71 |
+
<section>
|
72 |
+
<h2>Long Inference</h2>
|
73 |
+
<form id="inferLongForm">
|
74 |
+
Bits: <input type="text" name="bits" value="0 1 0 1"><br>
|
75 |
+
ctx_bits: <input type="number" name="ctx_bits" value="4096"><br>
|
76 |
+
overlap: <input type="number" name="overlap" value="256"><br>
|
77 |
+
<button type="submit">Infer Long</button>
|
78 |
+
</form>
|
79 |
+
<pre id="inferLongOut"></pre>
|
80 |
+
</section>
|
81 |
+
<section>
|
82 |
+
<h2>Text Inference</h2>
|
83 |
+
<form id="textInferForm">
|
84 |
+
Text: <input type="text" name="text" value="hello"><br>
|
85 |
+
<button type="submit">Infer Text</button>
|
86 |
+
</form>
|
87 |
+
<pre id="textInferOut"></pre>
|
88 |
+
</section>
|
89 |
+
<section>
|
90 |
+
<h2>λ Weights</h2>
|
91 |
+
<form id="lambdaForm">
|
92 |
+
λ<sub>K</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_K" oninput="lambda_K_val.innerText=value"><span id="lambda_K_val"></span><br>
|
93 |
+
λ<sub>C</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_C" oninput="lambda_C_val.innerText=value"><span id="lambda_C_val"></span><br>
|
94 |
+
λ<sub>S</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_S" oninput="lambda_S_val.innerText=value"><span id="lambda_S_val"></span><br>
|
95 |
+
<button type="submit">Update</button>
|
96 |
+
</form>
|
97 |
+
</section>
|
98 |
+
<section>
|
99 |
+
<h2>Diffusion LM</h2>
|
100 |
+
<label><input type="checkbox" id="diffusion_box"> Enable Diffusion Mode</label>
|
101 |
+
</section>
|
102 |
+
<section>
|
103 |
+
<h2>GPU Acceleration</h2>
|
104 |
+
<label><input type="checkbox" id="gpu_box"> Enable FSDP & CUDA</label>
|
105 |
+
</section>
|
106 |
+
<section>
|
107 |
+
<h2>Enable Compression</h2>
|
108 |
+
<label><input type="checkbox" id="compression_box"> Compress I/O</label>
|
109 |
+
<p>Ratio: <span id="comp_ratio">1.0</span></p>
|
110 |
+
</section>
|
111 |
+
<section>
|
112 |
+
<h2>Quantization Aware Training</h2>
|
113 |
+
<label><input type="checkbox" id="qat_box"> Enable 4-bit QAT</label>
|
114 |
+
</section>
|
115 |
+
<section>
|
116 |
+
<h2>Model Status</h2>
|
117 |
+
<pre id="statusOut"></pre>
|
118 |
+
</section>
|
119 |
+
<section>
|
120 |
+
<h2>Telemetry</h2>
|
121 |
+
<canvas id="metricChart" width="600" height="300"></canvas>
|
122 |
+
</section>
|
123 |
+
<section>
|
124 |
+
<h2>Hugging Face Checkpoints</h2>
|
125 |
+
Repo ID: <input type="text" id="hf_repo"><br>
|
126 |
+
Token: <input type="password" id="hf_token" placeholder="optional"><br>
|
127 |
+
<button id="uploadBtn">Upload weights</button>
|
128 |
+
<button id="downloadBtn">Download weights</button>
|
129 |
+
<p id="hfStatus"></p>
|
130 |
+
</section>
|
131 |
+
|
132 |
+
<script>
|
133 |
+
async function postJSON(url, data){
|
134 |
+
const resp = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(data)});
|
135 |
+
return resp.json();
|
136 |
+
}
|
137 |
+
|
138 |
+
async function pollJob(id){
|
139 |
+
while(true){
|
140 |
+
const job = await fetch(`/job/${id}`).then(r=>r.json());
|
141 |
+
if(job.status === 'completed') return job.result;
|
142 |
+
if(job.status === 'error') throw job.error || 'Job failed';
|
143 |
+
await new Promise(r=>setTimeout(r, 1000));
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
function loadInitParams(){
|
148 |
+
const saved = JSON.parse(localStorage.getItem('init_params')||'{}');
|
149 |
+
const form = document.getElementById('initForm');
|
150 |
+
for(const [k,v] of Object.entries(saved)){
|
151 |
+
const el = form.elements[k];
|
152 |
+
if(!el) continue;
|
153 |
+
if(el.type === 'checkbox') el.checked = v; else el.value = v;
|
154 |
+
}
|
155 |
+
}
|
156 |
+
loadInitParams();
|
157 |
+
|
158 |
+
function byteArrayToBits(arr){
|
159 |
+
const bits=[];
|
160 |
+
for(const b of arr){
|
161 |
+
for(let i=7;i>=0;i--) bits.push((b>>i)&1);
|
162 |
+
}
|
163 |
+
return bits;
|
164 |
+
}
|
165 |
+
|
166 |
+
let trainFileBits=null, inferFileBits=null, datasetBits=null;
|
167 |
+
|
168 |
+
async function fileToBits(file){
|
169 |
+
if(file.type.startsWith('text')){
|
170 |
+
const text = await file.text();
|
171 |
+
const res = await postJSON('/text_to_bits', {text});
|
172 |
+
return res.bits;
|
173 |
+
}
|
174 |
+
const buf = await file.arrayBuffer();
|
175 |
+
return byteArrayToBits(new Uint8Array(buf));
|
176 |
+
}
|
177 |
+
|
178 |
+
let metricChart;
|
179 |
+
async function initChart(){
|
180 |
+
const data = await fetch('/metrics').then(r=>r.json());
|
181 |
+
const labels = data.negentropy.map((_,i)=>i);
|
182 |
+
const ctx = document.getElementById('metricChart').getContext('2d');
|
183 |
+
metricChart = new Chart(ctx, {
|
184 |
+
type:'line',
|
185 |
+
data:{
|
186 |
+
labels:labels,
|
187 |
+
datasets:[
|
188 |
+
{label:'Negentropy', data:data.negentropy, borderColor:'blue', fill:false},
|
189 |
+
{label:'LZ Complexity', data:data.lz_complexity, borderColor:'orange', fill:false},
|
190 |
+
{label:'Symbiosis', data:data.symbiosis, borderColor:'green', fill:false}
|
191 |
+
]
|
192 |
+
},
|
193 |
+
options:{responsive:false, interaction:{mode:'index', intersect:false}}
|
194 |
+
});
|
195 |
+
}
|
196 |
+
|
197 |
+
async function updateChart(){
|
198 |
+
const data = await fetch('/metrics').then(r=>r.json());
|
199 |
+
const labels = data.negentropy.map((_,i)=>i);
|
200 |
+
metricChart.data.labels = labels;
|
201 |
+
metricChart.data.datasets[0].data = data.negentropy;
|
202 |
+
metricChart.data.datasets[1].data = data.lz_complexity;
|
203 |
+
metricChart.data.datasets[2].data = data.symbiosis;
|
204 |
+
metricChart.update();
|
205 |
+
}
|
206 |
+
|
207 |
+
initChart();
|
208 |
+
setInterval(updateChart, 2000);
|
209 |
+
|
210 |
+
async function refreshStatus(){
|
211 |
+
const [s, c] = await Promise.all([fetch('/status'), fetch('/model_config')]);
|
212 |
+
const status = await s.json();
|
213 |
+
const config = await c.json();
|
214 |
+
document.getElementById('statusOut').innerText = JSON.stringify({...status, ...config}, null, 2);
|
215 |
+
}
|
216 |
+
|
217 |
+
document.getElementById('initForm').addEventListener('submit', async (e)=>{
|
218 |
+
e.preventDefault();
|
219 |
+
const fd = new FormData(e.target);
|
220 |
+
const obj = Object.fromEntries(fd.entries());
|
221 |
+
const ints = ['d_model','nhead','num_layers','dim_feedforward','max_seq_len','chunk_size','overlap'];
|
222 |
+
ints.forEach(k=>{ if(obj[k]===''){ delete obj[k]; } else obj[k]=parseInt(obj[k]); });
|
223 |
+
obj.reversible = document.getElementById('reversible_box').checked;
|
224 |
+
obj.use_checkpoint = document.getElementById('checkpoint_box').checked;
|
225 |
+
obj.act_threshold = parseFloat(obj.act_threshold);
|
226 |
+
const floors = {c_floor: parseFloat(obj.c_floor), s_floor: parseFloat(obj.s_floor)};
|
227 |
+
delete obj.c_floor; delete obj.s_floor;
|
228 |
+
await postJSON('/init', obj);
|
229 |
+
await postJSON('/config/telemetry', floors);
|
230 |
+
localStorage.setItem('init_params', JSON.stringify({...obj, ...floors}));
|
231 |
+
refreshStatus();
|
232 |
+
updateChart();
|
233 |
+
});
|
234 |
+
|
235 |
+
document.getElementById('trainForm').addEventListener('submit', async (e)=>{
|
236 |
+
e.preventDefault();
|
237 |
+
const form = e.target;
|
238 |
+
let payload;
|
239 |
+
if(trainFileBits){
|
240 |
+
payload = trainFileBits;
|
241 |
+
} else if(datasetBits){
|
242 |
+
payload = datasetBits;
|
243 |
+
} else {
|
244 |
+
payload = [form.bits.value.trim().split(/\s+/).map(Number)];
|
245 |
+
}
|
246 |
+
for(const el of form.elements) el.disabled = true;
|
247 |
+
const out = document.getElementById('trainOut');
|
248 |
+
out.innerText = '⏳';
|
249 |
+
try{
|
250 |
+
const job = await postJSON('/train', {bits: payload});
|
251 |
+
const res = await pollJob(job.job_id);
|
252 |
+
out.innerText = 'Loss: '+res.loss.toFixed(4);
|
253 |
+
if(res.ratio !== undefined){
|
254 |
+
document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
|
255 |
+
}
|
256 |
+
} catch(err){
|
257 |
+
out.innerText = 'Error';
|
258 |
+
alert(err);
|
259 |
+
} finally {
|
260 |
+
for(const el of form.elements) el.disabled = false;
|
261 |
+
refreshStatus();
|
262 |
+
updateChart();
|
263 |
+
}
|
264 |
+
});
|
265 |
+
|
266 |
+
document.getElementById('train_file').addEventListener('change', async (e)=>{
|
267 |
+
const f = e.target.files[0];
|
268 |
+
if(!f) return;
|
269 |
+
const bits = await fileToBits(f);
|
270 |
+
trainFileBits = [bits];
|
271 |
+
datasetBits = null;
|
272 |
+
document.querySelector('#trainForm input[name="bits"]').value = bits.slice(0,64).join(' ');
|
273 |
+
});
|
274 |
+
|
275 |
+
document.querySelector('#trainForm input[name="bits"]').addEventListener('input', ()=>{
|
276 |
+
trainFileBits = null;
|
277 |
+
datasetBits = null;
|
278 |
+
});
|
279 |
+
|
280 |
+
document.getElementById('scaleBtn').addEventListener('click', async ()=>{
|
281 |
+
const btn = document.getElementById('scaleBtn');
|
282 |
+
const input = document.getElementById('width_mult');
|
283 |
+
const mult = parseFloat(input.value);
|
284 |
+
btn.disabled = true; input.disabled = true;
|
285 |
+
const original = btn.innerText; btn.innerText = '⏳';
|
286 |
+
try{
|
287 |
+
const job = await postJSON('/scale_up', {width_mult: mult});
|
288 |
+
await pollJob(job.job_id);
|
289 |
+
} catch(err){
|
290 |
+
alert(err);
|
291 |
+
} finally {
|
292 |
+
btn.innerText = original;
|
293 |
+
btn.disabled = false; input.disabled = false;
|
294 |
+
refreshStatus();
|
295 |
+
updateChart();
|
296 |
+
}
|
297 |
+
});
|
298 |
+
|
299 |
+
document.getElementById('collapseForm').addEventListener('submit', async (e)=>{
|
300 |
+
e.preventDefault();
|
301 |
+
const form = e.target;
|
302 |
+
const btn = form.querySelector('button');
|
303 |
+
for(const el of form.elements) el.disabled = true;
|
304 |
+
const clusters = JSON.parse(form.clusters.value);
|
305 |
+
const params = JSON.parse(form.params.value);
|
306 |
+
const w = parseFloat(document.getElementById('width_scale').value);
|
307 |
+
const original = btn.innerText; btn.innerText = '⏳';
|
308 |
+
try{
|
309 |
+
const job = await postJSON('/collapse', {clusters: clusters, params: params, width_scale: w});
|
310 |
+
await pollJob(job.job_id);
|
311 |
+
} catch(err){
|
312 |
+
alert(err);
|
313 |
+
} finally {
|
314 |
+
btn.innerText = original;
|
315 |
+
for(const el of form.elements) el.disabled = false;
|
316 |
+
refreshStatus();
|
317 |
+
updateChart();
|
318 |
+
}
|
319 |
+
});
|
320 |
+
|
321 |
+
document.getElementById('inferForm').addEventListener('submit', async (e)=>{
|
322 |
+
e.preventDefault();
|
323 |
+
let bits;
|
324 |
+
if(inferFileBits){
|
325 |
+
bits = inferFileBits;
|
326 |
+
} else if(datasetBits){
|
327 |
+
bits = [datasetBits[0]];
|
328 |
+
} else {
|
329 |
+
bits = [e.target.bits.value.trim().split(/\s+/).map(Number)];
|
330 |
+
}
|
331 |
+
const res = await postJSON('/infer', {bits});
|
332 |
+
if(res.error){
|
333 |
+
alert(res.error + '\n' + (res.suggestion||''));
|
334 |
+
} else {
|
335 |
+
document.getElementById('inferOut').innerText = JSON.stringify(res, null, 2);
|
336 |
+
if(res.ratio !== undefined){
|
337 |
+
document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
|
338 |
+
}
|
339 |
+
}
|
340 |
+
refreshStatus();
|
341 |
+
updateChart();
|
342 |
+
});
|
343 |
+
|
344 |
+
document.getElementById('infer_file').addEventListener('change', async (e)=>{
|
345 |
+
const f = e.target.files[0];
|
346 |
+
if(!f) return;
|
347 |
+
const bits = await fileToBits(f);
|
348 |
+
inferFileBits = [bits];
|
349 |
+
datasetBits = null;
|
350 |
+
document.querySelector('#inferForm input[name="bits"]').value = bits.slice(0,64).join(' ');
|
351 |
+
});
|
352 |
+
|
353 |
+
document.querySelector('#inferForm input[name="bits"]').addEventListener('input', ()=>{
|
354 |
+
inferFileBits = null;
|
355 |
+
datasetBits = null;
|
356 |
+
});
|
357 |
+
|
358 |
+
document.getElementById('datasetSelect').addEventListener('change', async (e)=>{
|
359 |
+
const val = e.target.value;
|
360 |
+
trainFileBits = null;
|
361 |
+
inferFileBits = null;
|
362 |
+
if(!val){ datasetBits = null; return; }
|
363 |
+
const [name, split] = val.split('_');
|
364 |
+
const resp = await fetch(`/dataset?name=${name}&split=${split}&size=4&seq_len=64`);
|
365 |
+
const data = await resp.json();
|
366 |
+
datasetBits = data.bits;
|
367 |
+
const preview = data.bits[0].slice(0,64).join(' ');
|
368 |
+
document.querySelector('#trainForm input[name="bits"]').value = preview;
|
369 |
+
document.querySelector('#inferForm input[name="bits"]').value = preview;
|
370 |
+
});
|
371 |
+
|
372 |
+
document.getElementById('inferLongForm').addEventListener('submit', async (e)=>{
|
373 |
+
e.preventDefault();
|
374 |
+
const bits = e.target.bits.value.trim().split(/\s+/).map(Number);
|
375 |
+
const ctx = parseInt(e.target.ctx_bits.value);
|
376 |
+
const ov = parseInt(e.target.overlap.value);
|
377 |
+
const res = await postJSON('/infer_long', {bits: bits, ctx_bits: ctx, overlap: ov});
|
378 |
+
document.getElementById('inferLongOut').innerText = JSON.stringify(res, null, 2);
|
379 |
+
refreshStatus();
|
380 |
+
updateChart();
|
381 |
+
});
|
382 |
+
|
383 |
+
document.getElementById('textInferForm').addEventListener('submit', async (e)=>{
|
384 |
+
e.preventDefault();
|
385 |
+
const text = e.target.text.value;
|
386 |
+
const res = await postJSON('/infer_text', {text:text});
|
387 |
+
document.getElementById('textInferOut').innerText = JSON.stringify(res, null, 2);
|
388 |
+
refreshStatus();
|
389 |
+
updateChart();
|
390 |
+
});
|
391 |
+
|
392 |
+
async function loadLambdas(){
|
393 |
+
const resp = await fetch('/lambdas');
|
394 |
+
const vals = await resp.json();
|
395 |
+
for(const k of ['lambda_K','lambda_C','lambda_S']){
|
396 |
+
document.getElementById(k).value = vals[k];
|
397 |
+
document.getElementById(k+"_val").innerText = vals[k];
|
398 |
+
}
|
399 |
+
}
|
400 |
+
|
401 |
+
document.getElementById('lambdaForm').addEventListener('submit', async (e)=>{
|
402 |
+
e.preventDefault();
|
403 |
+
const data = {
|
404 |
+
lambda_K: parseFloat(document.getElementById('lambda_K').value),
|
405 |
+
lambda_C: parseFloat(document.getElementById('lambda_C').value),
|
406 |
+
lambda_S: parseFloat(document.getElementById('lambda_S').value),
|
407 |
+
};
|
408 |
+
await postJSON('/lambdas', data);
|
409 |
+
for(const k in data){
|
410 |
+
document.getElementById(k+"_val").innerText = data[k];
|
411 |
+
}
|
412 |
+
refreshStatus();
|
413 |
+
});
|
414 |
+
|
415 |
+
loadLambdas();
|
416 |
+
|
417 |
+
function restoreToggle(id,key,endpoint,field){
|
418 |
+
const box = document.getElementById(id);
|
419 |
+
const saved = localStorage.getItem(key);
|
420 |
+
if(saved !== null){ box.checked = saved === 'true'; postJSON(endpoint,{[field]: box.checked}); }
|
421 |
+
box.addEventListener('change', async (e)=>{
|
422 |
+
await postJSON(endpoint, {[field]: e.target.checked});
|
423 |
+
localStorage.setItem(key, e.target.checked);
|
424 |
+
refreshStatus();
|
425 |
+
});
|
426 |
+
}
|
427 |
+
|
428 |
+
restoreToggle('diffusion_box','diffusion','/diffusion','diffusion');
|
429 |
+
restoreToggle('gpu_box','use_gpu','/gpu','use_gpu');
|
430 |
+
restoreToggle('compression_box','compression','/compression','compression');
|
431 |
+
restoreToggle('qat_box','qat','/qat','qat');
|
432 |
+
|
433 |
+
document.getElementById('uploadBtn').addEventListener('click', async ()=>{
|
434 |
+
const repo = document.getElementById('hf_repo').value;
|
435 |
+
const token = document.getElementById('hf_token').value;
|
436 |
+
const res = await postJSON('/save_checkpoint', {repo_id: repo, token: token||undefined});
|
437 |
+
document.getElementById('hfStatus').innerText = res.status || res.error;
|
438 |
+
});
|
439 |
+
|
440 |
+
document.getElementById('downloadBtn').addEventListener('click', async ()=>{
|
441 |
+
const repo = document.getElementById('hf_repo').value;
|
442 |
+
const token = document.getElementById('hf_token').value;
|
443 |
+
const res = await postJSON('/download_checkpoint', {repo_id: repo, token: token||undefined});
|
444 |
+
document.getElementById('hfStatus').innerText = res.status || res.error;
|
445 |
+
refreshStatus();
|
446 |
+
updateChart();
|
447 |
+
});
|
448 |
+
|
449 |
+
refreshStatus();
|
450 |
+
</script>
|
451 |
+
</div>
|
452 |
+
</body>
|
453 |
+
</html>
|
454 |
+
|
BitTransformerLM/bit_transformer/torch_utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from contextlib import contextmanager
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
@contextmanager
|
8 |
+
def cpu_autocast(enabled: bool = True):
|
9 |
+
"""Context manager for bfloat16 autocast on CPU.
|
10 |
+
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
enabled: bool, default True
|
14 |
+
Whether to enable autocast. When ``False`` this context manager
|
15 |
+
behaves like a no-op.
|
16 |
+
"""
|
17 |
+
if enabled:
|
18 |
+
with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
|
19 |
+
yield
|
20 |
+
else:
|
21 |
+
yield
|
BitTransformerLM/bit_transformer/training.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Common training utilities for BitTransformer models."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import Callable, Dict, List, Optional
|
6 |
+
import contextlib
|
7 |
+
import sys
|
8 |
+
import warnings
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from .compression import compress_bits, pack_bits, unpack_bits
|
16 |
+
from .optimization import configure_optimizer
|
17 |
+
from .model import BitTransformerLM
|
18 |
+
from .utils import set_dropout
|
19 |
+
from .torch_utils import cpu_autocast
|
20 |
+
|
21 |
+
|
22 |
+
def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float:
|
23 |
+
"""Cosine ramp from ``start`` to ``end`` over ``total_steps``."""
|
24 |
+
if total_steps <= 0 or step >= total_steps:
|
25 |
+
return end
|
26 |
+
cos_inner = math.pi * step / total_steps
|
27 |
+
return start + (end - start) * (1 - math.cos(cos_inner)) / 2
|
28 |
+
|
29 |
+
|
30 |
+
def train_loop(
|
31 |
+
model: BitTransformerLM,
|
32 |
+
data: torch.Tensor,
|
33 |
+
*,
|
34 |
+
epochs: int = 1,
|
35 |
+
extra_steps: int = 0,
|
36 |
+
compress_prob: float = 0.5,
|
37 |
+
direct_prob: float = 0.0,
|
38 |
+
batch_size: int = 8,
|
39 |
+
num_workers: int = 0,
|
40 |
+
accum_steps: int = 1,
|
41 |
+
amp: bool = False,
|
42 |
+
compile_model: bool = False,
|
43 |
+
log: bool = False,
|
44 |
+
forward_kwargs: Optional[Dict] = None,
|
45 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
46 |
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
47 |
+
diffusion: bool = False,
|
48 |
+
noise_fn: Optional[Callable[[], float]] = None,
|
49 |
+
diffusion_curriculum: bool = False,
|
50 |
+
compress_warmup: int = 0,
|
51 |
+
) -> List[Dict[str, float]]:
|
52 |
+
"""Generic training loop supporting optional compression and diffusion.
|
53 |
+
|
54 |
+
``compress_prob`` controls the fraction of batches that are run through
|
55 |
+
``forward_compressed``. ``direct_prob`` instead feeds the model with the
|
56 |
+
bit-packed result of ``compress_bits`` after converting back to a bit
|
57 |
+
tensor. When enabled, metrics for direct-compressed batches are tracked
|
58 |
+
separately.
|
59 |
+
|
60 |
+
When ``diffusion`` is ``True`` the loop performs denoising training. Batches
|
61 |
+
are noised by randomly flipping bits with a probability given by
|
62 |
+
``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When
|
63 |
+
``diffusion_curriculum`` is ``True`` the noise probability decreases
|
64 |
+
linearly from ``0.5`` to ``0.0`` over the training epochs. The model is
|
65 |
+
then trained to recover the clean sequence using full-context attention
|
66 |
+
(``causal=False``).
|
67 |
+
|
68 |
+
Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow
|
69 |
+
integration with long-running training sessions, otherwise new ones are
|
70 |
+
created automatically.
|
71 |
+
"""
|
72 |
+
if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1":
|
73 |
+
model = torch.compile(model)
|
74 |
+
elif compile_model:
|
75 |
+
warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12")
|
76 |
+
|
77 |
+
model.train()
|
78 |
+
set_dropout(model, 0.1)
|
79 |
+
|
80 |
+
device = next(model.parameters()).device
|
81 |
+
loader = DataLoader(
|
82 |
+
data,
|
83 |
+
batch_size=batch_size,
|
84 |
+
shuffle=True,
|
85 |
+
num_workers=num_workers,
|
86 |
+
persistent_workers=num_workers > 0,
|
87 |
+
)
|
88 |
+
steps_per_epoch = max(1, len(loader))
|
89 |
+
total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps)
|
90 |
+
if optimizer is None or scheduler is None:
|
91 |
+
optimizer, scheduler = configure_optimizer(
|
92 |
+
model, lr=1e-3, total_steps=total_updates
|
93 |
+
)
|
94 |
+
metrics: List[Dict[str, float]] = []
|
95 |
+
|
96 |
+
global_step = 0
|
97 |
+
for epoch in range(epochs):
|
98 |
+
raw_losses: List[float] = []
|
99 |
+
raw_accs: List[float] = []
|
100 |
+
comp_losses: List[float] = []
|
101 |
+
comp_accs: List[float] = []
|
102 |
+
comp_ratios: List[float] = []
|
103 |
+
direct_losses: List[float] = []
|
104 |
+
|
105 |
+
last_batch = None
|
106 |
+
for step, batch in enumerate(loader):
|
107 |
+
last_batch = batch
|
108 |
+
batch = batch.to(device)
|
109 |
+
cur_compress = (
|
110 |
+
cosine_ramp(global_step, 0.0, compress_prob, compress_warmup)
|
111 |
+
if not diffusion
|
112 |
+
else compress_prob
|
113 |
+
)
|
114 |
+
if diffusion:
|
115 |
+
if diffusion_curriculum:
|
116 |
+
p = 0.5 * (1 - epoch / max(1, epochs - 1))
|
117 |
+
else:
|
118 |
+
p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5)
|
119 |
+
noise = (torch.rand_like(batch.float()) < p).long()
|
120 |
+
noisy = batch ^ noise
|
121 |
+
with (
|
122 |
+
torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
123 |
+
if amp and torch.cuda.is_available()
|
124 |
+
else cpu_autocast() if amp else contextlib.nullcontext()
|
125 |
+
):
|
126 |
+
logits, _ = model(noisy, causal=False)
|
127 |
+
pred = logits.reshape(-1, 2)
|
128 |
+
target = batch.reshape(-1)
|
129 |
+
loss = F.cross_entropy(pred, target) / accum_steps
|
130 |
+
acc = (pred.argmax(dim=-1) == target).float().mean().item()
|
131 |
+
raw_losses.append(loss.item() * accum_steps)
|
132 |
+
raw_accs.append(acc)
|
133 |
+
loss.backward()
|
134 |
+
if (step + 1) % accum_steps == 0:
|
135 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
136 |
+
optimizer.step()
|
137 |
+
scheduler.step()
|
138 |
+
optimizer.zero_grad()
|
139 |
+
global_step += 1
|
140 |
+
continue
|
141 |
+
|
142 |
+
r = torch.rand(())
|
143 |
+
key = "raw"
|
144 |
+
ratio = 1.0
|
145 |
+
target = batch[:, 1:].reshape(-1)
|
146 |
+
|
147 |
+
if r < direct_prob:
|
148 |
+
packed = [pack_bits(row.to(torch.uint8)) for row in batch]
|
149 |
+
unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed]
|
150 |
+
max_len = min(
|
151 |
+
max(u.numel() for u in unpacked),
|
152 |
+
model.pos_enc.pe.size(0),
|
153 |
+
)
|
154 |
+
padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked]
|
155 |
+
dc_batch = torch.stack(padded).long()
|
156 |
+
with (
|
157 |
+
torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
158 |
+
if amp and torch.cuda.is_available()
|
159 |
+
else cpu_autocast() if amp else contextlib.nullcontext()
|
160 |
+
):
|
161 |
+
logits, _ = model(dc_batch, **(forward_kwargs or {}))
|
162 |
+
ratio = sum(p.numel() for p in packed) / batch.numel()
|
163 |
+
target = dc_batch[:, 1:].reshape(-1)
|
164 |
+
key = "direct"
|
165 |
+
elif r < direct_prob + cur_compress:
|
166 |
+
comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch]
|
167 |
+
with (
|
168 |
+
torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
169 |
+
if amp and torch.cuda.is_available()
|
170 |
+
else cpu_autocast() if amp else contextlib.nullcontext()
|
171 |
+
):
|
172 |
+
logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {}))
|
173 |
+
ratio = sum(c.numel() for c in comp_batch) / batch.numel()
|
174 |
+
target = batch[:, 1:].reshape(-1)
|
175 |
+
key = "compressed"
|
176 |
+
else:
|
177 |
+
with (
|
178 |
+
torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
179 |
+
if amp and torch.cuda.is_available()
|
180 |
+
else cpu_autocast() if amp else contextlib.nullcontext()
|
181 |
+
):
|
182 |
+
logits, _ = model(batch, **(forward_kwargs or {}))
|
183 |
+
|
184 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
185 |
+
loss = F.cross_entropy(pred, target) / accum_steps
|
186 |
+
acc = (pred.argmax(dim=-1) == target).float().mean().item()
|
187 |
+
|
188 |
+
loss.backward()
|
189 |
+
if (step + 1) % accum_steps == 0:
|
190 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
191 |
+
optimizer.step()
|
192 |
+
scheduler.step()
|
193 |
+
optimizer.zero_grad()
|
194 |
+
global_step += 1
|
195 |
+
|
196 |
+
if key == "compressed":
|
197 |
+
comp_losses.append(loss.item() * accum_steps)
|
198 |
+
comp_accs.append(acc)
|
199 |
+
comp_ratios.append(ratio)
|
200 |
+
elif key == "direct":
|
201 |
+
direct_losses.append(loss.item() * accum_steps)
|
202 |
+
comp_ratios.append(ratio)
|
203 |
+
else:
|
204 |
+
raw_losses.append(loss.item() * accum_steps)
|
205 |
+
raw_accs.append(acc)
|
206 |
+
|
207 |
+
# run extra gradient updates using the final batch
|
208 |
+
if extra_steps > 0 and last_batch is not None and not diffusion:
|
209 |
+
for step in range(extra_steps):
|
210 |
+
with (
|
211 |
+
torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
212 |
+
if amp and torch.cuda.is_available()
|
213 |
+
else cpu_autocast() if amp else contextlib.nullcontext()
|
214 |
+
):
|
215 |
+
logits, _ = model(last_batch, **(forward_kwargs or {}))
|
216 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
217 |
+
target = last_batch[:, 1:].reshape(-1)
|
218 |
+
loss = F.cross_entropy(pred, target) / accum_steps
|
219 |
+
acc = (pred.argmax(dim=-1) == target).float().mean().item()
|
220 |
+
loss.backward()
|
221 |
+
if (step + 1) % accum_steps == 0:
|
222 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
223 |
+
optimizer.step()
|
224 |
+
scheduler.step()
|
225 |
+
optimizer.zero_grad()
|
226 |
+
raw_losses.append(loss.item() * accum_steps)
|
227 |
+
raw_accs.append(acc)
|
228 |
+
global_step += 1
|
229 |
+
|
230 |
+
m = {
|
231 |
+
"raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0,
|
232 |
+
"raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0,
|
233 |
+
"compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0,
|
234 |
+
"compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0,
|
235 |
+
"direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0,
|
236 |
+
"compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0,
|
237 |
+
}
|
238 |
+
metrics.append(m)
|
239 |
+
|
240 |
+
if log:
|
241 |
+
print(
|
242 |
+
f"Epoch {epoch} "
|
243 |
+
f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | "
|
244 |
+
f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} "
|
245 |
+
f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}"
|
246 |
+
)
|
247 |
+
|
248 |
+
return metrics
|
249 |
+
|
250 |
+
__all__ = ["train_loop"]
|
BitTransformerLM/bit_transformer/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gzip
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(model: torch.nn.Module, path: str) -> None:
|
8 |
+
"""Save a model using gzip compression."""
|
9 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
10 |
+
with gzip.open(path, 'wb') as f:
|
11 |
+
torch.save(model, f)
|
12 |
+
|
13 |
+
|
14 |
+
def load_model(path: str) -> torch.nn.Module:
|
15 |
+
"""Load a model saved with ``save_model``."""
|
16 |
+
with gzip.open(path, 'rb') as f:
|
17 |
+
model = torch.load(f, map_location="cpu", weights_only=False)
|
18 |
+
return model
|
19 |
+
|
20 |
+
|
21 |
+
def set_dropout(model: torch.nn.Module, p: float) -> None:
|
22 |
+
"""Set dropout probability ``p`` for all dropout layers in ``model``."""
|
23 |
+
for module in model.modules():
|
24 |
+
if isinstance(module, nn.Dropout):
|
25 |
+
module.p = p
|
26 |
+
|
27 |
+
|
28 |
+
__all__ = ["save_model", "load_model", "set_dropout"]
|
BitTransformerLM/bit_transformer_lm_codex_playbook.md
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
# 🧭 BitTransformerLM Codex Playbook (Merged)
|
4 |
+
|
5 |
+
A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom.
|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
## Phase 1 — Training Loop & Runtime Optimizations (apply these first)
|
10 |
+
|
11 |
+
### Task 1 — Make batch size configurable & fix OneCycle accounting — COMPLETED ✅
|
12 |
+
|
13 |
+
**Prompt:**
|
14 |
+
|
15 |
+
```bash
|
16 |
+
codex run bittransformerlm/patch \
|
17 |
+
--file bit_transformer/training.py \
|
18 |
+
--edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"
|
19 |
+
```
|
20 |
+
|
21 |
+
✅ OneCycle’s horizon matches reality across runs.
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
### Task 2 — Remove hardcoded `total_steps=100` in dashboard/MCP — COMPLETED ✅
|
26 |
+
|
27 |
+
**Prompt:**
|
28 |
+
|
29 |
+
```bash
|
30 |
+
codex run bittransformerlm/patch \
|
31 |
+
--file dashboard/manager.py \
|
32 |
+
--edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"
|
33 |
+
```
|
34 |
+
|
35 |
+
✅ Aligns scheduler behavior between direct loop and MCP/dashboard.
|
36 |
+
|
37 |
+
---
|
38 |
+
|
39 |
+
### Task 3 — Add mixed-precision autocast (AMP, BF16) — COMPLETED ✅
|
40 |
+
|
41 |
+
**Prompt (pseudo-patch):**
|
42 |
+
|
43 |
+
```python
|
44 |
+
with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
|
45 |
+
logits = model(batch)
|
46 |
+
loss = criterion(logits, labels)
|
47 |
+
loss.backward()
|
48 |
+
```
|
49 |
+
|
50 |
+
✅ 1.2–1.8× throughput on attention-heavy training. Keep grad-clip.
|
51 |
+
|
52 |
+
---
|
53 |
+
|
54 |
+
### Task 4 — Add gradient accumulation — COMPLETED ✅
|
55 |
+
|
56 |
+
**Prompt:**
|
57 |
+
|
58 |
+
```bash
|
59 |
+
codex run bittransformerlm/patch \
|
60 |
+
--file bit_transformer/training.py \
|
61 |
+
--edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"
|
62 |
+
```
|
63 |
+
|
64 |
+
✅ Simulates larger effective batch sizes without extra memory.
|
65 |
+
|
66 |
+
---
|
67 |
+
|
68 |
+
### Task 5 — Optimize dataset pipeline (mmap + streaming) — COMPLETED ✅
|
69 |
+
|
70 |
+
**Prompt:**
|
71 |
+
|
72 |
+
```bash
|
73 |
+
codex run bittransformerlm/patch \
|
74 |
+
--file data/wikitext_schedule.py \
|
75 |
+
--edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"
|
76 |
+
```
|
77 |
+
|
78 |
+
✅ Removes conversion bottlenecks on large corpora.
|
79 |
+
|
80 |
+
---
|
81 |
+
|
82 |
+
### Task 6 — Schedule compression probability (safer ramp) — COMPLETED ✅
|
83 |
+
|
84 |
+
**Prompt (pseudo-code):**
|
85 |
+
|
86 |
+
```python
|
87 |
+
compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)
|
88 |
+
```
|
89 |
+
|
90 |
+
✅ Prevents early instability from aggressive compression.
|
91 |
+
|
92 |
+
---
|
93 |
+
|
94 |
+
### Task 7 — Stabilize safety gate (EMA + burn‑in) — COMPLETED ✅
|
95 |
+
|
96 |
+
**Prompt (pseudo-patch):**
|
97 |
+
|
98 |
+
```python
|
99 |
+
ema_val = ema(val_loss, decay=0.9)
|
100 |
+
if step < burn_in_steps:
|
101 |
+
allow_training = True
|
102 |
+
elif ema_val > threshold:
|
103 |
+
trigger_gate()
|
104 |
+
```
|
105 |
+
|
106 |
+
✅ Reduces false positives from noisy early validations.
|
107 |
+
|
108 |
+
---
|
109 |
+
|
110 |
+
### Task 8 — Enable `torch.compile` selectively — COMPLETED ✅
|
111 |
+
|
112 |
+
**Prompt:**
|
113 |
+
|
114 |
+
```bash
|
115 |
+
codex run bittransformerlm/patch \
|
116 |
+
--file bit_transformer/training.py \
|
117 |
+
--edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"
|
118 |
+
```
|
119 |
+
|
120 |
+
✅ Opportunistic speedup where supported.
|
121 |
+
|
122 |
+
---
|
123 |
+
|
124 |
+
### Task 9 — Integrate FlashAttention / SDPA
|
125 |
+
|
126 |
+
**Prompt (pseudo-patch):**
|
127 |
+
|
128 |
+
```python
|
129 |
+
from torch.nn import functional as F
|
130 |
+
|
131 |
+
def forward_attention(q, k, v, is_causal=True):
|
132 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
133 |
+
```
|
134 |
+
|
135 |
+
✅ Unlocks fused kernels; prefer `is_causal=True` over boolean masks.
|
136 |
+
|
137 |
+
---
|
138 |
+
|
139 |
+
### Task 10 — Cache causal masks — COMPLETED ✅
|
140 |
+
|
141 |
+
**Prompt (pseudo-code):**
|
142 |
+
|
143 |
+
```python
|
144 |
+
mask_cache = {}
|
145 |
+
|
146 |
+
def get_tri_mask(seq_len, device):
|
147 |
+
key = (seq_len, device)
|
148 |
+
if key not in mask_cache:
|
149 |
+
mask_cache[key] = torch.triu(
|
150 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
|
151 |
+
)
|
152 |
+
return mask_cache[key]
|
153 |
+
```
|
154 |
+
|
155 |
+
✅ Avoids repeated `triu` allocations when masks are still needed.
|
156 |
+
|
157 |
+
---
|
158 |
+
|
159 |
+
### Task 11 — Fix stitched attention negative indexing — COMPLETED ✅
|
160 |
+
|
161 |
+
**Prompt (pseudo-code):**
|
162 |
+
|
163 |
+
```python
|
164 |
+
start = max(s - overlap, 0)
|
165 |
+
end = min(s + chunk_size, T)
|
166 |
+
canvas[..., start:end] = attn_chunk[..., : end - start]
|
167 |
+
```
|
168 |
+
|
169 |
+
✅ Prevents wrap-around misplacement during T×T map reconstruction.
|
170 |
+
|
171 |
+
---
|
172 |
+
|
173 |
+
### Task 12 — Default off: full T×T attention logging in chunked runs — COMPLETED ✅
|
174 |
+
|
175 |
+
**Prompt:**
|
176 |
+
|
177 |
+
```bash
|
178 |
+
codex run bittransformerlm/patch \
|
179 |
+
--file bit_transformer/model.py \
|
180 |
+
--edit "Set full_attn_logging=False by default when chunk_size is set"
|
181 |
+
```
|
182 |
+
|
183 |
+
✅ Big memory/time savings without losing training signal.
|
184 |
+
|
185 |
+
---
|
186 |
+
|
187 |
+
## Phase 2 — Model Creation & Training Tasks (run after Phase 1)
|
188 |
+
|
189 |
+
### Task A — Train the best current baseline (8×256 with ACT)
|
190 |
+
|
191 |
+
**Prompt:**
|
192 |
+
|
193 |
+
```bash
|
194 |
+
codex run bittransformerlm/train \
|
195 |
+
--layers 8 \
|
196 |
+
--d_model 256 \
|
197 |
+
--nhead 8 \
|
198 |
+
--causal true \
|
199 |
+
--chunk_size 128 \
|
200 |
+
--act true \
|
201 |
+
--reversible true \
|
202 |
+
--checkpointing true \
|
203 |
+
--batch_size 64 \
|
204 |
+
--accum_steps 2 \
|
205 |
+
--amp bf16 \
|
206 |
+
--lr_schedule progressive_plateau \
|
207 |
+
--full_attn_logging false
|
208 |
+
```
|
209 |
+
|
210 |
+
✅ Reproduces the validated **sweet spot** with newly enabled efficiency features.
|
211 |
+
|
212 |
+
---
|
213 |
+
|
214 |
+
### Task B — CPU‑friendly deployment (8×128, INT8 + optional QAT)
|
215 |
+
|
216 |
+
**Prompt:**
|
217 |
+
|
218 |
+
```bash
|
219 |
+
codex run bittransformerlm/train \
|
220 |
+
--layers 8 \
|
221 |
+
--d_model 128 \
|
222 |
+
--nhead 8 \
|
223 |
+
--causal true \
|
224 |
+
--chunk_size 128 \
|
225 |
+
--quantization int8 \
|
226 |
+
--qat true \
|
227 |
+
--reversible true \
|
228 |
+
--checkpointing true \
|
229 |
+
--batch_size 128 \
|
230 |
+
--accum_steps 1 \
|
231 |
+
--amp bf16
|
232 |
+
```
|
233 |
+
|
234 |
+
✅ Efficient CPU target; QAT optional based on deployment constraints.
|
235 |
+
|
236 |
+
---
|
237 |
+
|
238 |
+
### Task C — Cautious scale‑up candidate (16×256)
|
239 |
+
|
240 |
+
**Prompt:**
|
241 |
+
|
242 |
+
```bash
|
243 |
+
codex run bittransformerlm/train \
|
244 |
+
--layers 16 \
|
245 |
+
--d_model 256 \
|
246 |
+
--nhead 8 \
|
247 |
+
--causal true \
|
248 |
+
--chunk_size 128 \
|
249 |
+
--act true \
|
250 |
+
--reversible true \
|
251 |
+
--checkpointing true \
|
252 |
+
--batch_size 48 \
|
253 |
+
--accum_steps 3 \
|
254 |
+
--amp bf16 \
|
255 |
+
--lr_schedule progressive_plateau
|
256 |
+
```
|
257 |
+
|
258 |
+
⚠️ Use only after data expansion and schedule retune.
|
259 |
+
|
260 |
+
---
|
261 |
+
|
262 |
+
## Recommended Execution Order
|
263 |
+
|
264 |
+
1. **Phase 1 Tasks 1–12** (apply all optimizations).
|
265 |
+
2. **Task A** baseline → validate.
|
266 |
+
3. **Task B** CPU build → validate + (optional) QAT.
|
267 |
+
4. **Task C** scale‑up **only** when data/schedule allow.
|
268 |
+
|
269 |
+
---
|
270 |
+
|
271 |
+
### Notes
|
272 |
+
|
273 |
+
- Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
|
274 |
+
- Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention.
|
275 |
+
- When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required.
|
276 |
+
|
277 |
+
---
|
278 |
+
|
BitTransformerLM/build_full_bits.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import torch
|
3 |
+
from datasets import load_dataset
|
4 |
+
|
5 |
+
TXT_MB = 100
|
6 |
+
OUT = pathlib.Path('full_bits.pt')
|
7 |
+
|
8 |
+
|
9 |
+
def build_bits(out: pathlib.Path = OUT, txt_mb: int = TXT_MB) -> None:
|
10 |
+
ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
11 |
+
buf = bytearray()
|
12 |
+
for line in ds['text']:
|
13 |
+
buf.extend(line.encode() + b"\n")
|
14 |
+
if len(buf) >= txt_mb * 2 ** 20:
|
15 |
+
break
|
16 |
+
bits = []
|
17 |
+
for byte in buf:
|
18 |
+
bits.extend(int(b) for b in f'{byte:08b}')
|
19 |
+
tensor = torch.tensor(bits, dtype=torch.uint8)
|
20 |
+
torch.save(tensor, out)
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
build_bits()
|
BitTransformerLM/context_extension.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Increasing the BitTransformerLM context window
|
2 |
+
|
3 |
+
Current limitations and mechanisms
|
4 |
+
The default max_seq_len in BitTransformerLM is 1 024 bits
|
5 |
+
GitHub
|
6 |
+
. Since text is encoded using parity bits (9 bits per byte)
|
7 |
+
GitHub
|
8 |
+
, this translates to roughly 113 bytes (≈113 characters) of input. The model uses full self‑attention, giving quadratic memory complexity in sequence length. To train on very long sequences, train_full_sequence slides a fixed‑size context window along a long bit tensor, detaching the computation graph periodically
|
9 |
+
GitHub
|
10 |
+
. Compression can shorten sequences via run‑length encoding
|
11 |
+
GitHub
|
12 |
+
, and chunked attention can divide long inputs into overlapping windows for attention calculations
|
13 |
+
GitHub
|
14 |
+
. However, the maximum positional encoding still defines an upper bound.
|
15 |
+
|
16 |
+
Strategies to reach ~2 k‑word context (~18 k bits)
|
17 |
+
Increase max_seq_len and positional encoding. The positional encoding precomputes a [max_len, d_model] matrix
|
18 |
+
GitHub
|
19 |
+
. Raising max_len to accommodate ~18 000 bits (for ~2 000 words × 9 bits per word) is possible but memory‑intensive. At d_model=128, the positional encoding would be ~18 000×128≈2.3 M floats (≈9 MB). That is reasonable for a CPU VM. Codex can modify the default max_seq_len and update any dependent tests.
|
20 |
+
Use chunked attention and overlapping windows. LoggingTransformerEncoderLayer already supports chunk_size and overlap parameters
|
21 |
+
GitHub
|
22 |
+
. Setting chunk_size (e.g., 2 048 bits) and an overlap of e.g., 128 bits enables the model to handle sequences far longer than the attention window while still allowing information flow across chunks. Codex can expose chunk_size and overlap through the dashboard and CLI so users can tune them for longer contexts.
|
23 |
+
Codex prompt example: “Modify the dashboard /init endpoint to accept chunk_size and overlap fields and pass them to BitTransformerLM. Update the HTML template to include input fields for these parameters.”
|
24 |
+
Apply sliding‑window training and inference. The train_full_sequence method trains on long bit tensors by sliding a context window and detaching the graph every ctx_bits bits
|
25 |
+
GitHub
|
26 |
+
. For inference, a similar sliding approach could produce outputs for long sequences. Codex can add an infer_long_sequence method that divides a long bit sequence into overlapping windows, runs the model with causal=True to preserve order, and stitches the outputs.
|
27 |
+
Prompt example: “Implement def infer_long_sequence(model: BitTransformerLM, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256): that processes a long bit tensor in sliding windows with overlap, uses causal=True, and returns the concatenated output bits.”
|
28 |
+
Exploit run‑length compression more aggressively. Since binary data often contains runs of identical bits (e.g., long sequences of zeros), increasing compression ratio reduces the effective sequence length. Codex could add additional compression schemes (e.g., bit‑packing into bytes using numpy.packbits) and integrate them into the model’s I/O pipeline. Care must be taken to maintain parity bits for error detection.
|
29 |
+
Prompt example: “Add functions pack_bits and unpack_bits that use numpy.packbits to pack 8 bits into a byte. Modify train_loop so that when direct_prob>0 the model is trained on packed bits with a suitable embedding.”
|
30 |
+
Memory‑efficient attention alternatives. For even larger contexts, one could replace full attention with sparse, local or linear attention mechanisms. However, this would change the core architecture, which the task seeks to avoid. Using chunked attention (already present) and reversible layers is therefore preferred.
|
31 |
+
Dynamic quantization and mixed precision. Larger context sizes increase model activations. Enabling use_autocast=True to compute in bfloat16 and applying quantize_dynamic after training reduces memory usage
|
32 |
+
GitHub
|
33 |
+
GitHub
|
34 |
+
. Codex can create scripts that quantify memory usage and automatically toggle these features when large contexts are requested.
|
35 |
+
Proposed Codex tasks to implement context extension
|
36 |
+
Expose context parameters in the API/UI. Extend the dashboard and MCP server to allow clients to specify max_seq_len, chunk_size, overlap, and ctx_bits when initializing a model or running long inference.
|
37 |
+
Prompt example: “Add optional parameters max_seq_len, chunk_size and overlap to the /init endpoint and pass them into BitTransformerLM and ModelManager. Update the HTML template to include these fields.”
|
38 |
+
Implement sliding‑window inference. Add a function infer_long_sequence as described above and expose it via the dashboard and MCP server.
|
39 |
+
Prompt example: “Add a new endpoint /infer_long to mcp_server.py that accepts a list of bits and processes them using a sliding window with overlap. The endpoint should return the predicted bits and telemetry summaries for each window.”
|
40 |
+
Allow dynamic context scaling. Add a method to BitTransformerLM to adjust its pos_enc buffer when the context exceeds the current max_seq_len. This can be done by creating a new positional encoding tensor with the new length and copying the existing values.
|
41 |
+
Prompt example: “Implement BitTransformerLM.expand_positional_encoding(new_len: int) that creates a new positional encoding buffer of size new_len and copies the existing encoding. Update the model’s max_seq_len accordingly.”
|
42 |
+
Integrate aggressive compression. Implement alternative compression schemes (e.g., bit‑packing or general‑purpose compressors) and add toggles for them in training and inference. Evaluate compression ratio and latency to decide when to use them.
|
43 |
+
Benchmark and tune hyperparameters. Write scripts to benchmark model memory use and throughput for various max_seq_len, chunk_size, reversible, use_act, and quantization settings. These benchmarks can inform safe defaults for the VM build.
|
BitTransformerLM/example.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bit_transformer import example_training_step
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
loss, telemetry = example_training_step()
|
5 |
+
print("Training loss:", loss)
|
6 |
+
print("Available telemetry:", list(telemetry.keys()))
|
BitTransformerLM/full_bits_train.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import torch
|
3 |
+
from bit_transformer import BitTransformerLM
|
4 |
+
|
5 |
+
DATA_PATH = pathlib.Path('full_bits.pt')
|
6 |
+
|
7 |
+
class BitSeq(torch.utils.data.IterableDataset):
|
8 |
+
def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
|
9 |
+
self.bits = torch.load(path, mmap=True)
|
10 |
+
self.seq = seq
|
11 |
+
|
12 |
+
def __len__(self) -> int:
|
13 |
+
return (self.bits.numel() // self.seq) - 1
|
14 |
+
|
15 |
+
def __iter__(self):
|
16 |
+
N = (self.bits.numel() // self.seq) - 1
|
17 |
+
for i in range(N):
|
18 |
+
s = i * self.seq
|
19 |
+
yield (
|
20 |
+
self.bits[s:s+self.seq].long(),
|
21 |
+
self.bits[s+1:s+self.seq+1].long(),
|
22 |
+
)
|
23 |
+
|
24 |
+
def main() -> None:
|
25 |
+
dl = torch.utils.data.DataLoader(
|
26 |
+
BitSeq(DATA_PATH, seq=2048),
|
27 |
+
batch_size=8,
|
28 |
+
num_workers=0,
|
29 |
+
pin_memory=False,
|
30 |
+
)
|
31 |
+
|
32 |
+
model = BitTransformerLM(
|
33 |
+
d_model=64,
|
34 |
+
nhead=4,
|
35 |
+
num_layers=2,
|
36 |
+
dim_feedforward=256,
|
37 |
+
max_seq_len=2048,
|
38 |
+
reversible=True,
|
39 |
+
use_autocast=True,
|
40 |
+
)
|
41 |
+
|
42 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
43 |
+
xb, yb = next(iter(dl))
|
44 |
+
logits, _ = model(xb)
|
45 |
+
pred = logits.reshape(-1, 2)
|
46 |
+
target = yb.reshape(-1)
|
47 |
+
loss = loss_fn(pred, target)
|
48 |
+
print('Batch loss:', float(loss))
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
main()
|
BitTransformerLM/integration_flow.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.profiler import profile
|
3 |
+
from bit_transformer import (
|
4 |
+
BitTransformerLM,
|
5 |
+
quantize_dynamic,
|
6 |
+
hil_safe_inference,
|
7 |
+
collapse_submodel,
|
8 |
+
)
|
9 |
+
from bit_transformer.training import train_loop
|
10 |
+
from bit_transformer.torch_utils import cpu_autocast
|
11 |
+
|
12 |
+
def train(
|
13 |
+
model: BitTransformerLM,
|
14 |
+
data: torch.Tensor,
|
15 |
+
epochs: int = 3,
|
16 |
+
compress_prob: float = 0.5,
|
17 |
+
direct_prob: float = 0.0,
|
18 |
+
log: bool = False,
|
19 |
+
forward_kwargs: dict | None = None,
|
20 |
+
) -> list[dict]:
|
21 |
+
"""Train on bit sequences with optional random compression.
|
22 |
+
|
23 |
+
If ``direct_prob`` is positive, some batches are fed using their
|
24 |
+
run-length encoded representation packed into bits. Loss on these
|
25 |
+
direct-compressed batches is tracked separately.
|
26 |
+
|
27 |
+
Returns a list of per-epoch metric dictionaries containing raw and
|
28 |
+
compressed loss/accuracy statistics and the mean compression ratio.
|
29 |
+
"""
|
30 |
+
return train_loop(
|
31 |
+
model,
|
32 |
+
data,
|
33 |
+
epochs=epochs,
|
34 |
+
compress_prob=compress_prob,
|
35 |
+
direct_prob=direct_prob,
|
36 |
+
log=log,
|
37 |
+
forward_kwargs=forward_kwargs,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def main() -> None:
|
42 |
+
data = torch.randint(0, 2, (64, 128), dtype=torch.long)
|
43 |
+
validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long)
|
44 |
+
input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long)
|
45 |
+
bit_sequence_data = data.tolist()
|
46 |
+
|
47 |
+
model = BitTransformerLM(
|
48 |
+
d_model=32,
|
49 |
+
nhead=4,
|
50 |
+
num_layers=1,
|
51 |
+
dim_feedforward=64,
|
52 |
+
max_seq_len=128,
|
53 |
+
use_act=True,
|
54 |
+
act_threshold=0.7,
|
55 |
+
reversible=True,
|
56 |
+
chunk_size=128,
|
57 |
+
)
|
58 |
+
|
59 |
+
for step in range(1, 13):
|
60 |
+
if step % 2 == 0:
|
61 |
+
model = model.double_width()
|
62 |
+
else:
|
63 |
+
model = model.double_layers()
|
64 |
+
train(model, data, epochs=3, compress_prob=0.5, log=True)
|
65 |
+
_, telemetry = model(validation_bits)
|
66 |
+
K = telemetry["negentropy_logits"].mean().item()
|
67 |
+
C = telemetry["lz_complexity_logits"].mean().item()
|
68 |
+
S = telemetry["symbiosis_score"].mean().item()
|
69 |
+
assert (
|
70 |
+
K > 0.3 and C > 0.35 and S > 0.5
|
71 |
+
), f"Step {step} telemetry floor failure"
|
72 |
+
|
73 |
+
with cpu_autocast():
|
74 |
+
model(input_bits)
|
75 |
+
|
76 |
+
quantized_model = quantize_dynamic(model)
|
77 |
+
quantized_model.eval()
|
78 |
+
|
79 |
+
safe_output, _ = hil_safe_inference(
|
80 |
+
quantized_model, input_bits, c_floor=0.35, s_floor=0.5
|
81 |
+
)
|
82 |
+
|
83 |
+
student_model, _ = collapse_submodel(
|
84 |
+
bit_sequence_data,
|
85 |
+
target_params=dict(
|
86 |
+
d_model=16,
|
87 |
+
nhead=4,
|
88 |
+
num_layers=1,
|
89 |
+
dim_feedforward=32,
|
90 |
+
max_seq_len=128,
|
91 |
+
),
|
92 |
+
floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5},
|
93 |
+
)
|
94 |
+
|
95 |
+
compiled_model = (
|
96 |
+
torch.compile(student_model)
|
97 |
+
if hasattr(torch, "compile")
|
98 |
+
else student_model
|
99 |
+
)
|
100 |
+
compiled_model.eval()
|
101 |
+
|
102 |
+
with profile() as prof:
|
103 |
+
compiled_model(input_bits)
|
104 |
+
|
105 |
+
prof.export_chrome_trace("trace12.json")
|
106 |
+
print("Safe output bits:", safe_output.squeeze(0).tolist())
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
BitTransformerLM/integration_schedule.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
from itertools import cycle
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from bit_transformer import (
|
10 |
+
BitTransformerLM,
|
11 |
+
text_to_bits,
|
12 |
+
quantize_dynamic,
|
13 |
+
prepare_qat_fx,
|
14 |
+
convert_qat_fx,
|
15 |
+
hil_safe_inference,
|
16 |
+
collapse_submodel,
|
17 |
+
diffusion_inference,
|
18 |
+
TelemetrySynthesizer,
|
19 |
+
save_distilled_model,
|
20 |
+
)
|
21 |
+
from bit_transformer.training import train_loop as train
|
22 |
+
from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
|
23 |
+
from bit_transformer.utils import save_model, load_model, set_dropout
|
24 |
+
from bit_transformer.torch_utils import cpu_autocast
|
25 |
+
|
26 |
+
|
27 |
+
def lines_to_tensor(lines, max_len):
|
28 |
+
seqs = []
|
29 |
+
for text in lines:
|
30 |
+
bits = text_to_bits(text)[:max_len]
|
31 |
+
if len(bits) < max_len:
|
32 |
+
bits.extend([0] * (max_len - len(bits)))
|
33 |
+
seqs.append(bits)
|
34 |
+
return torch.tensor(seqs, dtype=torch.long)
|
35 |
+
|
36 |
+
|
37 |
+
def load_wikitext(dataset_size=128, max_len=64):
|
38 |
+
try:
|
39 |
+
from datasets import load_dataset
|
40 |
+
|
41 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
|
42 |
+
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
|
43 |
+
valid_split = max(1, dataset_size // 4)
|
44 |
+
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
|
45 |
+
train = lines_to_tensor(train_lines, max_len)
|
46 |
+
valid = lines_to_tensor(valid_lines, max_len)
|
47 |
+
return train, valid, train_lines
|
48 |
+
except Exception as e:
|
49 |
+
print("Dataset load failed, using random bits", e)
|
50 |
+
train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
|
51 |
+
valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
|
52 |
+
return train, valid, ["" for _ in range(len(train))]
|
53 |
+
|
54 |
+
|
55 |
+
def _warmup(
|
56 |
+
model: BitTransformerLM,
|
57 |
+
data: torch.Tensor,
|
58 |
+
steps: int = 5,
|
59 |
+
freeze_old: bool = False,
|
60 |
+
old_layers: int = 0,
|
61 |
+
*,
|
62 |
+
diffusion: bool = False,
|
63 |
+
curriculum: bool = False,
|
64 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
65 |
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
66 |
+
) -> None:
|
67 |
+
"""Run a short warm-up loop after expansion."""
|
68 |
+
model.train()
|
69 |
+
set_dropout(model, 0.1)
|
70 |
+
if freeze_old:
|
71 |
+
for idx, layer in enumerate(model.layers):
|
72 |
+
if idx < old_layers:
|
73 |
+
for p in layer.parameters():
|
74 |
+
p.requires_grad_(False)
|
75 |
+
if optimizer is None or scheduler is None:
|
76 |
+
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
|
77 |
+
it = iter(data.split(8))
|
78 |
+
for idx in range(steps):
|
79 |
+
try:
|
80 |
+
batch = next(it)
|
81 |
+
except StopIteration:
|
82 |
+
it = iter(data.split(8))
|
83 |
+
batch = next(it)
|
84 |
+
if diffusion:
|
85 |
+
p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
|
86 |
+
noise = (torch.rand_like(batch.float()) < p).long()
|
87 |
+
noisy = batch ^ noise
|
88 |
+
logits, _ = model(noisy, causal=False)
|
89 |
+
pred = logits.reshape(-1, 2)
|
90 |
+
target = batch.reshape(-1)
|
91 |
+
else:
|
92 |
+
logits, _ = model(batch)
|
93 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
94 |
+
target = batch[:, 1:].reshape(-1)
|
95 |
+
loss = F.cross_entropy(pred, target)
|
96 |
+
loss.backward()
|
97 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
98 |
+
optimizer.step()
|
99 |
+
scheduler.step()
|
100 |
+
optimizer.zero_grad()
|
101 |
+
for p in model.parameters():
|
102 |
+
p.requires_grad_(True)
|
103 |
+
model.eval()
|
104 |
+
set_dropout(model, 0.0)
|
105 |
+
|
106 |
+
|
107 |
+
def integration_schedule(
|
108 |
+
steps: int = 10,
|
109 |
+
max_len: int = 64,
|
110 |
+
dataset_size: int = 128,
|
111 |
+
*,
|
112 |
+
weights_path: str = "weights/model.pt.gz",
|
113 |
+
plateau_steps: int = 0,
|
114 |
+
collapsed_path: str | None = None,
|
115 |
+
epochs_per_step: int = 2,
|
116 |
+
extra_steps: int = 3,
|
117 |
+
collapse: bool = True,
|
118 |
+
diffusion: bool = False,
|
119 |
+
noise_schedule: str = "linear",
|
120 |
+
diffusion_steps: int = 8,
|
121 |
+
diffusion_curriculum: bool = False,
|
122 |
+
use_checkpoint: bool = True,
|
123 |
+
reversible: bool = True,
|
124 |
+
improve_thresh: float = 0.01,
|
125 |
+
qat: bool = False,
|
126 |
+
):
|
127 |
+
start = time.time()
|
128 |
+
train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
|
129 |
+
if os.path.exists(weights_path):
|
130 |
+
try:
|
131 |
+
model = load_model(weights_path)
|
132 |
+
print(f"Loaded model from {weights_path}")
|
133 |
+
except Exception as e:
|
134 |
+
print("Failed to load weights, initializing new model", e)
|
135 |
+
model = BitTransformerLM(
|
136 |
+
d_model=32,
|
137 |
+
nhead=4,
|
138 |
+
num_layers=1,
|
139 |
+
dim_feedforward=64,
|
140 |
+
max_seq_len=max_len,
|
141 |
+
use_act=True,
|
142 |
+
act_threshold=0.7,
|
143 |
+
reversible=reversible,
|
144 |
+
chunk_size=max_len,
|
145 |
+
use_autocast=True,
|
146 |
+
use_checkpoint=use_checkpoint,
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
model = BitTransformerLM(
|
150 |
+
d_model=32,
|
151 |
+
nhead=4,
|
152 |
+
num_layers=1,
|
153 |
+
dim_feedforward=64,
|
154 |
+
max_seq_len=max_len,
|
155 |
+
use_act=True,
|
156 |
+
act_threshold=0.7,
|
157 |
+
reversible=reversible,
|
158 |
+
chunk_size=max_len,
|
159 |
+
use_autocast=True,
|
160 |
+
use_checkpoint=use_checkpoint,
|
161 |
+
)
|
162 |
+
if qat:
|
163 |
+
model = prepare_qat_fx(model)
|
164 |
+
results = []
|
165 |
+
scale_cycle = cycle(["layers", "width", "context"])
|
166 |
+
base_lr = 1e-3
|
167 |
+
prev_val_loss: Optional[float] = None
|
168 |
+
for step in range(steps):
|
169 |
+
model.train()
|
170 |
+
set_dropout(model, 0.1)
|
171 |
+
opt, sched = configure_optimizer(
|
172 |
+
model, lr=base_lr, total_steps=epochs_per_step
|
173 |
+
)
|
174 |
+
train(
|
175 |
+
model,
|
176 |
+
train_bits,
|
177 |
+
epochs=epochs_per_step,
|
178 |
+
extra_steps=extra_steps,
|
179 |
+
compress_prob=0.0 if diffusion else 1.0,
|
180 |
+
log=True,
|
181 |
+
diffusion=diffusion,
|
182 |
+
diffusion_curriculum=diffusion_curriculum,
|
183 |
+
optimizer=opt,
|
184 |
+
scheduler=sched,
|
185 |
+
)
|
186 |
+
|
187 |
+
model.eval()
|
188 |
+
set_dropout(model, 0.0)
|
189 |
+
with torch.no_grad():
|
190 |
+
logits, telemetry = model(valid_bits, causal=not diffusion)
|
191 |
+
if diffusion:
|
192 |
+
pred = logits.reshape(-1, 2)
|
193 |
+
target = valid_bits.reshape(-1)
|
194 |
+
else:
|
195 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
196 |
+
target = valid_bits[:, 1:].reshape(-1)
|
197 |
+
val_loss = F.cross_entropy(pred, target).item()
|
198 |
+
k = telemetry["negentropy_logits"].mean().item()
|
199 |
+
c = telemetry["lz_complexity_logits"].mean().item()
|
200 |
+
s = telemetry["symbiosis_score"].mean().item()
|
201 |
+
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
|
202 |
+
results.append((step, val_loss, k, c, s))
|
203 |
+
|
204 |
+
if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
|
205 |
+
strategy = next(scale_cycle)
|
206 |
+
base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
|
207 |
+
if strategy == "layers":
|
208 |
+
old_layers = model.num_layers
|
209 |
+
model = model.double_layers()
|
210 |
+
warm_opt, warm_sched = configure_optimizer(
|
211 |
+
model, lr=base_lr, total_steps=100
|
212 |
+
)
|
213 |
+
_warmup(
|
214 |
+
model,
|
215 |
+
train_bits,
|
216 |
+
steps=100,
|
217 |
+
freeze_old=True,
|
218 |
+
old_layers=old_layers,
|
219 |
+
diffusion=diffusion,
|
220 |
+
curriculum=diffusion_curriculum,
|
221 |
+
optimizer=warm_opt,
|
222 |
+
scheduler=warm_sched,
|
223 |
+
)
|
224 |
+
elif strategy == "width":
|
225 |
+
model = model.double_width()
|
226 |
+
warm_opt, warm_sched = configure_optimizer(
|
227 |
+
model, lr=base_lr, total_steps=100
|
228 |
+
)
|
229 |
+
_warmup(
|
230 |
+
model,
|
231 |
+
train_bits,
|
232 |
+
steps=100,
|
233 |
+
diffusion=diffusion,
|
234 |
+
curriculum=diffusion_curriculum,
|
235 |
+
optimizer=warm_opt,
|
236 |
+
scheduler=warm_sched,
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
max_len *= 2
|
240 |
+
train_bits, valid_bits, train_lines = load_wikitext(
|
241 |
+
dataset_size, max_len
|
242 |
+
)
|
243 |
+
model = model.double_length()
|
244 |
+
warm_opt, warm_sched = configure_optimizer(
|
245 |
+
model, lr=base_lr, total_steps=100
|
246 |
+
)
|
247 |
+
_warmup(
|
248 |
+
model,
|
249 |
+
train_bits,
|
250 |
+
steps=100,
|
251 |
+
diffusion=diffusion,
|
252 |
+
curriculum=diffusion_curriculum,
|
253 |
+
optimizer=warm_opt,
|
254 |
+
scheduler=warm_sched,
|
255 |
+
)
|
256 |
+
|
257 |
+
prev_val_loss = val_loss
|
258 |
+
if time.time() - start > 8 * 60:
|
259 |
+
print("Time limit reached")
|
260 |
+
break
|
261 |
+
|
262 |
+
# optional plateau phase at final size
|
263 |
+
for p in range(plateau_steps):
|
264 |
+
model.train()
|
265 |
+
set_dropout(model, 0.1)
|
266 |
+
train(
|
267 |
+
model,
|
268 |
+
train_bits,
|
269 |
+
epochs=epochs_per_step,
|
270 |
+
extra_steps=extra_steps,
|
271 |
+
compress_prob=0.0 if diffusion else 1.0,
|
272 |
+
log=True,
|
273 |
+
diffusion=diffusion,
|
274 |
+
diffusion_curriculum=diffusion_curriculum,
|
275 |
+
)
|
276 |
+
model.eval()
|
277 |
+
set_dropout(model, 0.0)
|
278 |
+
with torch.no_grad():
|
279 |
+
logits, telemetry = model(valid_bits, causal=not diffusion)
|
280 |
+
if diffusion:
|
281 |
+
pred = logits.reshape(-1, 2)
|
282 |
+
target = valid_bits.reshape(-1)
|
283 |
+
else:
|
284 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
285 |
+
target = valid_bits[:, 1:].reshape(-1)
|
286 |
+
val_loss = F.cross_entropy(pred, target).item()
|
287 |
+
k = telemetry["negentropy_logits"].mean().item()
|
288 |
+
c = telemetry["lz_complexity_logits"].mean().item()
|
289 |
+
s = telemetry["symbiosis_score"].mean().item()
|
290 |
+
idx = steps + p
|
291 |
+
print(
|
292 |
+
f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
|
293 |
+
)
|
294 |
+
results.append((idx, val_loss, k, c, s))
|
295 |
+
if time.time() - start > 8 * 60:
|
296 |
+
print("Time limit reached")
|
297 |
+
break
|
298 |
+
|
299 |
+
# final validation after last step
|
300 |
+
model.eval()
|
301 |
+
set_dropout(model, 0.0)
|
302 |
+
with torch.no_grad():
|
303 |
+
logits, telemetry = model(valid_bits, causal=not diffusion)
|
304 |
+
if diffusion:
|
305 |
+
pred = logits.reshape(-1, 2)
|
306 |
+
target = valid_bits.reshape(-1)
|
307 |
+
else:
|
308 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
309 |
+
target = valid_bits[:, 1:].reshape(-1)
|
310 |
+
val_loss = F.cross_entropy(pred, target).item()
|
311 |
+
k = telemetry["negentropy_logits"].mean().item()
|
312 |
+
c = telemetry["lz_complexity_logits"].mean().item()
|
313 |
+
s = telemetry["symbiosis_score"].mean().item()
|
314 |
+
|
315 |
+
print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
|
316 |
+
results.append((steps + plateau_steps, val_loss, k, c, s))
|
317 |
+
|
318 |
+
# persist final model weights for future runs
|
319 |
+
save_model(model, weights_path)
|
320 |
+
|
321 |
+
input_bits = valid_bits[:1]
|
322 |
+
if qat:
|
323 |
+
qmodel = convert_qat_fx(model)
|
324 |
+
else:
|
325 |
+
with cpu_autocast():
|
326 |
+
model(input_bits)
|
327 |
+
qmodel = quantize_dynamic(model)
|
328 |
+
qmodel.eval()
|
329 |
+
try:
|
330 |
+
hil_safe_inference(
|
331 |
+
qmodel,
|
332 |
+
input_bits,
|
333 |
+
c_floor=0.3,
|
334 |
+
s_floor=0.5,
|
335 |
+
causal=not diffusion,
|
336 |
+
strict=not diffusion,
|
337 |
+
)
|
338 |
+
except RuntimeError as e:
|
339 |
+
print("Safety gate triggered", e)
|
340 |
+
collapsed = None
|
341 |
+
if collapse:
|
342 |
+
synth = TelemetrySynthesizer(n_clusters=8)
|
343 |
+
reps = synth.cluster_sequences(model, train_bits[:64])
|
344 |
+
floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
|
345 |
+
collapsed, metrics = collapse_submodel(
|
346 |
+
reps,
|
347 |
+
target_params=dict(
|
348 |
+
d_model=16,
|
349 |
+
nhead=4,
|
350 |
+
num_layers=1,
|
351 |
+
dim_feedforward=32,
|
352 |
+
max_seq_len=max_len,
|
353 |
+
),
|
354 |
+
floors=floors,
|
355 |
+
)
|
356 |
+
collapsed.eval()
|
357 |
+
with torch.no_grad():
|
358 |
+
logits, _ = collapsed(valid_bits)
|
359 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
360 |
+
target = valid_bits[:, 1:].reshape(-1)
|
361 |
+
c_loss = F.cross_entropy(pred, target).item()
|
362 |
+
print("Collapsed model validation loss:", c_loss)
|
363 |
+
if collapsed_path is not None:
|
364 |
+
save_distilled_model(
|
365 |
+
collapsed,
|
366 |
+
collapsed_path,
|
367 |
+
{**metrics, "val_loss": c_loss},
|
368 |
+
floors=floors,
|
369 |
+
)
|
370 |
+
if diffusion:
|
371 |
+
sample = diffusion_inference(
|
372 |
+
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
|
373 |
+
)
|
374 |
+
print("Diffusion sample:", sample[0].tolist())
|
375 |
+
return results, collapsed
|
376 |
+
|
377 |
+
|
378 |
+
if __name__ == "__main__":
|
379 |
+
integration_schedule()
|
BitTransformerLM/mcp_server.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import gzip
|
4 |
+
import uuid
|
5 |
+
import traceback
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
from flask import Flask, request, jsonify, send_file
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from bit_transformer.dashboard_app import ModelManager
|
12 |
+
from bit_transformer.dashboard import plot_telemetry
|
13 |
+
from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
|
14 |
+
from bit_transformer.optimization import configure_optimizer
|
15 |
+
from bit_transformer.bit_io import text_to_bits
|
16 |
+
|
17 |
+
app = Flask(__name__)
|
18 |
+
manager = ModelManager()
|
19 |
+
|
20 |
+
# background job management
|
21 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
22 |
+
jobs: dict[str, dict] = {}
|
23 |
+
|
24 |
+
|
25 |
+
def _submit_job(fn, *args, **kwargs) -> str:
|
26 |
+
"""Schedule a function for background execution and return a job id."""
|
27 |
+
job_id = str(uuid.uuid4())
|
28 |
+
jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []}
|
29 |
+
|
30 |
+
def wrapper():
|
31 |
+
jobs[job_id]["status"] = "running"
|
32 |
+
try:
|
33 |
+
jobs[job_id]["result"] = fn(*args, **kwargs)
|
34 |
+
jobs[job_id]["status"] = "completed"
|
35 |
+
except Exception as err: # pragma: no cover - captured for client
|
36 |
+
jobs[job_id]["status"] = "error"
|
37 |
+
jobs[job_id]["error"] = str(err)
|
38 |
+
jobs[job_id]["trace"] = traceback.format_exc()
|
39 |
+
|
40 |
+
executor.submit(wrapper)
|
41 |
+
return job_id
|
42 |
+
|
43 |
+
|
44 |
+
@app.errorhandler(Exception)
|
45 |
+
def handle_exception(err):
|
46 |
+
"""Return JSON error responses with stack traces."""
|
47 |
+
return (
|
48 |
+
jsonify({"error": str(err), "trace": traceback.format_exc()}),
|
49 |
+
getattr(err, "code", 500),
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
@app.route("/init", methods=["POST"])
|
54 |
+
def init_model():
|
55 |
+
data = request.json or {}
|
56 |
+
int_fields = {
|
57 |
+
"d_model",
|
58 |
+
"nhead",
|
59 |
+
"num_layers",
|
60 |
+
"dim_feedforward",
|
61 |
+
"max_seq_len",
|
62 |
+
"chunk_size",
|
63 |
+
"overlap",
|
64 |
+
}
|
65 |
+
float_fields = {"act_threshold"}
|
66 |
+
bool_fields = {"reversible", "use_checkpoint"}
|
67 |
+
params = {}
|
68 |
+
for k, v in data.items():
|
69 |
+
if v is None:
|
70 |
+
params[k] = None
|
71 |
+
elif k in int_fields:
|
72 |
+
params[k] = int(v)
|
73 |
+
elif k in float_fields:
|
74 |
+
params[k] = float(v)
|
75 |
+
elif k in bool_fields:
|
76 |
+
params[k] = bool(v)
|
77 |
+
else:
|
78 |
+
params[k] = v
|
79 |
+
manager.init_model(params)
|
80 |
+
return jsonify({"status": "initialized", "params": params})
|
81 |
+
|
82 |
+
@app.route("/train", methods=["POST"])
|
83 |
+
def train_model():
|
84 |
+
bits = request.json["bits"]
|
85 |
+
|
86 |
+
def task():
|
87 |
+
tensor = torch.tensor(bits, dtype=torch.long)
|
88 |
+
loss, ratio = manager.train_step(tensor)
|
89 |
+
return {"loss": loss, "ratio": ratio}
|
90 |
+
|
91 |
+
job_id = _submit_job(task)
|
92 |
+
return jsonify({"job_id": job_id})
|
93 |
+
|
94 |
+
|
95 |
+
@app.route("/train_epochs", methods=["POST"])
|
96 |
+
def train_epochs_route():
|
97 |
+
data = request.json
|
98 |
+
bits = data["bits"]
|
99 |
+
epochs = int(data.get("epochs", 1))
|
100 |
+
compress_prob = float(data.get("compress_prob", 0.5))
|
101 |
+
direct_prob = float(data.get("direct_prob", 0.0))
|
102 |
+
|
103 |
+
def task():
|
104 |
+
tensor = torch.tensor(bits, dtype=torch.long)
|
105 |
+
metrics = manager.train_epochs(
|
106 |
+
tensor,
|
107 |
+
epochs=epochs,
|
108 |
+
compress_prob=compress_prob,
|
109 |
+
direct_prob=direct_prob,
|
110 |
+
)
|
111 |
+
return {"metrics": metrics}
|
112 |
+
|
113 |
+
job_id = _submit_job(task)
|
114 |
+
return jsonify({"job_id": job_id})
|
115 |
+
|
116 |
+
@app.route("/scale_up", methods=["POST"])
|
117 |
+
def scale_up():
|
118 |
+
width_mult = float(request.json.get("width_mult", 1.0))
|
119 |
+
|
120 |
+
def task():
|
121 |
+
manager.scale_up(width_mult)
|
122 |
+
return {
|
123 |
+
"status": "scaled",
|
124 |
+
"layers": manager.model.num_layers,
|
125 |
+
"d_model": manager.model.d_model,
|
126 |
+
}
|
127 |
+
|
128 |
+
job_id = _submit_job(task)
|
129 |
+
return jsonify({"job_id": job_id})
|
130 |
+
|
131 |
+
@app.route("/collapse", methods=["POST"])
|
132 |
+
def collapse_model():
|
133 |
+
cluster_bits = request.json["clusters"]
|
134 |
+
params = {k: int(v) for k, v in request.json["params"].items()}
|
135 |
+
width_scale = float(request.json.get("width_scale", 1.0))
|
136 |
+
|
137 |
+
def task():
|
138 |
+
manager.collapse(cluster_bits, params, width_scale)
|
139 |
+
return {"status": "collapsed"}
|
140 |
+
|
141 |
+
job_id = _submit_job(task)
|
142 |
+
return jsonify({"job_id": job_id})
|
143 |
+
|
144 |
+
|
145 |
+
@app.route("/job/<job_id>", methods=["GET"])
|
146 |
+
def get_job(job_id: str):
|
147 |
+
job = jobs.get(job_id)
|
148 |
+
if job is None:
|
149 |
+
return jsonify({"error": "not found"}), 404
|
150 |
+
return jsonify(job)
|
151 |
+
|
152 |
+
|
153 |
+
@app.route("/jobs", methods=["GET"])
|
154 |
+
def list_jobs():
|
155 |
+
return jsonify(jobs)
|
156 |
+
|
157 |
+
@app.route("/lambdas", methods=["GET", "POST"])
|
158 |
+
def update_lambdas():
|
159 |
+
if request.method == "POST":
|
160 |
+
data = request.json
|
161 |
+
manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]))
|
162 |
+
return jsonify({"status": "updated"})
|
163 |
+
else:
|
164 |
+
return jsonify({
|
165 |
+
"lambda_K": manager.lambda_K,
|
166 |
+
"lambda_C": manager.lambda_C,
|
167 |
+
"lambda_S": manager.lambda_S,
|
168 |
+
})
|
169 |
+
|
170 |
+
@app.route("/diffusion", methods=["GET", "POST"])
|
171 |
+
def update_diffusion():
|
172 |
+
if request.method == "POST":
|
173 |
+
manager.set_diffusion(bool(request.json.get("diffusion", False)))
|
174 |
+
return jsonify({"status": "updated"})
|
175 |
+
return jsonify({"diffusion": manager.diffusion})
|
176 |
+
|
177 |
+
|
178 |
+
@app.route("/qat", methods=["GET", "POST"])
|
179 |
+
def update_qat():
|
180 |
+
if request.method == "POST":
|
181 |
+
manager.set_qat(bool(request.json.get("qat", False)))
|
182 |
+
return jsonify({"status": "updated"})
|
183 |
+
return jsonify({"qat": manager.qat})
|
184 |
+
|
185 |
+
|
186 |
+
@app.route("/gpu", methods=["GET", "POST"])
|
187 |
+
def update_gpu():
|
188 |
+
if request.method == "POST":
|
189 |
+
manager.set_gpu(bool(request.json.get("use_gpu", False)))
|
190 |
+
return jsonify({"status": "updated"})
|
191 |
+
return jsonify({"use_gpu": manager.use_gpu})
|
192 |
+
|
193 |
+
@app.route("/infer", methods=["POST"])
|
194 |
+
def inference():
|
195 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
196 |
+
result = manager.infer(bits)
|
197 |
+
return jsonify(result)
|
198 |
+
|
199 |
+
|
200 |
+
@app.route("/infer_long", methods=["POST"])
|
201 |
+
def inference_long():
|
202 |
+
bits = torch.tensor(request.json["bits"], dtype=torch.long)
|
203 |
+
ctx = int(request.json.get("ctx_bits", 4096))
|
204 |
+
overlap = int(request.json.get("overlap", 256))
|
205 |
+
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
|
206 |
+
return jsonify(result)
|
207 |
+
|
208 |
+
@app.route("/infer_text", methods=["POST"])
|
209 |
+
def inference_text():
|
210 |
+
text = request.json.get("text", "")
|
211 |
+
result = manager.infer_text(text)
|
212 |
+
return jsonify(result)
|
213 |
+
|
214 |
+
@app.route("/status", methods=["GET"])
|
215 |
+
def status():
|
216 |
+
return jsonify(manager.get_status())
|
217 |
+
|
218 |
+
|
219 |
+
@app.route("/model_config", methods=["GET"])
|
220 |
+
def model_config():
|
221 |
+
return jsonify(manager.get_model_config())
|
222 |
+
|
223 |
+
|
224 |
+
@app.route("/metrics", methods=["GET"])
|
225 |
+
def metrics():
|
226 |
+
return jsonify(manager.get_metrics())
|
227 |
+
|
228 |
+
|
229 |
+
@app.route("/save_checkpoint", methods=["POST"])
|
230 |
+
def save_checkpoint_route():
|
231 |
+
repo_id = request.json.get("repo_id")
|
232 |
+
token = request.json.get("token") or os.getenv("HF_TOKEN")
|
233 |
+
if manager.model is None:
|
234 |
+
return jsonify({"error": "model not initialized"}), 400
|
235 |
+
if token:
|
236 |
+
hf_login(token=token)
|
237 |
+
save_checkpoint(manager.model, repo_id=repo_id)
|
238 |
+
return jsonify({"status": "saved"})
|
239 |
+
|
240 |
+
|
241 |
+
@app.route("/download_checkpoint", methods=["POST"])
|
242 |
+
def download_checkpoint_route():
|
243 |
+
repo_id = request.json.get("repo_id")
|
244 |
+
token = request.json.get("token") or os.getenv("HF_TOKEN")
|
245 |
+
if token:
|
246 |
+
hf_login(token=token)
|
247 |
+
dest = manager.weights_path + ".gz"
|
248 |
+
ok = download_checkpoint(dest, repo_id=repo_id)
|
249 |
+
if not ok:
|
250 |
+
return jsonify({"status": "failed"}), 500
|
251 |
+
if manager.model is None:
|
252 |
+
return jsonify({"status": "downloaded", "loaded": False})
|
253 |
+
with gzip.open(dest, "rb") as f:
|
254 |
+
state = torch.load(f, map_location="cpu")
|
255 |
+
manager.model.load_state_dict(state)
|
256 |
+
manager.optimizer, manager.scheduler = configure_optimizer(
|
257 |
+
manager.model, lr=1e-3, total_steps=manager.total_steps
|
258 |
+
)
|
259 |
+
manager._apply_device()
|
260 |
+
manager._save_state()
|
261 |
+
return jsonify({"status": "downloaded", "loaded": True})
|
262 |
+
|
263 |
+
@app.route("/plot.png")
|
264 |
+
def plot_png():
|
265 |
+
fig, _ = plot_telemetry(manager.metrics)
|
266 |
+
buf = io.BytesIO()
|
267 |
+
fig.savefig(buf, format="png")
|
268 |
+
plt.close(fig)
|
269 |
+
buf.seek(0)
|
270 |
+
return send_file(buf, mimetype="image/png")
|
271 |
+
|
272 |
+
|
273 |
+
@app.route("/text_to_bits", methods=["POST"])
|
274 |
+
def text_to_bits_route():
|
275 |
+
text = request.json.get("text", "")
|
276 |
+
if len(text) > 100_000:
|
277 |
+
return jsonify({"error": "text too large"}), 413
|
278 |
+
return jsonify({"bits": text_to_bits(text)})
|
279 |
+
|
280 |
+
|
281 |
+
@app.route("/dataset", methods=["GET"])
|
282 |
+
def dataset_route():
|
283 |
+
name = request.args.get("name", "")
|
284 |
+
split = request.args.get("split", "train")
|
285 |
+
size = int(request.args.get("size", 1))
|
286 |
+
seq_len = int(request.args.get("seq_len", 64))
|
287 |
+
if size * seq_len > 1_000_000:
|
288 |
+
return jsonify({"error": "dataset too large"}), 413
|
289 |
+
if name == "wikitext2":
|
290 |
+
try:
|
291 |
+
from datasets import load_dataset
|
292 |
+
|
293 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
|
294 |
+
lines = [t for t in ds["text"] if t.strip()][:size]
|
295 |
+
except Exception:
|
296 |
+
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
|
297 |
+
return jsonify({"bits": bits.tolist()})
|
298 |
+
bits_list = []
|
299 |
+
for text in lines:
|
300 |
+
b = text_to_bits(text)[:seq_len]
|
301 |
+
if len(b) < seq_len:
|
302 |
+
b.extend([0] * (seq_len - len(b)))
|
303 |
+
bits_list.append(b)
|
304 |
+
if len(bits_list) < size:
|
305 |
+
pad = size - len(bits_list)
|
306 |
+
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
|
307 |
+
return jsonify({"bits": bits_list})
|
308 |
+
return jsonify({"error": "unknown dataset"}), 400
|
309 |
+
|
310 |
+
|
311 |
+
@app.route("/health")
|
312 |
+
def health_check():
|
313 |
+
return jsonify({"status": "ok"})
|
314 |
+
|
315 |
+
|
316 |
+
def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None:
|
317 |
+
app.run(host=host, port=port, debug=True)
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
import torch
|
322 |
+
run_mcp_server()
|
BitTransformerLM/progressive_scaleup.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Legacy progressive scale-up demo.
|
2 |
+
|
3 |
+
This script is retained for historical reference but has been superseded by
|
4 |
+
``integration_schedule.py`` which provides a more flexible scaling workflow.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import warnings
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from bit_transformer import (
|
12 |
+
BitTransformerLM,
|
13 |
+
configure_optimizer,
|
14 |
+
expand_model,
|
15 |
+
text_to_bits,
|
16 |
+
)
|
17 |
+
from bit_transformer.training import train_loop as basic_train
|
18 |
+
|
19 |
+
warnings.warn(
|
20 |
+
"progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
|
21 |
+
DeprecationWarning,
|
22 |
+
stacklevel=2,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def progressive_scale_up(
|
27 |
+
eps: float = 0.65,
|
28 |
+
steps: int = 2,
|
29 |
+
width_mult: float = 1.0,
|
30 |
+
forward_kwargs: dict | None = None,
|
31 |
+
) -> None:
|
32 |
+
"""Demonstrate automatic scaling of the model on random data."""
|
33 |
+
params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
|
34 |
+
model = BitTransformerLM(**params)
|
35 |
+
steps_per_epoch = 64 // 8
|
36 |
+
optimizer, scheduler = configure_optimizer(
|
37 |
+
model, lr=1e-3, total_steps=steps * steps_per_epoch
|
38 |
+
)
|
39 |
+
|
40 |
+
train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
|
41 |
+
valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
|
42 |
+
|
43 |
+
for step in range(steps):
|
44 |
+
# one epoch over train
|
45 |
+
basic_train(
|
46 |
+
model,
|
47 |
+
train,
|
48 |
+
epochs=1,
|
49 |
+
compress_prob=0.5,
|
50 |
+
log=False,
|
51 |
+
forward_kwargs=forward_kwargs,
|
52 |
+
)
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
logits, _ = model(valid, **(forward_kwargs or {}))
|
56 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
57 |
+
target = valid[:, 1:].reshape(-1)
|
58 |
+
val_loss = F.cross_entropy(pred, target).item()
|
59 |
+
print(f"Step {step} validation loss: {val_loss:.4f}")
|
60 |
+
if val_loss < eps:
|
61 |
+
params["num_layers"] *= 2
|
62 |
+
params["d_model"] = int(params["d_model"] * width_mult)
|
63 |
+
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
|
64 |
+
model = expand_model(model, params)
|
65 |
+
optimizer, scheduler = configure_optimizer(
|
66 |
+
model, lr=1e-3, total_steps=steps * steps_per_epoch
|
67 |
+
)
|
68 |
+
print(
|
69 |
+
"Scaled model to", params["num_layers"], "layers and width", params["d_model"]
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def progressive_scale_up_text(
|
74 |
+
improve_thresh: float = 0.01,
|
75 |
+
steps: int = 2,
|
76 |
+
width_mult: float = 2.0,
|
77 |
+
max_len: int = 64,
|
78 |
+
dataset_size: int = 512,
|
79 |
+
forward_kwargs: dict | None = None,
|
80 |
+
) -> None:
|
81 |
+
"""Scale up using WikiText2 lines converted to bits.
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
improve_thresh: float
|
86 |
+
Relative validation loss improvement required to avoid scaling.
|
87 |
+
If improvement is <= this threshold, model size is increased.
|
88 |
+
steps: int
|
89 |
+
Number of training steps.
|
90 |
+
width_mult: float
|
91 |
+
Multiplier applied when increasing model width.
|
92 |
+
max_len: int
|
93 |
+
Initial sequence length.
|
94 |
+
dataset_size: int
|
95 |
+
Number of training lines to load from WikiText2.
|
96 |
+
forward_kwargs: dict | None
|
97 |
+
Extra keyword arguments for the forward pass.
|
98 |
+
"""
|
99 |
+
from datasets import load_dataset
|
100 |
+
|
101 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
|
102 |
+
train_iter = ds["train"]["text"]
|
103 |
+
valid_iter = ds["validation"]["text"]
|
104 |
+
|
105 |
+
train_lines = []
|
106 |
+
for line in train_iter:
|
107 |
+
train_lines.append(line)
|
108 |
+
if len(train_lines) >= dataset_size:
|
109 |
+
break
|
110 |
+
|
111 |
+
valid_lines = []
|
112 |
+
for line in valid_iter:
|
113 |
+
valid_lines.append(line)
|
114 |
+
if len(valid_lines) >= dataset_size // 4:
|
115 |
+
break
|
116 |
+
|
117 |
+
def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
|
118 |
+
seqs = []
|
119 |
+
for text in lines:
|
120 |
+
bits = text_to_bits(text)[:length]
|
121 |
+
if len(bits) < length:
|
122 |
+
bits.extend([0] * (length - len(bits)))
|
123 |
+
seqs.append(bits)
|
124 |
+
return torch.tensor(seqs, dtype=torch.long)
|
125 |
+
|
126 |
+
train = lines_to_tensor(train_lines, max_len)
|
127 |
+
valid = lines_to_tensor(valid_lines, max_len)
|
128 |
+
|
129 |
+
params = dict(
|
130 |
+
d_model=32,
|
131 |
+
nhead=4,
|
132 |
+
num_layers=1,
|
133 |
+
dim_feedforward=64,
|
134 |
+
max_seq_len=max_len,
|
135 |
+
)
|
136 |
+
model = BitTransformerLM(**params)
|
137 |
+
steps_per_epoch = len(train) // 8
|
138 |
+
optimizer, scheduler = configure_optimizer(
|
139 |
+
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
|
140 |
+
)
|
141 |
+
|
142 |
+
prev_loss: float | None = None
|
143 |
+
scale_length = True
|
144 |
+
|
145 |
+
for step in range(steps):
|
146 |
+
basic_train(
|
147 |
+
model,
|
148 |
+
train,
|
149 |
+
epochs=1,
|
150 |
+
compress_prob=0.5,
|
151 |
+
log=False,
|
152 |
+
forward_kwargs=forward_kwargs,
|
153 |
+
)
|
154 |
+
|
155 |
+
with torch.no_grad():
|
156 |
+
logits, _ = model(valid, **(forward_kwargs or {}))
|
157 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
158 |
+
target = valid[:, 1:].reshape(-1)
|
159 |
+
val_loss = F.cross_entropy(pred, target).item()
|
160 |
+
print(f"Step {step} validation loss: {val_loss:.4f}")
|
161 |
+
if prev_loss is not None:
|
162 |
+
improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
|
163 |
+
if improvement <= improve_thresh:
|
164 |
+
if scale_length:
|
165 |
+
params["max_seq_len"] *= 2
|
166 |
+
train = lines_to_tensor(train_lines, params["max_seq_len"])
|
167 |
+
valid = lines_to_tensor(valid_lines, params["max_seq_len"])
|
168 |
+
model = model.double_length()
|
169 |
+
steps_per_epoch = len(train) // 8
|
170 |
+
scale_type = "length"
|
171 |
+
else:
|
172 |
+
params["d_model"] = int(params["d_model"] * width_mult)
|
173 |
+
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
|
174 |
+
model = expand_model(model, params)
|
175 |
+
scale_type = "width"
|
176 |
+
optimizer, scheduler = configure_optimizer(
|
177 |
+
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
|
178 |
+
)
|
179 |
+
scale_length = not scale_length
|
180 |
+
param_count = sum(p.numel() for p in model.parameters())
|
181 |
+
print(
|
182 |
+
f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
|
183 |
+
)
|
184 |
+
prev_loss = val_loss
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
parser = argparse.ArgumentParser(description="Progressively scale model length and width")
|
189 |
+
parser.add_argument("--steps", type=int, default=2, help="number of training steps")
|
190 |
+
parser.add_argument(
|
191 |
+
"--improve-thresh",
|
192 |
+
type=float,
|
193 |
+
default=0.01,
|
194 |
+
help="relative loss improvement required to avoid scaling",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--width-mult", type=float, default=2.0, help="width multiplier when scaling"
|
198 |
+
)
|
199 |
+
parser.add_argument("--causal", action="store_true", help="use causal attention during training")
|
200 |
+
parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
|
201 |
+
args = parser.parse_args()
|
202 |
+
if args.wikitext:
|
203 |
+
progressive_scale_up_text(
|
204 |
+
improve_thresh=args.improve_thresh,
|
205 |
+
steps=args.steps,
|
206 |
+
width_mult=args.width_mult,
|
207 |
+
forward_kwargs={"causal": args.causal} if args.causal else None,
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
progressive_scale_up(
|
211 |
+
eps=args.improve_thresh,
|
212 |
+
steps=args.steps,
|
213 |
+
width_mult=args.width_mult,
|
214 |
+
forward_kwargs={"causal": args.causal} if args.causal else None,
|
215 |
+
)
|
216 |
+
|
BitTransformerLM/pyproject.toml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=67", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "bit-transformer"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "Bit-based transformer language model nearing v1 release with telemetry"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.10"
|
11 |
+
license = {text = "All Rights Reserved"}
|
12 |
+
authors = [{name = "WCNegentropy"}]
|
13 |
+
|
14 |
+
[project.urls]
|
15 |
+
Homepage = "https://example.com"
|
16 |
+
|
17 |
+
[tool.setuptools.packages.find]
|
18 |
+
include = ["bit_transformer"]
|
BitTransformerLM/recursive_integration_flow.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.profiler import profile
|
4 |
+
from bit_transformer import (
|
5 |
+
BitTransformerLM,
|
6 |
+
quantize_dynamic,
|
7 |
+
hil_safe_inference,
|
8 |
+
collapse_submodel,
|
9 |
+
)
|
10 |
+
from bit_transformer.training import train_loop
|
11 |
+
from bit_transformer.torch_utils import cpu_autocast
|
12 |
+
|
13 |
+
|
14 |
+
def train(
|
15 |
+
model: BitTransformerLM,
|
16 |
+
data: torch.Tensor,
|
17 |
+
epochs: int = 1,
|
18 |
+
compress_prob: float = 0.5,
|
19 |
+
log: bool = False,
|
20 |
+
forward_kwargs: dict | None = None,
|
21 |
+
) -> list[dict]:
|
22 |
+
"""Train with random compression; returns per-epoch metrics."""
|
23 |
+
return train_loop(
|
24 |
+
model,
|
25 |
+
data,
|
26 |
+
epochs=epochs,
|
27 |
+
compress_prob=compress_prob,
|
28 |
+
direct_prob=0.0,
|
29 |
+
log=log,
|
30 |
+
forward_kwargs=forward_kwargs,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None:
|
35 |
+
"""Run a dynamic scale-up loop with telemetry-based gating."""
|
36 |
+
train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long)
|
37 |
+
valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long)
|
38 |
+
input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long)
|
39 |
+
bit_sequence_data = train_bits.tolist()
|
40 |
+
|
41 |
+
best_K = best_C = best_S = 0.0
|
42 |
+
|
43 |
+
model = BitTransformerLM(
|
44 |
+
d_model=32,
|
45 |
+
nhead=4,
|
46 |
+
num_layers=1,
|
47 |
+
dim_feedforward=64,
|
48 |
+
max_seq_len=max_len,
|
49 |
+
use_act=True,
|
50 |
+
act_threshold=0.7,
|
51 |
+
reversible=True,
|
52 |
+
chunk_size=max_len,
|
53 |
+
use_autocast=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
results = []
|
57 |
+
for step in range(steps + 1):
|
58 |
+
epochs = min(10, 2 + step // 2)
|
59 |
+
train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True)
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
with cpu_autocast():
|
63 |
+
logits, telemetry = model(valid_bits)
|
64 |
+
pred = logits[:, :-1, :].reshape(-1, 2)
|
65 |
+
target = valid_bits[:, 1:].reshape(-1)
|
66 |
+
val_loss = F.cross_entropy(pred, target).item()
|
67 |
+
k = telemetry["negentropy_logits"].mean().item()
|
68 |
+
c = telemetry["lz_complexity_logits"].mean().item()
|
69 |
+
s = telemetry["symbiosis_score"].mean().item()
|
70 |
+
|
71 |
+
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
|
72 |
+
results.append((step, val_loss, k, c, s))
|
73 |
+
|
74 |
+
if step > 0:
|
75 |
+
if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3:
|
76 |
+
print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.")
|
77 |
+
break
|
78 |
+
best_K = max(best_K, k)
|
79 |
+
best_C = max(best_C, c)
|
80 |
+
best_S = max(best_S, s)
|
81 |
+
|
82 |
+
if step < steps:
|
83 |
+
if step % 2 == 0:
|
84 |
+
model = model.double_width()
|
85 |
+
else:
|
86 |
+
model = model.double_layers()
|
87 |
+
|
88 |
+
# Post-scaling optimizations
|
89 |
+
with cpu_autocast():
|
90 |
+
model(input_bits)
|
91 |
+
|
92 |
+
qmodel = quantize_dynamic(model)
|
93 |
+
qmodel.eval()
|
94 |
+
|
95 |
+
safe_output = hil_safe_inference(
|
96 |
+
qmodel, input_bits, c_floor=0.5, s_floor=0.2
|
97 |
+
)
|
98 |
+
|
99 |
+
student_model, _ = collapse_submodel(
|
100 |
+
bit_sequence_data,
|
101 |
+
target_params=dict(
|
102 |
+
d_model=16,
|
103 |
+
nhead=4,
|
104 |
+
num_layers=1,
|
105 |
+
dim_feedforward=32,
|
106 |
+
max_seq_len=max_len,
|
107 |
+
),
|
108 |
+
floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2},
|
109 |
+
)
|
110 |
+
|
111 |
+
if hasattr(torch, "compile"):
|
112 |
+
try:
|
113 |
+
compiled = torch.compile(student_model)
|
114 |
+
except RuntimeError as exc:
|
115 |
+
print(f"Compilation skipped: {exc}")
|
116 |
+
compiled = student_model
|
117 |
+
else:
|
118 |
+
compiled = student_model
|
119 |
+
compiled.eval()
|
120 |
+
|
121 |
+
with profile() as prof:
|
122 |
+
compiled(input_bits)
|
123 |
+
prof.export_chrome_trace("trace12.json")
|
124 |
+
print("Safe output bits:", safe_output[0].tolist())
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
recursive_integration_flow()
|
BitTransformerLM/requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
3 |
+
torch==2.7.1+cpu
|
4 |
+
|
5 |
+
# --extra-index-url https://download.pytorch.org/whl/cu118
|
6 |
+
# torch==2.7.1+cu118
|
7 |
+
|
8 |
+
pytest==8.4.1
|
9 |
+
scikit-learn==1.7.1
|
10 |
+
matplotlib==3.10.3
|
11 |
+
datasets==4.0.0
|
12 |
+
flask==3.1.1
|
13 |
+
numpy==2.3.1
|
14 |
+
requests==2.32.3
|
15 |
+
watchdog==6.0.0
|
16 |
+
huggingface-hub==0.34.3
|
BitTransformerLM/review_cli.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from bit_transformer import BitTransformerLM
|
7 |
+
|
8 |
+
|
9 |
+
def list_candidates(path: Path):
|
10 |
+
models = sorted(path.glob("*.pt"))
|
11 |
+
for m in models:
|
12 |
+
metrics_file = m.with_suffix(m.suffix + ".json")
|
13 |
+
metrics = {}
|
14 |
+
if metrics_file.exists():
|
15 |
+
with open(metrics_file) as f:
|
16 |
+
metrics = json.load(f)
|
17 |
+
yield m, metrics
|
18 |
+
|
19 |
+
|
20 |
+
def main():
|
21 |
+
parser = argparse.ArgumentParser(description="Review distilled submodels")
|
22 |
+
parser.add_argument("directory", type=Path, help="Directory with candidate models")
|
23 |
+
parser.add_argument("--approve-dir", type=Path, default=Path("approved"), help="Directory to store approved models")
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
args.approve_dir.mkdir(exist_ok=True)
|
27 |
+
log_file = args.approve_dir / "review_log.jsonl"
|
28 |
+
|
29 |
+
for model_path, metrics in list_candidates(args.directory):
|
30 |
+
print("Candidate:", model_path.name)
|
31 |
+
for k, v in metrics.items():
|
32 |
+
print(f" {k}: {v}")
|
33 |
+
ans = input("Approve this model? [y/N] ").strip().lower()
|
34 |
+
if ans == "y":
|
35 |
+
approved_path = args.approve_dir / model_path.name
|
36 |
+
torch.save(torch.load(model_path), approved_path)
|
37 |
+
entry = {"model": approved_path.name, "metrics": metrics, "approved": True}
|
38 |
+
with open(log_file, "a") as lf:
|
39 |
+
lf.write(json.dumps(entry) + "\n")
|
40 |
+
print("Approved and saved to", approved_path)
|
41 |
+
else:
|
42 |
+
entry = {"model": model_path.name, "metrics": metrics, "approved": False}
|
43 |
+
with open(log_file, "a") as lf:
|
44 |
+
lf.write(json.dumps(entry) + "\n")
|
45 |
+
print("Rejected", model_path.name)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
BitTransformerLM/start.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
python mcp_server.py &
|
4 |
+
sleep 2
|
5 |
+
python -m bit_transformer.dashboard_app
|
BitTransformerLM/state_of_the_repo_audit.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
summary: |
|
2 |
+
The **BitTransformerLM** repository is well-structured and aligns closely with the README’s feature set.
|
3 |
+
All core functionalities (bit-level modeling, telemetry metrics, progressive scaling, compression, context extension, diffusion mode, dashboard, etc.) are present and largely consistent with documentation.
|
4 |
+
The code is generally clean and well-tested (no TODOs or obvious dead code) with an effective CI in place:contentReference[oaicite:0]{index=0}.
|
5 |
+
We identified a few issues via static analysis: a critical **security flaw** where the dashboard’s `/exec` endpoint executes arbitrary code:contentReference[oaicite:1]{index=1}, a missing import that breaks the compression toggle:contentReference[oaicite:2]{index=2}:contentReference[oaicite:3]{index=3}, and a rare edge-case in bit-sequence decompression logic:contentReference[oaicite:4]{index=4}.
|
6 |
+
No functions exceed 300 lines, though the `BitTransformerLM.forward` method is complex with deeply nested logic (~6 levels) and duplicated code blocks for the halting mechanism.
|
7 |
+
Naming conventions are consistent (snake_case for functions, CamelCase for classes), and dependency versions are up-to-date.
|
8 |
+
Documentation and code behavior are in sync – for example, the MCP server’s `/health` endpoint described in docs is implemented:contentReference[oaicite:5]{index=5}.
|
9 |
+
Overall, the project appears **nearly production-ready**, with these fixes and refinements needed before a 1.0 release.
|
10 |
+
|
11 |
+
findings:
|
12 |
+
- severity: P0
|
13 |
+
effort: S
|
14 |
+
category: security
|
15 |
+
location: bit_transformer/dashboard_app.py:533
|
16 |
+
description: "Unrestricted `/exec` HTTP endpoint allows arbitrary code execution:contentReference[oaicite:6]{index=6}."
|
17 |
+
recommendation: "Disable or restrict the `/exec` route (e.g. remove it or require an admin token) to prevent remote code execution."
|
18 |
+
status: completed ✅
|
19 |
+
- severity: P1
|
20 |
+
effort: S
|
21 |
+
category: static
|
22 |
+
location: bit_transformer/dashboard_app.py:195
|
23 |
+
description: "NameError risk – `compress_bits` is used without being imported:contentReference[oaicite:7]{index=7}:contentReference[oaicite:8]{index=8}."
|
24 |
+
recommendation: "Import the `compress_bits` function in `dashboard_app.py` (e.g. `from .compression import compress_bits`) so compression toggles don’t crash."
|
25 |
+
status: completed ✅
|
26 |
+
- severity: P2
|
27 |
+
effort: M
|
28 |
+
category: static
|
29 |
+
location: bit_transformer/model.py:320
|
30 |
+
description: "Edge-case bug – `_maybe_decompress` skips decompression if all values ≤1:contentReference[oaicite:9]{index=9}, which can misinterpret run-length encoding outputs of all 1s."
|
31 |
+
recommendation: "Adjust the decompress condition (e.g. track whether input was compressed) to ensure even uniformly alternating bit sequences get properly decompressed."
|
32 |
+
status: completed ✅
|
33 |
+
- severity: P3
|
34 |
+
effort: M
|
35 |
+
category: static
|
36 |
+
location: bit_transformer/model.py:415
|
37 |
+
description: "Duplicate code – nearly identical halting logic is implemented in both reversible and normal forward loops:contentReference[oaicite:10]{index=10}:contentReference[oaicite:11]{index=11}."
|
38 |
+
recommendation: "Refactor the halting (ACT) mechanism into a helper function to avoid repetition and reduce maintenance effort."
|
39 |
+
status: completed ✅
|
40 |
+
- severity: P3
|
41 |
+
effort: M
|
42 |
+
category: static
|
43 |
+
location: bit_transformer/model.py:368
|
44 |
+
description: "Complex logic – `BitTransformerLM.forward` contains deeply nested control flow (up to 5-6 levels) for reversible layers, ACT, etc."
|
45 |
+
recommendation: "Consider simplifying or breaking up the forward pass (e.g. separate functions for reversible vs. standard flow) to improve readability and maintainability."
|
46 |
+
status: completed ✅
|
47 |
+
- severity: P3
|
48 |
+
effort: S
|
49 |
+
category: static
|
50 |
+
location: bit_transformer/dashboard_app.py:125
|
51 |
+
description: "Config parsing quirk – booleans in `ModelManager.init_model` are cast to int (True→1) instead of preserved as bool."
|
52 |
+
recommendation: "Handle boolean fields explicitly (e.g. do not cast values for keys like `reversible` or `use_act` to int) to avoid confusion and potential type issues."
|
53 |
+
status: completed ✅
|
54 |
+
|
55 |
+
codex_tasks:
|
56 |
+
- codex_prompt: "Remove or secure the dangerous `/exec` endpoint in the dashboard to prevent arbitrary code execution."
|
57 |
+
acceptance_test: |
|
58 |
+
import requests, subprocess
|
59 |
+
Attempt to call the /exec endpoint with a harmless command
|
60 |
+
try:
|
61 |
+
resp = requests.post("http://localhost:5000/exec", json={"code": "print('OK')"}, timeout=5)
|
62 |
+
except Exception as e:
|
63 |
+
resp = e.response if hasattr(e, 'response') else None
|
64 |
+
The endpoint should be removed or secured, so it should either 404 or refuse access
|
65 |
+
assert resp is None or resp.status_code in (403, 404), "Exec endpoint still accessible!"
|
66 |
+
status: completed ✅
|
67 |
+
- codex_prompt: "Import the `compress_bits` function in `dashboard_app.py` so that enabling compression no longer raises a NameError."
|
68 |
+
acceptance_test: |
|
69 |
+
import torch
|
70 |
+
from bit_transformer.dashboard_app import ModelManager
|
71 |
+
mgr = ModelManager()
|
72 |
+
mgr.set_compression(True)
|
73 |
+
bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
|
74 |
+
try:
|
75 |
+
loss, ratio = mgr.train_step(bits)
|
76 |
+
except NameError as e:
|
77 |
+
raise AssertionError(f"NameError not resolved: {e}")
|
78 |
+
assert isinstance(loss, float) and 0 <= ratio <= 1.0, "Compression training failed"
|
79 |
+
status: completed ✅
|
80 |
+
- codex_prompt: "Fix `_maybe_decompress` in `model.py` to always decompress run-length encoded sequences (even if all run lengths are 1) before computing metrics."
|
81 |
+
acceptance_test: |
|
82 |
+
import torch
|
83 |
+
from bit_transformer import BitTransformerLM, compress_bits, decompress_bits
|
84 |
+
Create an alternating bit sequence where compress_bits yields only count=1 values
|
85 |
+
bits = torch.tensor([0,1]*8, dtype=torch.uint8)
|
86 |
+
comp = compress_bits(bits)
|
87 |
+
model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=len(bits))
|
88 |
+
Compute negentropy on compressed vs original and compare
|
89 |
+
neg_comp = model.negentropy_kpi(comp.unsqueeze(0))
|
90 |
+
neg_raw = model.negentropy_kpi(bits.unsqueeze(0))
|
91 |
+
assert torch.allclose(neg_comp, neg_raw, atol=1e-6), "Negentropy differs for compressed input – decompression fix failed"
|
92 |
+
status: completed ✅
|
93 |
+
|
94 |
+
metrics:
|
95 |
+
loc_total: 3770
|
96 |
+
todo_count: 0
|
97 |
+
duplicate_block_count: 3
|
98 |
+
oversized_functions: 0
|