WCNegentropy commited on
Commit
12e8f96
·
verified ·
1 Parent(s): fca1b1f

Upload 65 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BitTransformerLM/.github/workflows/ci.yml +29 -0
  2. BitTransformerLM/.gitignore +103 -0
  3. BitTransformerLM/ABOUTME.md +110 -0
  4. BitTransformerLM/AGENTS.md +66 -0
  5. BitTransformerLM/BitTransformerLM_full_assessment.md +196 -0
  6. BitTransformerLM/Dockerfile +27 -0
  7. BitTransformerLM/LICENSE/ALIGNMENT_AND_TRANSPARENCY.txt +42 -0
  8. BitTransformerLM/LICENSE/COMMERCIAL_LICENSE.txt +34 -0
  9. BitTransformerLM/LICENSE/CONTRIBUTOR_LICENSE_AGREEMENT.txt +7 -0
  10. BitTransformerLM/LICENSE/DISCLAIMER.txt +93 -0
  11. BitTransformerLM/LICENSE/LICENSE.txt +12 -0
  12. BitTransformerLM/LICENSE/TRADEMARK_POLICY.txt +12 -0
  13. BitTransformerLM/NEW_CODEX_TASK.md +85 -0
  14. BitTransformerLM/README.md +177 -0
  15. BitTransformerLM/bit_transformer/__init__.py +86 -0
  16. BitTransformerLM/bit_transformer/bit_io.py +97 -0
  17. BitTransformerLM/bit_transformer/collapse.py +95 -0
  18. BitTransformerLM/bit_transformer/compression.py +82 -0
  19. BitTransformerLM/bit_transformer/dashboard.py +58 -0
  20. BitTransformerLM/bit_transformer/dashboard_app.py +927 -0
  21. BitTransformerLM/bit_transformer/distil.py +90 -0
  22. BitTransformerLM/bit_transformer/distributed.py +30 -0
  23. BitTransformerLM/bit_transformer/hf_checkpoint.py +76 -0
  24. BitTransformerLM/bit_transformer/model.py +875 -0
  25. BitTransformerLM/bit_transformer/optimization.py +37 -0
  26. BitTransformerLM/bit_transformer/parity.py +24 -0
  27. BitTransformerLM/bit_transformer/quantization.py +89 -0
  28. BitTransformerLM/bit_transformer/safety.py +149 -0
  29. BitTransformerLM/bit_transformer/scale.py +36 -0
  30. BitTransformerLM/bit_transformer/static/style.css +93 -0
  31. BitTransformerLM/bit_transformer/telemetry.py +95 -0
  32. BitTransformerLM/bit_transformer/templates/dashboard.html +454 -0
  33. BitTransformerLM/bit_transformer/torch_utils.py +21 -0
  34. BitTransformerLM/bit_transformer/training.py +250 -0
  35. BitTransformerLM/bit_transformer/utils.py +28 -0
  36. BitTransformerLM/bit_transformer_lm_codex_playbook.md +278 -0
  37. BitTransformerLM/build_full_bits.py +23 -0
  38. BitTransformerLM/context_extension.md +43 -0
  39. BitTransformerLM/example.py +6 -0
  40. BitTransformerLM/full_bits_train.py +51 -0
  41. BitTransformerLM/integration_flow.py +110 -0
  42. BitTransformerLM/integration_schedule.py +379 -0
  43. BitTransformerLM/mcp_server.py +322 -0
  44. BitTransformerLM/progressive_scaleup.py +216 -0
  45. BitTransformerLM/pyproject.toml +18 -0
  46. BitTransformerLM/recursive_integration_flow.py +128 -0
  47. BitTransformerLM/requirements.txt +16 -0
  48. BitTransformerLM/review_cli.py +49 -0
  49. BitTransformerLM/start.sh +5 -0
  50. BitTransformerLM/state_of_the_repo_audit.md +98 -0
BitTransformerLM/.github/workflows/ci.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - uses: actions/setup-python@v4
15
+ with:
16
+ python-version: '3.11'
17
+ - name: Install dependencies
18
+ run: |
19
+ pip install --upgrade pip
20
+ pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
21
+ pip install build
22
+ - name: Run tests
23
+ run: pytest -q
24
+ - name: Build package
25
+ run: python -m build --sdist --wheel -o dist
26
+ - uses: actions/upload-artifact@v4
27
+ with:
28
+ name: dist
29
+ path: dist
BitTransformerLM/.gitignore ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+ MANIFEST
26
+
27
+ # PyInstaller
28
+ # Usually these files are written by a python script from a template
29
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .nox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ *.py,cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Jupyter Notebook
52
+ .ipynb_checkpoints
53
+
54
+ # Pyre type checker
55
+ .pyre/
56
+
57
+ # mypy
58
+ .mypy_cache/
59
+
60
+ # Environments
61
+ .env
62
+ .venv
63
+ env/
64
+ venv/
65
+ ENV/
66
+
67
+ # Spyder project settings
68
+ .spyderproject
69
+ .spyproject
70
+
71
+ # Rope project settings
72
+ .ropeproject
73
+
74
+ # IDEs
75
+ .idea/
76
+ .vscode/
77
+
78
+ # macOS
79
+ .DS_Store
80
+
81
+ # Logs
82
+ *.log
83
+
84
+ # Plot outputs
85
+ *.png
86
+ figures/
87
+
88
+ # Model artifacts
89
+ *.pt
90
+ *.pth
91
+ *.bin
92
+ candidates/
93
+ approved/
94
+ review_log.jsonl
95
+
96
+ # Configurations
97
+ *.ini
98
+
99
+ # Local data
100
+ *.sqlite3
101
+
102
+
103
+ *.pt.gz
BitTransformerLM/ABOUTME.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Here’s a menu of additional, “pure-PyTorch” extensions that can close the gap even further to a production-grade LLM:
2
+
3
+
4
+
5
+ 1. Native Low-Rank & MoE Layers (DO LAST)
6
+
7
+ Why: Expert mixtures and low-rank adapters let you balloon effective parameter count without proportional compute.
8
+ • Mixture-of-Experts: Implement a tiny gating network (one or two linear layers) that routes each token’s representation to one of E experts (each a small FFN). Only that expert runs on that position, so compute per token stays constant while total capacity grows by E×.
9
+ • PyTorch sketch:
10
+
11
+ class MoE(nn.Module):
12
+ def __init__(self, d_model, d_ff, n_experts=4):
13
+ super.__init__
14
+ self.gate = nn.Linear(d_model, n_experts)
15
+ self.experts = nn.ModuleList(
16
+ [nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU, nn.Linear(d_ff, d_model))
17
+ for _ in range(n_experts)]
18
+ )
19
+ def forward(self, x):
20
+ # x: [T,B,D]
21
+ logits = self.gate(x) # [T,B,E]
22
+ w = F.softmax(logits, dim=-1) # [T,B,E]
23
+ y = torch.stack([expert(x) for expert in self.experts], -1)
24
+ # y: [T,B,D,E] → weighted sum:
25
+ out = (y * w.unsqueeze(2)).sum(-1)
26
+ return out
27
+
28
+
29
+ • Trade-off: You’ll need a load-balancing loss term (e.g. encourage the gate to spread load) and telemetry on expert usage, but the code stays pure PyTorch.
30
+
31
+
32
+
33
+ 2. [x] Adaptive Computation Time (ACT)
34
+
35
+ Why: Let the model learn to spend more depth on “hard” bits and skip layers on easier ones.
36
+ • Implementation: Add a tiny halting unit after each layer—e.g. a single linear+sigmoid per token that predicts stop/pause. Accumulate “halt probability” across layers and stop processing tokens once they cross a threshold.
37
+ • Benefit: On average you’ll do fewer layer passes per token, reducing compute without touching PyTorch internals.
38
+
39
+
40
+
41
+ 3. [x] Advanced PyTorch-Native Quantization
42
+
43
+ Why: Move beyond static 4-bit packaging to full QAT / dynamic quant.
44
+ • FX-graph QAT: Use torch.quantization.prepare_qat_fx on your SparseQuantTransformerLayer with a custom 4-bit observer (we sketched one earlier). Then convert_fx to int8 or 4-bit for weights—no external libs needed.
45
+ • Dynamic quant for inference: Wrap your model in torch.quantization.quantize_dynamic(...), quantizing only Linear modules to int8 on-the-fly. Gives a big speed/memory win at inference time on CPU.
46
+
47
+
48
+
49
+ 4. [x] Chunked & Overlapping Attention
50
+
51
+ Why: Emulate sparse attention with pure PyTorch and no for-loops.
52
+ • How: Break your sequence into fixed-size chunks (e.g. 512 bits), attend within each chunk plus a small overlap window to neighbors.
53
+ • Pure PyTorch: Use unfold + batched torch.matmul to compute all chunked attention in parallel:
54
+
55
+ x: [B, L, D], chunk_size=C, overlap=O
56
+ pads = (O, O)
57
+ x_padded = F.pad(x, (0,0) + pads) # pad on seq dim
58
+ chunks = x_padded.unfold(1, C+2*O, C) # [B, n_chunks, C+2O, D]
59
+ Then project Q,K,V per-chunk and do fused matmuls batchwise
60
+
61
+
62
+ • Benefit: You get an O(L·(C+2O)) algorithm without Python loops, all in tensor ops.
63
+
64
+
65
+
66
+ 5. Functorch-Based Vectorization & vmap
67
+
68
+ Why: Fuse your per-head or per-expert loops automatically.
69
+ • Use functorch.vmap to turn your per-head attention code (the one inside the for t in range(T)) into a single batched kernel.
70
+ • Benefit: Cleaner code, fewer Python loops, and TorchInductor can fuse it just as well as hand-written loops.
71
+
72
+
73
+
74
+ 6. [x] Fully-Sharded DataParallel & Pipeline Parallel (PyTorch-Native)
75
+
76
+ Why: Scale out to multiple GPUs without external frameworks.
77
+ • FSDP: Wrap your model in torch.distributed.fsdp.FullyShardedDataParallel to shard both parameters and optimizer state across GPUs.
78
+ • Pipe: Use torch.distributed.pipeline.sync.Pipe to split your 40+ layer model across GPUs as pipeline stages.
79
+ • Benefit: Zero external deps—pure PyTorch DDP/FS/PIPE—so you can train 100M+ parameter models.
80
+
81
+
82
+
83
+ 7. [x] Mixed Precision & Autocast on CPU (bfloat16)
84
+
85
+ Why: PyTorch now supports `torch.amp.autocast('cpu')` for bfloat16 on some architectures.
86
+ • Surround your forward in with `torch.amp.autocast('cpu')`: to cut memory and speed up linear/attention kernels, even on CPU.
87
+
88
+
89
+
90
+ 8. [x] Optimized Learning-Rate Schedules & Optimizers
91
+
92
+ Why: Achieve GPT-level convergence behavior…
93
+ • Implement OneCycleLR or CosineAnnealingWarmRestarts directly via torch.optim.lr_scheduler.
94
+ • Swap to AdamW with decoupled weight decay (torch.optim.AdamW) and dynamic gradient clipping (torch.nn.utils.clip_grad_norm_).
95
+ • All of these live in core PyTorch.
96
+
97
+
98
+
99
+ Putting It All Together
100
+ 1. MoE + ACT will let you scale capacity (E× experts) while controlling average compute.
101
+ 2. FX/QAT + dynamic quant gives you 4-bit int inference with no external libs.
102
+ 3. Chunked attention + vmap replaces loops with giant fused tensor ops.
103
+ 4. FSDP + Pipe moves you onto multi-GPU purely in torch.distributed.
104
+ 5. Autocast (bfloat16) on CPU/GPU for mixed precision speed.
105
+
106
+ By layering these techniques, you can:
107
+ • Reach hundreds of millions (even billions) of effective parameters
108
+ • Maintain single-library purity (just PyTorch)
109
+ • Hit LLM-class throughputs (100’s of tokens/sec GPU, 10’s CPU)
110
+ • Keep full NRB telemetry available for safety checks
BitTransformerLM/AGENTS.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGENTS Guidelines for BitTransformerLM
2
+
3
+ ## Repository Scope and Purpose
4
+ - **BitTransformerLM** models raw binary streams using reversible transformer blocks and safety telemetry. The project is the canonical implementation under WCNegentropy.
5
+ - Core capabilities include bit-native modeling, telemetry metrics (negentropy, LZ complexity, symbiosis), progressive scaling, compression, context extension, diffusion mode (linear/cosine/exp noise schedules with parity correction), dashboard control, distributed training, and quantization.
6
+ - Phase 1 optimizations provide configurable batch sizing, gradient accumulation, mixed-precision, memory-mapped dataset streaming, scheduled compression ramps, selective `torch.compile`, and an EMA-smoothed safety gate with burn-in.
7
+
8
+ ## Environment Setup
9
+ - Requires **Python 3.10+**.
10
+ - Install dependencies:
11
+ - CPU: `pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt`
12
+ - Optional GPU: `pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118`
13
+ - The package name is `bit-transformer`; project metadata lives in `pyproject.toml`.
14
+
15
+ ## Repository Layout
16
+ - `bit_transformer/` – core package (`model`, `compression`, `telemetry`, `safety`, `dashboard_app`, `quantization`, etc.).
17
+ - `tests/` – pytest suite and historical `TEST_RESULTS.md`.
18
+ - Scripts: `example.py`, `unified_workflow.py`, `full_bits_train.py`, `build_full_bits.py`, `mcp_server.py`, `wikitext_*` utilities. The legacy `progressive_scaleup.py` is retained for reference but superseded by `integration_schedule.py`.
19
+ - Docs and specs: `README.md`, `state_of_the_repo_audit.md`, licensing files in `LICENSE/`.
20
+
21
+ ## Development Practices
22
+ - Follow snake_case for functions and CamelCase for classes.
23
+ - Keep functions under ~300 lines and minimize deeply nested control flow.
24
+ - Avoid reintroducing the deprecated dashboard `/exec` endpoint or other insecure code paths.
25
+ - Use the `/status` endpoint for model introspection; all routes return JSON and surface errors with stack traces.
26
+ - Ensure compression, decompression, and halting logic stay consistent with current implementation.
27
+ - Use the `cpu_autocast()` helper for BF16 mixed precision on CPU instead of
28
+ calling `torch.amp.autocast` directly.
29
+ - Adaptive training now expands depth, width, or context only when validation loss plateaus and automatically decays the base learning rate by √2 after each expansion with a 100‑step warm‑up.
30
+
31
+ ## Workflow & Commands
32
+ - Run the example: `python example.py`.
33
+ - Adaptive scaling now lives in `integration_schedule.py`; `progressive_scaleup.py` is deprecated.
34
+ - Unified workflow (optionally with dashboard or diffusion): `python unified_workflow.py --dashboard` or `python unified_workflow.py --diffusion --diffusion-steps 8 --dataset-size 32`.
35
+ - Increase `--diffusion-steps` for higher fidelity (8–16) and add `--diffusion-curriculum` to linearly decay noise over epochs.
36
+ - Disable checkpointing or reversible blocks when speed is prioritized over memory: `python unified_workflow.py --no-checkpoint --no-reversible`.
37
+ - Enable 4-bit quantization-aware training: `python unified_workflow.py --qat`.
38
+ - Skip full attention logging during chunked attention for memory savings by constructing the model with `full_attn_logging=False`.
39
+ - Start MCP server: `python mcp_server.py` and launch dashboard: `MCP_SERVER_ADDR=http://127.0.0.1:7000 python -m bit_transformer.dashboard_app`.
40
+ - `/metrics` and `/model_config` endpoints expose telemetry streams and hyperparameters.
41
+ - `/save_checkpoint` and `/download_checkpoint` sync weights with Hugging Face (token defaults to `HF_TOKEN`).
42
+ - Container build: `docker build -t bittransformerlm .` and run with exposed ports `5000` (dashboard) and `7000` (MCP).
43
+
44
+ ## Telemetry Metrics
45
+ | Metric | Meaning | Range |
46
+ |--------|---------|-------|
47
+ | **K** | Negentropy – deviation from random noise | 0–1 (1 = ordered) |
48
+ | **C** | LZ Complexity – compressibility proxy | 0–1 (higher = more changes) |
49
+ | **S** | Symbiosis – agreement with reference distribution | 0–1 (1 = aligned) |
50
+
51
+ ACT halting exports `halt_probs` in telemetry showing how many layers executed. For robust sampling under safety constraints, call `safe_sample_with_retry(model, bits)` which retries with diffusion mode and exponential backoff.
52
+
53
+ `TelemetrySynthesizer.cluster_sequences` can be used to select representative training samples before invoking `collapse_submodel`. The distillation helper deepens the model and widens once (`width_scale` = 1.5) if floors are missed, and `save_distilled_model` emits a `metrics.json` summary beside the weights.
54
+
55
+ ## Testing
56
+ - Run unit tests after any change: `pytest -q`.
57
+ - Use `watcher.py` for auto-reload and test on local development if desired.
58
+ - During training, call `model.train()` and keep dropout probabilities around `0.1–0.2`.
59
+ - Before running tests, inference, or pushing weights, switch to `model.eval()` and set all dropout probabilities to `0` to avoid flaky results.
60
+ - Dashboard will warn if telemetry metrics drift by more than 0.2 over the last 10 steps. Adjust via `ModelManager(drift_window, drift_threshold)` as needed.
61
+
62
+ ## Licensing
63
+ - Project governed by documents in `LICENSE/` (AGPLv3, commercial terms, disclaimers, etc.). Ensure compliance before contributing or distributing.
64
+
65
+ These guidelines keep the repository consistent with the project roadmap and previous audits. Maintain security, style, and testing discipline to keep BitTransformerLM production-ready.
66
+
BitTransformerLM/BitTransformerLM_full_assessment.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # BitTransformerLM Deep-Dive Assessment Report
3
+
4
+ *(Comprehensive technical review and optimization roadmap)*
5
+
6
+ ---
7
+
8
+ ## Completed Tasks
9
+ - [x] 3.1 Cosine noise schedule option
10
+ - [x] 3.2 Post-process parity correction
11
+ - [x] 2.3 Expose checkpoint & reversible toggles
12
+ - [x] 2.2 Update deprecated AMP call
13
+ - [x] 5.2 Metric-drift alerts
14
+ - [x] 1.3 Expand README / docstrings for telemetry & ACT
15
+ - [x] 3.3 Safety-gate soft-retry
16
+ - [x] 7.1 Add ACT halting unit test
17
+ - [x] 4.1 Integrate performance-based scaling
18
+ - [x] 4.2 Learning-rate decay on resize
19
+ - [x] 3.4 Chunked attention logging toggle
20
+ - [x] 3.5 Quantization-aware training toggle
21
+ - [x] 7.2 Quantization & QAT tests
22
+ - [x] 4.3 Dashboard flag wiring
23
+ - [x] 7.3 Dashboard smoke test
24
+ - [x] 2.1 Unify flag names & deprecate legacy scale script
25
+ - [x] 5.1 Telemetry λ and floor UI
26
+ - [x] 5.3 Cluster-based distillation data
27
+ - [x] 6.1 Allow width scaling in collapse loop
28
+ - [x] 6.2 Save distilled metrics summary
29
+
30
+ ## 1. Overview of BitTransformerLM Architecture and Recent Additions
31
+ BitTransformerLM is a **reversible Transformer** that operates **directly on binary sequences (bits)**. The immutable core uses multi-head self-attention on bit embeddings with sinusoidal positional encoding and already supports:
32
+
33
+ * Safety-centric telemetry (negentropy *K*, LZ complexity *C*, symbiosis *S*)
34
+ * Run-length compression / decompression paths
35
+ * Progressive scaling (depth & width) with reversible layers + gradient checkpointing
36
+ * Quantization (dynamic INT8 + optional 4‑bit QAT)
37
+ * A non‑causal **Diffusion‑LM mode** for bidirectional, denoising generation
38
+ * Dashboard, MCP server, and FSDP/pipeline hooks for distributed or edge deployment
39
+
40
+ Recent commits locked in deterministic environment setup (ChatGPT Codex container), removed insecure `/exec` endpoints, and added a reliable *course‑to‑fine* diffusion sampler stub. The model now installs and trains reproducibly on CPU‑only hosts, yet scales to multi‑GPU with FSDP.
41
+
42
+ ---
43
+
44
+ ## 2. Consistent Naming & Documentation
45
+ * Codebase generally follows *snake_case* functions / *CamelCase* classes, but CLI flags & helper scripts drift (e.g. `--diffusion` vs internal `causal=False`).
46
+ **Action:** unify flag names & docstrings; deprecate redundant scripts (`progressive_scaleup.py` vs `integration_schedule.py`).
47
+ * README and inline docs lack quick intuition for *K, C, S* metrics, ACT, and reversible internals.
48
+ **Action:** add short metric primers and ACT demo snippets; update `AGENTS.md` quick‑start table.
49
+
50
+ ---
51
+
52
+ ## 3. Optimizing Module Interactions & Performance
53
+ | Area | Current State | Optimization | Outcome |
54
+ |------|---------------|--------------|---------|
55
+ | **Chunked attention** ✅ | Saves RAM but reconstructs full *T×T* matrix for telemetry | Skip full matrix when `chunk_size < seq_len` and user disables `full_attn_logging` | Same metrics, big memory + speed win on long sequences |
56
+ | **PyTorch 2 features** | Uses `torch.compile` & BF16 autocast inconsistently | Standardize `torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16)`; wrap long loops | 10‑20 % CPU speed‑up, no deprecation warnings |
57
+ | **Reversible + checkpoint** | Always checkpoints → slower when RAM ample | Expose `--no-checkpoint` flag; document trade‑offs | User‑selectable speed vs memory |
58
+ | **Quantization** ✅ | INT8 dynamic works; 4‑bit QAT unused | Add `--qat` toggle in training scripts & unit‑test tiny model | Edge‑ready 4‑bit weights validated |
59
+ | **Compression loops** | Python for‑loops per sample | Batch or vectorized RLE when batch≫8 | Marginal speed‑up for large batches |
60
+
61
+ ---
62
+
63
+ ## 4. Fully Leveraging Diffusion Mode
64
+ 1. [x] **Noise schedule** – switchable linear ▸ cosine ▸ exponential; expose `--noise-schedule`.
65
+ 2. [x] **Step count** – allow 8–16 steps for high‑fidelity generation; document compute trade‑off.
66
+ 3. [x] **Parity safeguard** – post‑sampling parity‑bit fix or strict parity sampling to guarantee valid bytes.
67
+ 4. [x] **Training curriculum** – optional schedule: high‑noise → low‑noise over epochs; keep random‑noise fallback.
68
+ 5. [x] **Safety integration** – run `hil_safe_inference(strict=False)` during diffusion; warn (not crash) on metric floor breaches.
69
+
70
+ ---
71
+
72
+ ## 5. Enhanced Training Workflow & Scaling Strategy
73
+ * **Adaptive scaling trigger** – adopt `progressive_scaleup.py` logic: scale only when val‑loss Δ < threshold; alternate width↔context↔depth.
74
+ * **Context extension** – use `double_length()` when plateau met; maintain chunked attention windows.
75
+ * **Warm‑up & plateau** – keep 5‑batch freeze after each expansion; add default final plateau epoch.
76
+ * **LR hygiene** – slight LR decay each scale‑up; document rationale.
77
+
78
+ ---
79
+
80
+ ## 6. Telemetry Metrics & Safety Integration
81
+ * **Metric coefficients** (`λ_K`, `λ_C`, `λ_S`) exposed via dashboard slider; floors (C ≥ 0.3, S ≥ 0.5) adjustable per deployment.
82
+ * **TelemetrySynthesizer** – cluster activations → representative sequences for distillation & drift detection.
83
+ * **Metric drift alert** – integrate `detect_metric_drift()` into training monitor; log if Δ > 0.2.
84
+
85
+ ---
86
+
87
+ ## 7. Distillation & Model Collapse Optimization
88
+ 1. Use **cluster‑selected sequences** as `cluster_data` for `collapse_submodel` → better coverage.
89
+ 2. Permit optional width growth (`width_scale > 1`) in iterative collapse rounds.
90
+ 3. Log final vs floor metrics in `distilled_metrics.json` for audit trail.
91
+ 4. Optionally auto‑invoke collapse at end of `integration_schedule` with `--auto-collapse`.
92
+
93
+ ---
94
+
95
+ ## 8. Additional Testing & Release Readiness
96
+ * Expand pytest suite: diffusion training/sampling, ACT halting, INT8 + QAT inference, dashboard API smoke tests.
97
+ * Add multi‑GPU CI job to validate FSDP + reversible layers.
98
+ * Strengthen debug logs: print mode (causal/diffusion/compression), scale‑up events, safety‑gate warnings.
99
+
100
+ ---
101
+
102
+ ## 9. Strategic Summary
103
+ BitTransformerLM already delivers an **orthogonal bundle of “firsts”**: bit‑native granularity, reversible memory efficiency, metric‑driven safety, and turnkey text diffusion.
104
+ Executing the roadmap **knits every module into a smooth, reproducible pipeline** without touching core architecture—preserving alignment while boosting usability.
105
+
106
+ **Bottom‑line:** With these refinements, BitTransformerLM becomes the reference for transparent, resource‑efficient, safety‑gated language modelling at the bit level—well beyond “just another model.”
107
+
108
+
109
+ Below is an **implementation playbook** that turns every recommendation in *“Overview of BitTransformerLM Architecture and Recent Additions”* into clear tasks and ready‑to‑copy Codex prompts. Where page numbers add context, I note them; all content is from the uploaded PDF.&#x20;
110
+
111
+ ---
112
+
113
+ ## 1 · Repository Consistency & Documentation
114
+
115
+ | # | Task | Key Steps | Codex Prompt (trim or expand as desired) |
116
+ | --- | -------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
117
+ | 1.1 | **Audit & unify public API names** | • Scan for duplicate / mis‑matched flags (e.g. `--diffusion` vs `causal=False`).<br>• Rename or deprecate aliases; update docs. | “List every function, class, and CLI flag whose name does **not** match the style‑guide (snake\_case for funcs, CamelCase for classes) in the BitTransformerLM repo. For each, propose a single canonical name and generate the automated `git mv` or refactor patches.” |
118
+ | 1.2 | **Consolidate scaling scripts** | • Merge `progressive_scaleup.py` logic into `integration_schedule.py`.<br>• Mark redundant script as example. | “Move the performance‑based scaling criterion from `progressive_scaleup.py` into `integration_schedule.py`. Preserve existing kwargs, add `--improve‑thresh` with default 0.01. Provide diff.” |
119
+ | 1.3 | **Expand README / docstrings for telemetry & ACT** (pp. 1 ‑ 2) | • Add one‑paragraph explanations of Negentropy (K), LZ Complexity (C), Symbiosis (S), and ACT halting to README.<br>• Link to equations in code comments. | “Insert a new subsection *‘Telemetry Metrics Explained’* into README after the quick‑start block, then add in‑line docstrings for `negentropy_score`, `lz_complexity`, and `symbiosis_score` explaining ranges and typical values.” |
120
+
121
+ ---
122
+
123
+ ## 2 · Performance Optimizations
124
+
125
+ | # | Task | Key Steps | Codex Prompt |
126
+ | --- | ------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
127
+ | 2.1 | **Vectorize chunked‑attention telemetry** (p. 2) | • Add flag `--attn‑summary`.<br>• When enabled and `chunked_attn=True`, compute per‑chunk entropy and skip full `T × T` map. | “Refactor `_chunked_attn` in `model.py` so that, if `attn_summary` is true, it returns `(attn_entropy_per_chunk, None)` instead of the stitched full map. Fall back to old behaviour otherwise. Update callers.” |
128
+ | 2.2 | **Update deprecated AMP call** | Replace `torch.cpu.amp.autocast` with `torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16)` everywhere. | “Search repo for `torch.cpu.amp.autocast`, replace with the new API, and add a context‑manager wrapper `cpu_autocast` in `utils/torch_utils.py`.” |
129
+ | 2.3 | **Expose checkpoint & reversible toggles** (p. 2) | • Add CLI flags `--use-checkpoint / --no-checkpoint` and `--reversible`.<br>• Document memory/compute trade‑off. | “Modify `train.py` argparse to include mutually exclusive `--[no-]checkpoint` flags; wire to `use_checkpoint` in model init.” |
130
+ | 2.4 | **Batch run‑length encoding** (p. 3) | • Implement NumPy‑vectorised RLE for the full tensor.<br>• Fallback to Python loop if tensor < 1024 bits. | “Implement `batch_rle_encode` in `bit_io.py` using NumPy strides; write unit test comparing speed & correctness to existing per‑sequence encode.” |
131
+
132
+ ---
133
+
134
+ ## 3 · Diffusion‑Mode Enhancements
135
+
136
+ | # | Task | Key Steps | Codex Prompt | | |
137
+ | --- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ |
138
+ | 3.1 | **Cosine noise schedule option** (p. 4) | • Add \`schedule="linear | cosine | exp"`arg to`diffusion\_inference\`.<br>• Default remains linear. | “Extend `diffusion_inference` to support a cosine decay of `mask_prob` over `steps`. Provide math and update docstring.” |
139
+ | 3.2 | **Post‑process parity correction** (p. 4) | • After sampling, flip each parity bit if byte parity invalid.<br>• Log number of corrections. | “Write `enforce_parity(bits)` that patches 9th bit per byte to satisfy even‑parity, return corrected seq + stats.” | | |
140
+ | 3.3 | **Safety‑gate soft‑retry** | • On failed `hil_safe_inference(strict=True)`, auto‑retry up to 3× with diffusion or random seed.<br>• Surface warning in logs. | “Wrap `hil_safe_inference` in a helper `safe_sample_with_retry`; implement exponential back‑off and logging.” | | |
141
+
142
+ ---
143
+
144
+ ## 4 · Adaptive Training Workflow
145
+
146
+ | # | Task | Key Steps | Codex Prompt |
147
+ | --- | ------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
148
+ | 4.1 | **Integrate performance‑based scaling** (pp. 5‑6) | • Use `Δval_loss < thresh` as condition to trigger `add_layer()`/`double_width()`.<br>• Alternate occasional `double_length()` for context. | “Inside `integration_schedule.train_loop`, compute rolling val‑loss; if mean improvement < `args.improve_thresh`, call `model.scale_up(strategy=next_step)` where `next_step` cycles \[layer, width, context].” |
149
+ | 4.2 | **Learning‑rate decay on resize** | • After each scale‑up, reduce base LR by √2.<br>• Provide warm‑up of 100 steps. | “Add `adjust_learning_rate(optimizer, factor)` util; call it after every successful model expansion.” |
150
+ | 4.3 | **Dashboard flag wiring** | • Map UI toggles (compression, diffusion) to `compress_prob`, `diffusion` args in backend. | “In `dashboard_app.py`, when user toggles compression, pass `compress_prob=1.0` to `ModelManager.train()`.” |
151
+
152
+ ---
153
+
154
+ ## 5 · Telemetry & Safety
155
+
156
+ | # | Task | Key Steps | Codex Prompt |
157
+ | --- | -------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
158
+ | 5.1 | **Expose λ coefficients and safety floors in UI** (p. 7) | • Add sliders for `λ_K`, `λ_C`, `λ_S`, `C_floor`, `S_floor`.<br>• Persist to model state. | “Add REST endpoints `/config/telemetry` (GET/POST) that read or set lambda values and floors; bind to dashboard sliders.” |
159
+ | 5.2 | **Metric‑drift alerts** (p. 8) | • After every epoch, call `detect_metric_drift(history, window=100)`; if > 0.2 drift, log & optionally halt training. | “Integrate `detect_metric_drift` into `ModelManager._log_metrics`; raise `MetricDriftWarning` when threshold exceeded.” |
160
+ | 5.3 | **Cluster‑based distillation data** (pp. 8‑9) | • Use `TelemetrySynthesizer` to pick `k` cluster representatives (default 8).<br>• Feed to `collapse_submodel`. | “Before `collapse_submodel`, run `representatives = TelemetrySynthesizer(model).cluster(train_data, k=8)`. Replace `train_bits[:64]` with `representatives`.” |
161
+
162
+ ---
163
+
164
+ ## 6 · Distillation / Collapse Process
165
+
166
+ | # | Task | Key Steps | Codex Prompt |
167
+ | --- | ----------------------------------------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------- |
168
+ | 6.1 | **Allow width scaling in collapse loop** (p. 8) | • Add `width_scale` param; if metric floors unmet after deepening, double width once then retry. | “Modify `collapse_submodel`: on round‑2 failure, rebuild sub‑model with `hidden_dim *= width_scale` (default 1.5).” |
169
+ | 6.2 | **Save metrics summary** | • Extend `save_distilled_model` to write `metrics.json` with achieved vs floor values. | “Update `save_distilled_model` to dump `{‘C’:score_C, ‘S’:score_S, ‘floors’:{...}}` alongside weights.” |
170
+
171
+ ---
172
+
173
+ ## 7 · Testing & CI Hardening
174
+
175
+ | # | Task | Key Steps | Codex Prompt |
176
+ | --- | ------------------------------------- | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------- |
177
+ | 7.1 | **Add ACT halting unit test** (p. 10) | • Craft toy seq; assert `sum(halt_prob<1) < n_layers`. | “Write `tests/test_act.py` ensuring at least one layer halts early when `use_act=True, threshold=0.1`.” |
178
+ | 7.2 | **Quantization & QAT tests** | • After tiny train, run dynamic int8 + fake‑QAT path, assert same logits ±1e‑3. | “Add `pytest` case: train 2‑layer model 1 epoch, call `quantize_dynamic`, compare outputs on 10 random inputs.” |
179
+ | 7.3 | **Dashboard smoke test** | • In CI, launch Flask app with `pytest‑flask`, hit `/init`, `/train‑step`, `/infer`. | “Create `tests/test_dashboard.py` that starts server in a thread and exercises core endpoints.” |
180
+
181
+ ---
182
+
183
+ ## 8 · Packaging & Release
184
+
185
+ | # | Task | Key Steps | Codex Prompt |
186
+ | --- | ---------------------------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
187
+ | 8.1 | **Rename repository references** (p. 11) | • Replace `Test/` URL stubs with new repo slug.<br>• Update badges in README. | “Search‑replace all GitHub links from `WCNegentropy/Test` to `WCNegentropy/BitTransformerLM`; update badge SVGs.” |
188
+ | 8.2 | **PyPI build verification** | • Ensure `pyproject.toml` installs cleanly on 3.10 & 3.11 in CI. | “Add GitHub Action matrix for {macOS, ubuntu‑latest} × {3.10, 3.11}; run `pip install -e . && pytest`.” |
189
+
190
+ ---
191
+
192
+ ### How to Use These Prompts
193
+
194
+ **Run** unit tests; iterate if failures surface.
195
+
196
+ This checklist should bring BitTransformerLM to a polished, v1‑ready state while aligning with your NRB‑driven safety and telemetry philosophy.&#x20;
BitTransformerLM/Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get install -y python3.11 python3-pip python3.11-venv curl && \
5
+ apt-get clean && rm -rf /var/lib/apt/lists/*
6
+
7
+ WORKDIR /opt/bit_transformer
8
+ COPY . .
9
+
10
+ ARG TORCH_CUDA=cpu
11
+ RUN pip3 install --no-cache-dir --upgrade pip && \
12
+ if [ "$TORCH_CUDA" = "cu118" ]; then \
13
+ pip3 install torch==2.7.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118; \
14
+ else \
15
+ pip3 install torch==2.7.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu; \
16
+ fi && \
17
+ pip3 install -r requirements.txt
18
+
19
+ ENV MCP_SERVER_ADDR=http://127.0.0.1:7000
20
+
21
+ EXPOSE 5000 7000
22
+
23
+ RUN chmod +x start.sh
24
+
25
+ HEALTHCHECK CMD curl -f http://localhost:7000/health || exit 1
26
+
27
+ CMD ["/opt/bit_transformer/start.sh"]
BitTransformerLM/LICENSE/ALIGNMENT_AND_TRANSPARENCY.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alignment and Transparency Agreement for BitTransformerLM
2
+ This Alignment and Transparency Agreement ("Agreement") outlines requirements
3
+ for the responsible and aligned commercial deployment of BitTransformerLM,
4
+ developed by WCNEGENTROPY HOLDINGS LLC.
5
+
6
+ ## Core Principles
7
+ 1. **Alignment:** The deployment must actively maintain alignment with ethical
8
+ and epistemic integrity, avoiding harmful or coercive outcomes.
9
+ 2. **Transparency:** Telemetry data (negentropy, complexity, symbiosis scores,
10
+ etc.) must be transparently maintained and available for audit and inspection.
11
+ 3. **Safety and Stability:** Deployments must utilize provided safety gates
12
+ (e.g., `hil_safe_inference`) to prevent degenerate or harmful outputs.
13
+ 4. **Epistemic Responsibility:** Adopters commit to responsible use, actively
14
+ avoiding misuse or unethical applications.
15
+
16
+ ## Telemetry and Monitoring
17
+ Commercial license holders must maintain full transparency on telemetry metrics
18
+ collected from the software, as originally implemented in the BitTransformerLM
19
+ repository. Telemetry must be available upon request for audit by WCNEGENTROPY HOLDINGS LLC or authorized third parties.
20
+
21
+ ## Modification and Derivatives
22
+ Commercial license holders may modify the software for internal commercial use
23
+ but must explicitly disclose any modifications or derivatives upon request by
24
+ WCNEGENTROPY HOLDINGS LLC.
25
+
26
+ ## Violations
27
+ Non-compliance with any of the terms outlined in this Agreement may result in
28
+ revocation of commercial licensing rights at the sole discretion of WCNEGENTROPY HOLDINGS LLC.
29
+
30
+ ### Audit Cadence (added v0.9.0)
31
+ WCNEGENTROPY HOLDINGS LLC may request a telemetry snapshot **no more than once per
32
+ calendar quarter**. Licensee must deliver the requested data within **30 days
33
+ of receipt**. Failure to comply may result in suspension or termination of
34
+ commercial rights.
35
+
36
+ ---
37
+
38
+ For questions, clarification, or audits, contact:
39
+ **WCNegentropy Holdings**
40
41
+ Website: [wcnegentropy.com](https://wcnegentrop
42
+ y.com)
BitTransformerLM/LICENSE/COMMERCIAL_LICENSE.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Commercial License for BitTransformerLM
2
+ © 2025 WCNEGENTROPY HOLDINGS LLC – All Rights Reserved
3
+
4
+ > **Clarification (dual‑license):** This clause applies only to commercial deployments.
5
+ > Open‑source users remain bound by the GNU AGPL v3 in `LICENSE`.
6
+ > For holders of a **paid Commercial License**, BitTransformerLM is also provided
7
+ > under the **Apache License 2.0** (the “Commercial License”), **subject to the
8
+ > Alignment & Transparency Agreement (ATA)**.
9
+
10
+ BitTransformerLM (the “Software”), including all source code, documentation,
11
+ and associated assets, is the exclusive property of WCNEGENTROPY HOLDINGS LLC
12
+ (“WCNH”). Commercial use, reproduction, modification, distribution, or
13
+ sublicensing requires an executed Commercial License Agreement with WCNH.
14
+
15
+ ## Patent Grant (Defensive)
16
+ WCNH hereby grants Licensee a perpetual, worldwide, non‑exclusive, no‑charge
17
+ patent license to make, use, sell, offer to sell, import, and otherwise
18
+ exploit the Software **provided** Licensee complies with this Commercial
19
+ License and the ATA. **This patent license terminates automatically** if
20
+ Licensee initiates patent litigation alleging that the Software infringes any
21
+ patent claim.
22
+
23
+ ## Export‑Control & Sanctions Compliance
24
+ Licensee shall comply with all applicable export‑control and sanctions laws,
25
+ including but not limited to U.S. EAR, EU dual‑use regulations, and OFAC
26
+ sanctions lists. Licensee must not export, re‑export, or provide the Software
27
+ (or derivatives) to any prohibited country, entity, or individual.
28
+
29
+ ## Alignment & Transparency Obligation
30
+ Commercial usage is conditional upon adherence to the ATA (see
31
+ `ALIGNMENT_AND_TRANSPARENCY.txt`). Failure to comply constitutes a material
32
+ breach and grounds for immediate license revocation.
33
+
34
+ For commercial inquiries contact **[email protected]**.
BitTransformerLM/LICENSE/CONTRIBUTOR_LICENSE_AGREEMENT.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Individual Contributor License Agreement (ICLA)
2
+ By submitting a contribution to the BitTransformerLM repository you agree to
3
+ grant WCNEGENTROPY HOLDINGS LLC an irrevocable, worldwide, royalty‑free
4
+ copyright license to reproduce, prepare derivative works of, publicly display
5
+ and distribute your contribution. You certify that you have the right to make
6
+ this grant and that your contribution is original or you have secured the
7
+ appropriate rights.
BitTransformerLM/LICENSE/DISCLAIMER.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BitTransformerLM – Legal & Risk Disclaimer
2
+ _Last updated: 2025-08-04_
3
+
4
+ BitTransformerLM (the “Software”) is an **experimental, highly-capable, agentic AI
5
+ model** developed by WC Negentropy Holdings LLC (“WCNH”). By downloading,
6
+ installing, running, fine-tuning, or otherwise using the Software **you
7
+ acknowledge and agree to all terms below.**
8
+
9
+ ---
10
+
11
+ ## 1. No Warranty
12
+
13
+ THE SOFTWARE IS PROVIDED **“AS IS”** AND **WITHOUT WARRANTY** OF ANY KIND,
14
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NON-INFRINGEMENT,
16
+ AND ERROR-FREE OR UNINTERRUPTED OPERATION.
17
+ WCNH DOES **NOT** WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS OR
18
+ THAT ITS OUTPUT WILL BE ACCURATE, COMPLETE, OR RELIABLE.
19
+
20
+ ## 2. Limitation of Liability
21
+
22
+ TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL WCNH,
23
+ ITS AFFILIATES, CONTRIBUTORS, OR LICENSORS BE LIABLE FOR ANY DIRECT,
24
+ INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; BUSINESS INTERRUPTION; OR PERSONAL
27
+ INJURY) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
28
+ STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
29
+ ANY WAY OUT OF THE USE OF THE SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
30
+ OF SUCH DAMAGE.
31
+
32
+ ## 3. High-Risk & Regulated Uses
33
+
34
+ **DO NOT DEPLOY** the Software in environments where failure or malfunction
35
+ could lead to death, serious bodily injury, or severe property or
36
+ environmental damage, including but not limited to:
37
+
38
+ - Medical diagnosis or life-support systems
39
+ - Autonomous vehicles or aviation control
40
+ - Nuclear facilities, weapons development, or military combat systems
41
+ - Critical infrastructure (power, water, telecom)
42
+ - Legal, financial, or governmental decision-making without qualified
43
+ human review
44
+
45
+ You remain solely responsible for conducting appropriate risk assessments,
46
+ validation, and human oversight before any production deployment.
47
+
48
+ ## 4. Alignment & Transparency Obligations
49
+
50
+ If you hold a **Commercial License**, you must also comply with the
51
+ Alignment & Transparency Agreement (ATA), including:
52
+
53
+ - Logging K-C-S telemetry and retaining it for audit
54
+ - Supplying telemetry snapshots upon request (max once per quarter)
55
+ - Cooperating with reasonable misuse investigations
56
+
57
+ ## 5. Data & Privacy
58
+
59
+ You are responsible for:
60
+
61
+ - Ensuring you have the legal right to process any data you supply to the
62
+ Software
63
+ - Implementing technical and organizational measures to protect personal or
64
+ sensitive data
65
+ - Complying with all applicable data-protection and privacy laws (e.g.,
66
+ GDPR, CCPA, HIPAA)
67
+
68
+ ## 6. Export & Sanctions Compliance
69
+
70
+ You may **not** use or transfer the Software in violation of U.S. export
71
+ control laws, EU dual-use regulations, or applicable sanctions regimes.
72
+ This includes, but is not limited to, prohibitions on use in or by
73
+ countries, entities, or individuals listed on U.S. or EU restricted-party
74
+ lists.
75
+
76
+ ## 7. Third-Party Dependencies
77
+
78
+ The Software may incorporate open-source components licensed under separate
79
+ terms. Such components are provided **“as is”** and remain subject to their
80
+ respective licenses. A complete list is available in `THIRD_PARTY_LICENSES.txt`.
81
+
82
+ ## 8. No Professional Advice
83
+
84
+ The Software’s outputs (including code, text, and recommendations) do **not**
85
+ constitute professional, legal, medical, financial, or safety advice. Always
86
+ consult a qualified expert before relying on the Software for any critical
87
+ decision.
88
+
89
+ ---
90
+
91
+ **© 2023-2025 WCNEGENTROPY HOLDINGS LLC.**
92
+ All trademarks—including “BitTransformerLM” and the spiral-N logo—are property
93
+ of WCNH. Unauthorized use is prohibited.
BitTransformerLM/LICENSE/LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LICENSE (AGPLv3)
2
+ Copyright (C) 2025 WCNEGENTROPY HOLDINGS LLC
3
+ This program is free software: you can redistribute it and/or modify
4
+ it under the terms of the GNU Affero General Public License as published
5
+ by the Free Software Foundation, either version 3 of the License, or
6
+ (at your option) any later version.
7
+ This program is distributed in the hope that it will be useful,
8
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
+ GNU Affero General Public License for more details.
11
+ You should have received a copy of the GNU Affero General Public License
12
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
BitTransformerLM/LICENSE/TRADEMARK_POLICY.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ “BitTransformerLM” and the spiral‑N logo are trademarks of WCNEGENTROPY HOLDINGS LLC.
2
+
3
+ Permitted use:
4
+ • Describing, linking to, or referencing **unmodified, official builds** of
5
+ BitTransformerLM.
6
+
7
+ Prohibited use without prior written permission:
8
+ • Branding or promoting modified forks, derivatives, or third‑party services.
9
+ • Creating confusingly similar marks or implying endorsement by WCNH.
10
+
11
+ Forks must remove or rename the marks to avoid confusion.
12
+ Contact **[email protected]** for licensing requests.
BitTransformerLM/NEW_CODEX_TASK.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DEPRECATED
2
+
3
+ All tasks in this file have been implemented (Stages 1–5). The document remains for historical reference only.
4
+
5
+ Stage 1: Compression Algorithm Implementation
6
+
7
+ Task 1: Choose Compression Method
8
+
9
+ Prompt:
10
+
11
+ Codex: Provide a concise PyTorch-compatible implementation of lossless binary compression and decompression (e.g., RLE, Huffman, or LZ-based) suitable for binary input sequences represented as tensors of bits.
12
+
13
+ Task 2: Implement Compression Functions
14
+
15
+ Prompt:
16
+
17
+ Codex: Implement PyTorch functions compress_bits(input_tensor) and decompress_bits(compressed_tensor) that accept and return PyTorch tensors (dtype=torch.bool or torch.uint8). Ensure compress → decompress cycle perfectly reconstructs original data, and include simple unit tests.
18
+
19
+
20
+
21
+ Stage 2: Encoder/Decoder Integration
22
+
23
+ Task 3: Add Compression to Encoder Input
24
+
25
+ Prompt:
26
+
27
+ Codex: Modify BitTransformerLM’s input pipeline by wrapping the existing model forward pass with a forward_compressed(bits_tensor) method. This method should decompress incoming compressed bit tensors before embedding. Ensure it returns identical outputs as existing uncompressed inputs for verification.
28
+
29
+ Task 4: Add Decompression to Decoder Output
30
+
31
+ Prompt:
32
+
33
+ Codex: Implement a PyTorch-compatible function model_output_decompress(output_bits_tensor) to decompress bit sequences output by BitTransformerLM. Integrate this function as an optional post-processing step after the model’s bitstream generation.
34
+
35
+
36
+
37
+ Stage 3: Training and Evaluation Enhancements
38
+
39
+ Task 5: Toggle Compression During Training
40
+
41
+ Prompt:
42
+
43
+ Codex: Modify the existing training loop to randomly compress input bit sequences with a configurable probability (compress_prob=0.5). Ensure that when compression is on, inputs are compressed and decompressed transparently, and when off, inputs bypass compression.
44
+
45
+ Task 6: Evaluate Compressed vs Raw Performance
46
+
47
+ Prompt:
48
+
49
+ Codex: Extend the current training evaluation metrics to separately track loss, accuracy, and compression ratio for both compressed and raw sequences. Log these metrics clearly in the training output.
50
+
51
+
52
+
53
+ Stage 4: Advanced Integration (Optional)
54
+
55
+ Task 7: Multi-task Training for Compression Learning
56
+
57
+ Prompt:
58
+
59
+ Codex: Implement an optional multi-task training mode where the model occasionally sees compressed inputs directly without decompression. Add a separate loss calculation to monitor its performance on these compressed inputs. Track and log separately from normal next-bit prediction loss.
60
+
61
+ Task 8: Compression-aware Safety Telemetry
62
+
63
+ Prompt:
64
+
65
+ Codex: Adjust the existing BitTransformerLM telemetry (K, C, and S metrics) to handle compressed sequences appropriately. Modify telemetry calculations to optionally apply metrics to decompressed outputs instead of raw bitstream when compression is enabled.
66
+
67
+
68
+
69
+ Stage 5: Dashboard and Runtime Integration
70
+
71
+ Task 9: Dashboard Compression UI Toggle
72
+
73
+ Prompt:
74
+
75
+ Codex: Add a simple UI toggle labeled “Enable Compression” to the existing BitTransformerLM dashboard, controlling whether inputs and outputs are automatically compressed and decompressed. Display compression ratio metrics when enabled.
76
+
77
+ Task 10: Error Handling and User Feedback
78
+
79
+ Prompt:
80
+
81
+ Codex: Implement graceful error handling in the dashboard for compression and decompression failures. Provide clear user-facing feedback in the UI if decompression fails, along with suggestions or fallbacks.
82
+
83
+
84
+
85
+ These ten tasks enable incremental, testable integration of binary compression/decompression into BitTransformerLM without fundamentally altering the core transformer model itself.
BitTransformerLM/README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BitTransformerLM
2
+
3
+ **Project Status:** Pre-release (v1 candidate)
4
+
5
+ BitTransformerLM is a bit-centric transformer language model built entirely in PyTorch. The project began as a small prototype but has matured into a near-production system capable of modeling raw binary streams with sophisticated safety telemetry and automated scale-up tooling. This repository now serves as the canonical implementation under WCNegentropy.
6
+
7
+ ## Historical Background
8
+ - **Early Experiments** – Initial prototypes explored mapping text to parity-protected bits and training a minimal transformer on random data.
9
+ - **Telemetry & Safety** – Added negentropy, LZ complexity and symbiosis scoring to measure information flow and gate unsafe outputs.
10
+ - **Progressive Scaling** – Introduced reversible layers and automatic depth/width expansion for efficient curriculum training. The schedule now triggers expansions only when validation loss plateaus and decays the learning rate by √2 after each growth with a 100-step warm‑up.
11
+ - **Compression Support** – Integrated run-length encoding and packed bit I/O with optional multi-task training on compressed sequences.
12
+ - **Context Extension** – Implemented chunked attention and sliding-window inference for long sequences with optional overlapping windows.
13
+ - **Attention Logging Toggle** – ``full_attn_logging=False`` skips reconstructing full ``T×T`` attention maps during chunked attention, cutting memory use for very long sequences.
14
+ - **Diffusion LM Mode** – Enable bidirectional denoising by setting ``causal=False`` or toggling **Diffusion LM** in the dashboard. Chunked attention is automatically disabled in this mode and restored afterward.
15
+ - **Dashboard & MCP Server** – Built a lightweight web UI backed by a management server for real‑time training, inference and model collapse. New `/metrics` and `/model_config` endpoints surface live telemetry and hyperparameters, and `/save_checkpoint` and `/download_checkpoint` enable Hugging Face weight sync. The insecure `/exec` route has been removed.
16
+ - **Phase 1 Optimizations** – Configurable batch sizes with aligned OneCycle scheduling, gradient accumulation, mixed‑precision, memory‑mapped dataset streaming, scheduled compression ramps, selective ``torch.compile``, and an EMA‑smoothed safety gate with burn‑in to cut false positives.
17
+
18
+ The codebase has undergone multiple stress tests and synthetic benchmarks (see `tests/TEST_RESULTS.md`) and now approaches a stable release.
19
+
20
+ ## Quick Start
21
+ Install dependencies using the CPU wheel of PyTorch (default):
22
+ ```bash
23
+ pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
24
+ ```
25
+ When GPU acceleration is toggled in the dashboard, the application automatically
26
+ installs the CUDA-enabled wheel:
27
+ ```bash
28
+ pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118
29
+ ```
30
+ Run the example script:
31
+ ```bash
32
+ python example.py
33
+ ```
34
+ Adaptive scaling demo:
35
+ The legacy `progressive_scaleup.py` script is retained for reference but has been
36
+ superseded by `integration_schedule.py`, which offers a more flexible scaling
37
+ workflow.
38
+
39
+ Run the unified workflow:
40
+ ```bash
41
+ python unified_workflow.py --dashboard
42
+ # disable gradient checkpointing for faster but memory-hungry runs
43
+ python unified_workflow.py --no-checkpoint
44
+ # use standard (non-reversible) transformer blocks
45
+ python unified_workflow.py --no-reversible
46
+ # enable 4-bit quantization-aware training
47
+ python unified_workflow.py --qat
48
+ ```
49
+
50
+ For faster CPU execution, BitTransformerLM exposes a `cpu_autocast()` helper
51
+ that enables bfloat16 mixed precision. Models created with
52
+ `use_autocast=True` apply this automatically, or you can wrap individual
53
+ forward passes:
54
+
55
+ ```python
56
+ from bit_transformer.torch_utils import cpu_autocast
57
+
58
+ with cpu_autocast():
59
+ logits, telemetry = model(bits)
60
+ ```
61
+
62
+ Reduce memory use when chunked attention is active by disabling full
63
+ attention logging:
64
+
65
+ ```python
66
+ model = BitTransformerLM(chunk_size=128, full_attn_logging=False)
67
+ ```
68
+
69
+ Enable Diffusion LM training and sampling:
70
+ ```bash
71
+ python unified_workflow.py --diffusion --diffusion-steps 8 --dataset-size 32
72
+ # choose noise schedule: linear, cosine, exp
73
+ python unified_workflow.py --diffusion --noise-schedule cosine --diffusion-steps 16 --dataset-size 32
74
+ # linearly decay noise over epochs
75
+ python unified_workflow.py --diffusion --diffusion-curriculum --dataset-size 32
76
+ ```
77
+ Higher `--diffusion-steps` (8–16) improves sample quality at the cost of compute. When using the dashboard, enable the **Diffusion LM** toggle to run the model without causal masking or chunked attention.
78
+ Generated samples automatically fix parity bits so they can be decoded back to text.
79
+ To resume training across machines using Hugging Face storage:
80
+ ```bash
81
+ python unified_workflow.py --hf-repo your-username/bittransformerlm --hf-token $HF_TOKEN
82
+ ```
83
+ The dashboard exposes matching controls under **Hugging Face Checkpoints**. Provide a repository ID and optional token (falling back to the `HF_TOKEN` environment variable) and click **Upload weights** or **Download weights** to sync the model.
84
+ Run the unit tests:
85
+ ```bash
86
+ pytest -q
87
+ ```
88
+
89
+ ### Mode management
90
+
91
+ During training, ensure the model is in training mode with dropout enabled:
92
+
93
+ ```python
94
+ from bit_transformer.utils import set_dropout
95
+
96
+ model.train()
97
+ set_dropout(model, 0.1)
98
+ ```
99
+
100
+ Before running tests, performing inference, or committing weights to the repository, switch the model to evaluation mode and disable dropout:
101
+
102
+ ```python
103
+ model.eval()
104
+ set_dropout(model, 0.0)
105
+ ```
106
+
107
+ This prevents CI failures from accidentally pushing weights that still have active dropout.
108
+
109
+ ## Telemetry Metrics Explained
110
+ BitTransformerLM reports three bounded metrics in ``[0, 1]`` during training and inference:
111
+
112
+ - **Negentropy (K)** – departure from random noise; ``1`` denotes perfectly ordered bits while ``0`` is uniform randomness.
113
+ - **LZ Complexity (C)** – differentiable proxy for Lempel–Ziv compressibility; low values imply repetitive patterns and high values frequent transitions.
114
+ - **Symbiosis (S)** – agreement between model predictions and a reference distribution via KL divergence; scores near ``1`` show strong alignment.
115
+
116
+ An Adaptive Computation Time (ACT) mechanism lets layers halt early once confidence exceeds a threshold. Halt probabilities are exported as ``halt_probs`` in telemetry for inspection.
117
+
118
+ These metrics are logged alongside losses and can trigger safety gates when thresholds are violated. The dashboard monitors drift and emits warnings when recent values deviate beyond a configurable threshold.
119
+
120
+ ## Core Features
121
+ - **Bit-Native Modeling** – Works directly on 0/1 inputs with positional encodings and parity-protected text helpers.
122
+ - **Telemetry Synthesizer** – Clusters activation summaries to surface coherent subspaces and detect drift.
123
+ - **Submodel Distillation** – `TelemetrySynthesizer` selects representative sequences for `collapse_submodel`, which deepens
124
+ and widens once (`width_scale` = 1.5) if telemetry floors aren't met; `save_distilled_model` places a `metrics.json` summary
125
+ beside the distilled weights.
126
+ - **Safety Gate** – `hil_safe_inference` enforces minimum complexity and symbiosis scores at runtime with EMA smoothing and a configurable burn‑in period.
127
+ - **Quantization** – CPU inference can be quantized to int8 or trained with 4-bit QAT using the `--qat` flag.
128
+ - **Distributed Training** – FSDP and pipeline helpers allow multi‑GPU scaling when hardware is available.
129
+ - **Interactive Dashboard** – Live control of training, scaling and compression with optional GPU acceleration. The dashboard now exposes reversible layers, gradient checkpointing, ACT thresholds, λ floors, 4‑bit QAT and Diffusion LM toggles, real‑time telemetry charts powered by Chart.js, and Hugging Face checkpoint upload/download controls with `HF_TOKEN` fallback. Settings persist via `localStorage`.
130
+ - **CI/CD Pipeline** – GitHub Actions install dependencies, run the tests and build distribution artifacts on every push.
131
+
132
+ ## Development Workflow
133
+ 1. Start the MCP server:
134
+ ```bash
135
+ python mcp_server.py
136
+ ```
137
+ 2. Launch the dashboard in another terminal:
138
+ ```bash
139
+ MCP_SERVER_ADDR=http://127.0.0.1:7000 python -m bit_transformer.dashboard_app
140
+ ```
141
+ 3. Submit training batches, scale the model and monitor telemetry from the web UI.
142
+ The dashboard's appearance is controlled by `bit_transformer/static/style.css`.
143
+
144
+ A `watcher.py` script can automatically restart the server and run tests when files change during local development.
145
+
146
+ ## Container Deployment
147
+ A `Dockerfile` and `start.sh` script build a minimal VM image that launches both the MCP server and dashboard.
148
+
149
+ ```bash
150
+ docker build -t bittransformerlm .
151
+ docker run -p 5000:5000 -p 7000:7000 bittransformerlm
152
+ ```
153
+
154
+ By default the container installs the CPU-only PyTorch wheel. Set the build
155
+ argument `TORCH_CUDA=cu118` to preinstall the GPU version. The container sets
156
+ `MCP_SERVER_ADDR=http://127.0.0.1:7000` and exposes the dashboard on port 5000.
157
+
158
+ ## Roadmap
159
+ - Finalize S attribution tools and metric drift detection.
160
+ - Publish an initial release package and rename the repository to **BitTransformerLM**.
161
+ - Continue benchmarking on real datasets and expanding context window capabilities.
162
+
163
+ ## Licensing
164
+
165
+ This project is released under a combination of licenses and agreements to provide a clear framework for use, distribution, and contribution. All licensing documents can be found in the `LICENSE/` directory.
166
+
167
+ The key documents are:
168
+
169
+ * `LICENSE.txt`: The primary open-source license for the software, AGPLv3.
170
+ * `COMMERCIAL_LICENSE.txt`: Terms for commercial use of the software.
171
+ * `DISCLAIMER.txt`: Important legal disclaimers.
172
+ * `ALIGNMENT_AND_TRANSPARENCY.txt`: Our commitment to alignment and transparency.
173
+ * `TRADEMARK_POLICY.txt`: Guidelines for using the project's trademarks.
174
+ * `CONTRIBUTOR_LICENSE_AGREEMENT.txt`: The agreement for all contributors to sign.
175
+
176
+ Please review these documents carefully before using or contributing to the project.
177
+
BitTransformerLM/bit_transformer/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import (
2
+ PositionalEncoding,
3
+ BitTransformerLM,
4
+ ReversibleLoggingTransformerEncoderLayer,
5
+ example_usage,
6
+ example_training_step,
7
+ infer_long_sequence,
8
+ diffusion_inference,
9
+ )
10
+ from .telemetry import TelemetrySynthesizer, detect_metric_drift
11
+ from .dashboard import plot_telemetry
12
+ from .dashboard_app import run_dashboard
13
+ from .collapse import collapse_submodel, save_distilled_model
14
+ from .safety import hil_safe_inference, demo_hil_safety, safe_sample_with_retry
15
+ from .bit_io import (
16
+ text_to_bits,
17
+ bits_to_text,
18
+ infer_text,
19
+ )
20
+ from .parity import enforce_parity
21
+ from .compression import (
22
+ compress_bits,
23
+ decompress_bits,
24
+ model_output_decompress,
25
+ pack_bits,
26
+ unpack_bits,
27
+ )
28
+ from .distributed import wrap_fsdp, make_pipeline
29
+ from .optimization import configure_optimizer, adjust_learning_rate
30
+ from .scale import expand_model
31
+ from .distil import distill_step, TelemetryLog
32
+ from .quantization import (
33
+ quantize_dynamic,
34
+ prepare_qat_fx,
35
+ convert_qat_fx,
36
+ )
37
+ from .training import train_loop
38
+ from .utils import save_model, load_model, set_dropout
39
+ from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
40
+ from .torch_utils import cpu_autocast
41
+
42
+ __all__ = [
43
+ "PositionalEncoding",
44
+ "BitTransformerLM",
45
+ "ReversibleLoggingTransformerEncoderLayer",
46
+ "example_usage",
47
+ "example_training_step",
48
+ "TelemetrySynthesizer",
49
+ "detect_metric_drift",
50
+ "collapse_submodel",
51
+ "save_distilled_model",
52
+ "hil_safe_inference",
53
+ "demo_hil_safety",
54
+ "safe_sample_with_retry",
55
+ "text_to_bits",
56
+ "bits_to_text",
57
+ "infer_text",
58
+ "enforce_parity",
59
+ "plot_telemetry",
60
+ "run_dashboard",
61
+ "configure_optimizer",
62
+ "adjust_learning_rate",
63
+ "expand_model",
64
+ "distill_step",
65
+ "TelemetryLog",
66
+ "quantize_dynamic",
67
+ "prepare_qat_fx",
68
+ "convert_qat_fx",
69
+ "train_loop",
70
+ "wrap_fsdp",
71
+ "make_pipeline",
72
+ "compress_bits",
73
+ "decompress_bits",
74
+ "model_output_decompress",
75
+ "pack_bits",
76
+ "unpack_bits",
77
+ "infer_long_sequence",
78
+ "diffusion_inference",
79
+ "save_model",
80
+ "load_model",
81
+ "set_dropout",
82
+ "hf_login",
83
+ "save_checkpoint",
84
+ "download_checkpoint",
85
+ "cpu_autocast",
86
+ ]
BitTransformerLM/bit_transformer/bit_io.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TYPE_CHECKING
2
+ import torch
3
+ import sys
4
+
5
+ try: # torch.compile may be unavailable or unsupported
6
+ if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
7
+ compile_fn = torch.compile
8
+ else:
9
+ raise RuntimeError
10
+ except Exception: # pragma: no cover
11
+
12
+ def compile_fn(fn=None, **kwargs):
13
+ if fn is None:
14
+ return lambda f: f
15
+ return fn
16
+
17
+
18
+ if TYPE_CHECKING: # pragma: no cover
19
+ from .model import BitTransformerLM
20
+
21
+
22
+ @compile_fn
23
+ def bytes_to_bits(data: bytes) -> List[int]:
24
+ """Convert bytes to bits with per-byte parity bit."""
25
+ result: List[int] = []
26
+ for b in data:
27
+ bits = [(b >> i) & 1 for i in reversed(range(8))]
28
+ parity = sum(bits) % 2
29
+ result.extend(bits + [parity])
30
+ return result
31
+
32
+
33
+ @compile_fn
34
+ def bits_to_bytes(bits: List[int]) -> bytes:
35
+ """Convert parity-protected bits back to bytes."""
36
+ if len(bits) % 9 != 0:
37
+ raise ValueError("Bit stream length must be multiple of 9")
38
+ out = bytearray()
39
+ for i in range(0, len(bits), 9):
40
+ chunk = bits[i : i + 9]
41
+ payload = chunk[:8]
42
+ parity = chunk[8]
43
+ if parity != sum(payload) % 2:
44
+ raise ValueError("Parity check failed")
45
+ value = 0
46
+ for bit in payload:
47
+ value = (value << 1) | bit
48
+ out.append(value)
49
+ return bytes(out)
50
+
51
+
52
+ def text_to_bits(text: str) -> List[int]:
53
+ return bytes_to_bits(text.encode("utf-8"))
54
+
55
+
56
+ def bits_to_text(bits: List[int]) -> str:
57
+ return bits_to_bytes(bits).decode("utf-8", errors="replace")
58
+
59
+
60
+ def infer_text(
61
+ model: "BitTransformerLM",
62
+ text: str,
63
+ c_floor: float = 0.3,
64
+ s_floor: float = 0.5,
65
+ ) -> str:
66
+ """Run text through the model using the safety gate."""
67
+ from .safety import hil_safe_inference
68
+ bits = text_to_bits(text)
69
+ tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
70
+ out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor)
71
+ return bits_to_text(out_bits.squeeze(0).tolist())
72
+
73
+
74
+ def sample_text(
75
+ model: "BitTransformerLM",
76
+ prompt: str,
77
+ max_new_tokens: int = 16,
78
+ temperature: float = 1.0,
79
+ top_p: float = 1.0,
80
+ ) -> str:
81
+ """Generate text from the model using simple top-p sampling."""
82
+ model.eval()
83
+ bits = text_to_bits(prompt)
84
+ tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
85
+ for _ in range(max_new_tokens * 9):
86
+ if tensor.size(1) >= model.pos_enc.pe.size(0):
87
+ break
88
+ logits, _ = model(tensor, causal=True)
89
+ prob = logits[0, -1].softmax(-1) / temperature
90
+ sorted_prob, sorted_idx = prob.sort(descending=True)
91
+ cumulative = sorted_prob.cumsum(0)
92
+ mask = cumulative > top_p
93
+ sorted_prob[mask] = 0
94
+ sorted_prob = sorted_prob / sorted_prob.sum()
95
+ next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)]
96
+ tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1)
97
+ return bits_to_text(tensor.squeeze(0).tolist())
BitTransformerLM/bit_transformer/collapse.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+ from .model import BitTransformerLM
8
+ from .training import train_loop
9
+
10
+
11
+ def collapse_submodel(
12
+ cluster_data: List[List[int]],
13
+ target_params: Dict,
14
+ floors: Optional[Dict[str, float]] = None,
15
+ max_rounds: int = 3,
16
+ width_scale: float = 1.5,
17
+ forward_kwargs: Optional[Dict] = None,
18
+ ) -> Tuple[BitTransformerLM, Dict[str, float]]:
19
+ """Distill a submodel from clustered bit sequences.
20
+
21
+ The routine deepens the target model when telemetry floors are unmet and,
22
+ after the first deepening fails, widens the hidden dimensions by
23
+ ``width_scale`` once before retrying. Returns the distilled model and its
24
+ final telemetry metrics.
25
+ """
26
+ if floors is None:
27
+ floors = {"negentropy": 0.5, "lz_complexity": 0.3, "symbiosis_score": 0.5}
28
+
29
+ bit_tensor = torch.tensor(cluster_data, dtype=torch.long)
30
+ n = len(bit_tensor)
31
+ split = max(1, int(0.8 * n))
32
+ train_bits = bit_tensor[:split]
33
+ val_bits = bit_tensor[split:]
34
+ if len(val_bits) == 0:
35
+ val_bits = train_bits
36
+
37
+ params = target_params.copy()
38
+ metrics: Dict[str, float] = {}
39
+ width_scaled = False
40
+ for round_idx in range(max_rounds):
41
+ model = BitTransformerLM(**params)
42
+ train_loop(
43
+ model,
44
+ train_bits,
45
+ epochs=2,
46
+ compress_prob=0.5,
47
+ direct_prob=0.0,
48
+ log=False,
49
+ forward_kwargs=forward_kwargs,
50
+ )
51
+ with torch.no_grad():
52
+ logits, telemetry = model(val_bits, **(forward_kwargs or {}))
53
+ neg_k = model.negentropy_logits(logits).mean().item()
54
+ lz_c = model.lz_complexity_logits(logits).mean().item()
55
+ sym_s = telemetry["symbiosis_score"].mean().item()
56
+ metrics = {
57
+ "negentropy": neg_k,
58
+ "lz_complexity": lz_c,
59
+ "symbiosis_score": sym_s,
60
+ }
61
+ if (
62
+ neg_k >= floors["negentropy"]
63
+ and lz_c >= floors["lz_complexity"]
64
+ and sym_s >= floors["symbiosis_score"]
65
+ ):
66
+ break
67
+ if round_idx == 0:
68
+ params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
69
+ elif not width_scaled:
70
+ params["d_model"] = int(params.get("d_model", 32) * width_scale)
71
+ params["dim_feedforward"] = int(
72
+ params.get("dim_feedforward", 64) * width_scale
73
+ )
74
+ width_scaled = True
75
+ else:
76
+ params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
77
+ return model, metrics
78
+
79
+
80
+ def save_distilled_model(
81
+ model: BitTransformerLM,
82
+ path: str,
83
+ metrics: Dict[str, float],
84
+ floors: Optional[Dict[str, float]] = None,
85
+ ) -> None:
86
+ """Serialize a distilled model and its metric summary to disk.
87
+
88
+ Weights are written to ``path`` and a ``metrics.json`` file is placed in the
89
+ same directory containing the achieved metrics alongside the target floors.
90
+ """
91
+ torch.save(model.state_dict(), path)
92
+ payload = {"metrics": metrics, "floors": floors or {}}
93
+ metrics_path = os.path.join(os.path.dirname(path), "metrics.json")
94
+ with open(metrics_path, "w") as f:
95
+ json.dump(payload, f)
BitTransformerLM/bit_transformer/compression.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+
4
+
5
+ def compress_bits(bits: torch.Tensor) -> torch.Tensor:
6
+ """Run-length encode a 1D tensor of bits.
7
+
8
+ Args:
9
+ bits: 1D tensor with values 0 or 1 (bool or uint8).
10
+
11
+ Returns:
12
+ 1D uint8 tensor containing interleaved values and run lengths.
13
+ """
14
+ if bits.dim() != 1:
15
+ raise ValueError("compress_bits expects a 1D tensor")
16
+ b = bits.to(torch.uint8).flatten()
17
+ if b.numel() == 0:
18
+ return b
19
+ changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
20
+ starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
21
+ ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
22
+ values = b[starts.to(torch.long)]
23
+ counts = ends - starts
24
+
25
+ out_vals: List[int] = []
26
+ out_counts: List[int] = []
27
+ for v, c in zip(values.tolist(), counts.tolist()):
28
+ while c > 255:
29
+ out_vals.append(v)
30
+ out_counts.append(255)
31
+ c -= 255
32
+ out_vals.append(v)
33
+ out_counts.append(c)
34
+ values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
35
+ counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
36
+ out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
37
+ return out
38
+
39
+
40
+ def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
41
+ """Decode a run-length encoded bit tensor."""
42
+ if compressed.dim() != 1 or compressed.numel() % 2 != 0:
43
+ raise ValueError("compressed tensor must be 1D even-length")
44
+ data = compressed.to(torch.uint8)
45
+ values = data[0::2]
46
+ counts = data[1::2].to(torch.long)
47
+ return torch.repeat_interleave(values, counts)
48
+
49
+
50
+ def model_output_decompress(compressed_batch) -> torch.Tensor:
51
+ """Decompress a batch of compressed bit sequences."""
52
+ if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
53
+ sequences = [decompress_bits(compressed_batch)]
54
+ else:
55
+ sequences = [decompress_bits(row) for row in compressed_batch]
56
+ lengths = [seq.numel() for seq in sequences]
57
+ if len(set(lengths)) != 1:
58
+ raise ValueError("Sequences decompress to different lengths")
59
+ return torch.stack(sequences)
60
+
61
+
62
+ import numpy as np
63
+
64
+
65
+ def pack_bits(bits: torch.Tensor) -> torch.Tensor:
66
+ """Pack groups of 8 bits into uint8 values using numpy.packbits."""
67
+ if bits.dim() != 1:
68
+ raise ValueError("pack_bits expects a 1D tensor")
69
+ arr = bits.to(torch.uint8).cpu().numpy()
70
+ packed = np.packbits(arr)
71
+ return torch.from_numpy(packed)
72
+
73
+
74
+ def unpack_bits(packed: torch.Tensor, *, n_bits: int | None = None) -> torch.Tensor:
75
+ """Unpack uint8 values back into a bit tensor."""
76
+ if packed.dim() != 1:
77
+ raise ValueError("unpack_bits expects a 1D tensor")
78
+ arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
79
+ if n_bits is not None:
80
+ arr = arr[:n_bits]
81
+ return torch.from_numpy(arr)
82
+
BitTransformerLM/bit_transformer/dashboard.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from typing import Dict, List, Tuple
3
+
4
+
5
+ def plot_telemetry(
6
+ metrics_log: Dict[str, List[float]],
7
+ k_floor: float = 0.5,
8
+ c_floor: float = 0.3,
9
+ s_floor: float = 0.5,
10
+ ) -> Tuple[plt.Figure, List[plt.Axes]]:
11
+ """Plot K, C, S metrics over time with cluster transitions.
12
+
13
+ Args:
14
+ metrics_log: Dictionary with keys ``negentropy``, ``lz_complexity``,
15
+ ``symbiosis_score`` and optional ``clusters`` listing cluster
16
+ assignments per step.
17
+ k_floor: Threshold for negentropy (K).
18
+ c_floor: Threshold for LZ complexity (C).
19
+ s_floor: Threshold for symbiosis score (S).
20
+
21
+ Returns:
22
+ (figure, axes) tuple for further customization or saving.
23
+ """
24
+ steps = list(range(len(metrics_log.get("negentropy", []))))
25
+ fig, axes = plt.subplots(3, 1, sharex=True, figsize=(10, 6))
26
+ metrics = [
27
+ ("negentropy", k_floor, "K"),
28
+ ("lz_complexity", c_floor, "C"),
29
+ ("symbiosis_score", s_floor, "S"),
30
+ ]
31
+ for ax, (key, floor, label) in zip(axes, metrics):
32
+ values = metrics_log.get(key, [])
33
+ ax.plot(steps, values, label=label)
34
+ ax.axhline(floor, color="r", linestyle="--", linewidth=1)
35
+ violations = [i for i, v in enumerate(values) if v < floor]
36
+ if violations:
37
+ ax.scatter(
38
+ [steps[i] for i in violations],
39
+ [values[i] for i in violations],
40
+ color="r",
41
+ zorder=5,
42
+ label="violation",
43
+ )
44
+ ax.set_ylabel(label)
45
+ ax.legend(loc="upper right")
46
+
47
+ clusters = metrics_log.get("clusters")
48
+ if clusters is not None:
49
+ prev = clusters[0]
50
+ for t, c in enumerate(clusters):
51
+ if t > 0 and c != prev:
52
+ for ax in axes:
53
+ ax.axvline(t, color="gray", linestyle=":", alpha=0.5)
54
+ prev = c
55
+
56
+ axes[-1].set_xlabel("step")
57
+ plt.tight_layout()
58
+ return fig, axes
BitTransformerLM/bit_transformer/dashboard_app.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ import os
4
+ import traceback
5
+ import inspect
6
+ from typing import Any, Dict, List
7
+
8
+ from flask import Flask, jsonify, request, render_template, send_file
9
+ import subprocess
10
+ import sys
11
+ import warnings
12
+ import matplotlib.pyplot as plt
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import requests
16
+ import gzip
17
+
18
+ from .model import BitTransformerLM, infer_long_sequence
19
+ from .optimization import configure_optimizer
20
+ from .collapse import collapse_submodel
21
+ from .dashboard import plot_telemetry
22
+ from .scale import expand_model
23
+ from .bit_io import text_to_bits, bits_to_text
24
+ from .safety import hil_safe_inference
25
+ from .compression import model_output_decompress, compress_bits
26
+ from .distributed import wrap_fsdp
27
+ from .training import train_loop
28
+ from .telemetry import detect_metric_drift
29
+ from .quantization import prepare_qat_fx, convert_qat_fx
30
+ from torch.distributed.fsdp import FullyShardedDataParallel
31
+ from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
32
+
33
+
34
+ app = Flask(__name__)
35
+ app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 # 1MB upload limit
36
+
37
+ MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR")
38
+
39
+
40
+ @app.errorhandler(Exception)
41
+ def handle_exception(err):
42
+ """Return JSON error responses with stack traces."""
43
+ return (
44
+ jsonify({"error": str(err), "trace": traceback.format_exc()}),
45
+ getattr(err, "code", 500),
46
+ )
47
+
48
+ class MetricDriftWarning(UserWarning):
49
+ """Raised when telemetry metrics drift beyond the configured threshold."""
50
+
51
+ def _switch_torch(use_gpu: bool) -> None:
52
+ """Install the appropriate PyTorch wheel and restart the process."""
53
+ have_cuda = torch.version.cuda is not None
54
+ if use_gpu == have_cuda:
55
+ return
56
+ wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu"
57
+ url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu"
58
+ subprocess.run([
59
+ sys.executable,
60
+ "-m",
61
+ "pip",
62
+ "install",
63
+ "--extra-index-url",
64
+ url,
65
+ wheel,
66
+ ], check=True)
67
+ os.execv(sys.executable, [sys.executable] + sys.argv)
68
+
69
+ def mcp_post(path: str, data=None):
70
+ if not MCP_SERVER_ADDR:
71
+ return None
72
+ url = MCP_SERVER_ADDR.rstrip("/") + path
73
+ resp = requests.post(url, json=data)
74
+ resp.raise_for_status()
75
+ if resp.headers.get("Content-Type", "").startswith("image/"):
76
+ return resp.content
77
+ return resp.json()
78
+
79
+ def mcp_get(path: str):
80
+ if not MCP_SERVER_ADDR:
81
+ return None
82
+ url = MCP_SERVER_ADDR.rstrip("/") + path
83
+ resp = requests.get(url)
84
+ resp.raise_for_status()
85
+ if resp.headers.get("Content-Type", "").startswith("image/"):
86
+ return resp.content
87
+ return resp.json()
88
+
89
+ class ModelManager:
90
+ """Manage model state and training utilities for the dashboard."""
91
+
92
+ def __init__(
93
+ self,
94
+ snapshot_dir: str | None = None,
95
+ telemetry_log: str | None = None,
96
+ *,
97
+ drift_window: int = 10,
98
+ drift_threshold: float = 0.2,
99
+ ) -> None:
100
+ self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots")
101
+ self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG")
102
+ if self.telemetry_log is None:
103
+ self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json")
104
+ os.makedirs(self.snapshot_dir, exist_ok=True)
105
+ self.weights_path = os.path.join(self.snapshot_dir, "model.pt")
106
+
107
+ self.model: BitTransformerLM | None = None
108
+ self.optimizer: torch.optim.Optimizer | None = None
109
+ self.scheduler: torch.optim.lr_scheduler._LRScheduler | None = None
110
+ self.total_steps = 100
111
+ self.metrics: Dict[str, List[float]] = {
112
+ "negentropy_logits": [],
113
+ "lz_complexity_logits": [],
114
+ "symbiosis_score": [],
115
+ }
116
+ self.drift_window = drift_window
117
+ self.drift_threshold = drift_threshold
118
+ self.lambda_K = 1.0
119
+ self.lambda_C = 1.0
120
+ self.lambda_S = 1.0
121
+ self.c_floor = 0.3
122
+ self.s_floor = 0.5
123
+ self.causal = True
124
+ self.diffusion = False
125
+ self.decompress_output = False
126
+ self.use_compression = False
127
+ self.use_gpu = False
128
+ self.qat = False
129
+
130
+ # Load any existing state
131
+ if os.path.exists(self.telemetry_log):
132
+ try:
133
+ with open(self.telemetry_log) as f:
134
+ saved = json.load(f)
135
+ for key in self.metrics:
136
+ self.metrics[key] = saved.get(key, [])
137
+ except Exception:
138
+ pass
139
+ if os.path.exists(self.weights_path):
140
+ try:
141
+ self.model = torch.load(self.weights_path, map_location="cpu")
142
+ self.optimizer, self.scheduler = configure_optimizer(
143
+ self.model, lr=1e-3, total_steps=self.total_steps
144
+ )
145
+ self._apply_device()
146
+ except Exception:
147
+ self.model = None
148
+
149
+ config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json")
150
+ if self.model is None and os.path.exists(config_path):
151
+ try:
152
+ with open(config_path) as f:
153
+ params = json.load(f)
154
+ self.init_model(params)
155
+ except Exception:
156
+ pass
157
+
158
+ def init_model(self, params: Dict) -> None:
159
+ int_fields = {
160
+ "d_model",
161
+ "nhead",
162
+ "num_layers",
163
+ "dim_feedforward",
164
+ "max_seq_len",
165
+ "chunk_size",
166
+ "overlap",
167
+ }
168
+ float_fields = {"act_threshold"}
169
+ bool_fields = {"reversible", "use_checkpoint"}
170
+ clean: Dict[str, Any] = {}
171
+ for k, v in params.items():
172
+ if v is None:
173
+ clean[k] = None
174
+ elif k in int_fields:
175
+ clean[k] = int(v)
176
+ elif k in float_fields:
177
+ clean[k] = float(v)
178
+ elif k in bool_fields:
179
+ clean[k] = bool(v)
180
+ else:
181
+ clean[k] = v
182
+ self.model = BitTransformerLM(
183
+ **clean,
184
+ lambda_K=self.lambda_K,
185
+ lambda_C=self.lambda_C,
186
+ lambda_S=self.lambda_S,
187
+ )
188
+ self.optimizer, self.scheduler = configure_optimizer(
189
+ self.model, lr=1e-3, total_steps=self.total_steps
190
+ )
191
+ self._apply_device()
192
+ for key in self.metrics:
193
+ self.metrics[key].clear()
194
+
195
+ def set_lambdas(self, k: float, c: float, s: float) -> None:
196
+ """Update λ weights and propagate to the model."""
197
+ self.lambda_K = k
198
+ self.lambda_C = c
199
+ self.lambda_S = s
200
+ if self.model is not None:
201
+ self.model.set_lambdas(k, c, s)
202
+
203
+ def set_floors(self, c_floor: float, s_floor: float) -> None:
204
+ """Update safety floors for complexity (C) and symbiosis (S)."""
205
+ self.c_floor = c_floor
206
+ self.s_floor = s_floor
207
+
208
+ def set_diffusion(self, flag: bool) -> None:
209
+ """Toggle Diffusion LM mode which disables causal masking and chunking."""
210
+ self.diffusion = flag
211
+ self.causal = not flag
212
+ if self.model is not None and flag:
213
+ self.model.chunk_size = None
214
+
215
+ def set_decompress_output(self, flag: bool) -> None:
216
+ """Enable or disable decompression of model outputs."""
217
+ self.decompress_output = flag
218
+
219
+ def set_compression(self, flag: bool) -> None:
220
+ """Toggle automatic compression of inputs."""
221
+ self.use_compression = flag
222
+
223
+ def set_qat(self, flag: bool) -> None:
224
+ """Enable or disable 4-bit quantization-aware training."""
225
+ self.qat = flag
226
+ if self.model is None:
227
+ return
228
+ if flag:
229
+ self.model = prepare_qat_fx(self.model)
230
+ else:
231
+ self.model = convert_qat_fx(self.model)
232
+
233
+ def set_gpu(self, flag: bool) -> None:
234
+ """Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed."""
235
+ _switch_torch(flag)
236
+ self.use_gpu = flag and torch.cuda.is_available()
237
+ self._apply_device()
238
+
239
+ def _apply_device(self) -> None:
240
+ """Move the model to the selected device and wrap with FSDP if needed."""
241
+ if self.model is None:
242
+ return
243
+ if self.use_gpu:
244
+ device = torch.device("cuda")
245
+ if isinstance(self.model, FullyShardedDataParallel):
246
+ base = self.model.module
247
+ else:
248
+ base = self.model
249
+ base = base.to(device)
250
+ self.model = wrap_fsdp(base, device_id=device)
251
+ else:
252
+ device = torch.device("cpu")
253
+ if isinstance(self.model, FullyShardedDataParallel):
254
+ self.model = self.model.module
255
+ self.model = self.model.to(device)
256
+
257
+ def train_step(self, bits: torch.Tensor) -> float:
258
+ assert (
259
+ self.model is not None
260
+ and self.optimizer is not None
261
+ and self.scheduler is not None
262
+ )
263
+ self.model.train()
264
+ device = next(self.model.parameters()).device
265
+ bits = bits.to(device)
266
+ ratio = 1.0
267
+ if self.use_compression:
268
+ comps = [compress_bits(row.to(torch.uint8)) for row in bits]
269
+ comp_len = sum(c.numel() for c in comps)
270
+ ratio = min(comp_len / bits.numel(), 1.0)
271
+ logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
272
+ else:
273
+ logits, telemetry = self.model(bits, causal=self.causal)
274
+ pred = logits[:, :-1, :].reshape(-1, 2)
275
+ target = bits[:, 1:].reshape(-1)
276
+ loss = F.cross_entropy(pred, target)
277
+ loss.backward()
278
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
279
+ self.optimizer.step()
280
+ self.scheduler.step()
281
+ self.optimizer.zero_grad()
282
+ self._log_metrics(telemetry)
283
+ self._save_state()
284
+ return loss.item(), ratio
285
+
286
+ def train_epochs(
287
+ self,
288
+ bits: torch.Tensor,
289
+ *,
290
+ epochs: int = 1,
291
+ compress_prob: float = 0.5,
292
+ direct_prob: float = 0.0,
293
+ batch_size: int = 8,
294
+ num_workers: int = 0,
295
+ accum_steps: int = 1,
296
+ amp: bool = False,
297
+ compile_model: bool = False,
298
+ ) -> List[Dict[str, float]]:
299
+ """Run ``train_loop`` on a batch tensor and persist the state."""
300
+ assert self.model is not None
301
+ device = next(self.model.parameters()).device
302
+ bits = bits.to(device)
303
+ import math
304
+ steps_per_epoch = max(1, math.ceil(len(bits) / batch_size))
305
+ self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps)
306
+ self.optimizer, self.scheduler = configure_optimizer(
307
+ self.model, lr=1e-3, total_steps=self.total_steps
308
+ )
309
+ metrics = train_loop(
310
+ self.model,
311
+ bits,
312
+ epochs=epochs,
313
+ compress_prob=compress_prob if self.use_compression else 0.0,
314
+ direct_prob=direct_prob,
315
+ batch_size=batch_size,
316
+ num_workers=num_workers,
317
+ accum_steps=accum_steps,
318
+ amp=amp,
319
+ compile_model=compile_model,
320
+ forward_kwargs={"causal": self.causal},
321
+ optimizer=self.optimizer,
322
+ scheduler=self.scheduler,
323
+ )
324
+ self._save_state()
325
+ return metrics
326
+
327
+ def scale_up(self, width_mult: float = 1.0) -> None:
328
+ assert self.model is not None
329
+ params = dict(
330
+ d_model=int(self.model.d_model * width_mult),
331
+ nhead=self.model.layers[0].self_attn.num_heads,
332
+ num_layers=self.model.num_layers * 2,
333
+ dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult),
334
+ max_seq_len=self.model.pos_enc.pe.size(0),
335
+ )
336
+ self.model = expand_model(self.model, {
337
+ **params,
338
+ "lambda_K": self.lambda_K,
339
+ "lambda_C": self.lambda_C,
340
+ "lambda_S": self.lambda_S,
341
+ })
342
+ self.optimizer, self.scheduler = configure_optimizer(
343
+ self.model, lr=1e-3, total_steps=self.total_steps
344
+ )
345
+ self._save_state()
346
+
347
+ def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None:
348
+ self.model, _ = collapse_submodel(
349
+ cluster_bits,
350
+ target_params,
351
+ width_scale=width_scale,
352
+ forward_kwargs={"causal": self.causal},
353
+ )
354
+ self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S)
355
+ self.optimizer, self.scheduler = configure_optimizer(
356
+ self.model, lr=1e-3, total_steps=self.total_steps
357
+ )
358
+ self._apply_device()
359
+ for key in self.metrics:
360
+ self.metrics[key].clear()
361
+
362
+ def infer(self, bits: torch.Tensor) -> Dict:
363
+ assert self.model is not None
364
+ self.model.eval()
365
+ device = next(self.model.parameters()).device
366
+ bits = bits.to(device)
367
+ ratio = 1.0
368
+ with torch.no_grad():
369
+ if self.use_compression:
370
+ comps = [compress_bits(row.to(torch.uint8)) for row in bits]
371
+ comp_len = sum(c.numel() for c in comps)
372
+ ratio = min(comp_len / bits.numel(), 1.0)
373
+ logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
374
+ else:
375
+ logits, telemetry = self.model(bits, causal=self.causal)
376
+ self._log_metrics(telemetry)
377
+ pred_bits = logits.argmax(-1)
378
+ if self.decompress_output:
379
+ try:
380
+ pred_bits = model_output_decompress(pred_bits)
381
+ except Exception as e:
382
+ return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."}
383
+ def _to_python(obj):
384
+ if isinstance(obj, torch.Tensor):
385
+ return obj.tolist()
386
+ if isinstance(obj, list):
387
+ return [_to_python(o) for o in obj]
388
+ if isinstance(obj, dict):
389
+ return {kk: _to_python(vv) for kk, vv in obj.items()}
390
+ return obj
391
+ tele = {k: _to_python(v) for k, v in telemetry.items()}
392
+ return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio}
393
+
394
+ def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict:
395
+ """Run sliding-window inference on a long sequence."""
396
+ assert self.model is not None
397
+ device = next(self.model.parameters()).device
398
+ bits = bits.to(device)
399
+ preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap)
400
+ for tele in logs:
401
+ self._log_metrics(tele)
402
+ return {"predicted": preds.tolist(), "windows": len(logs)}
403
+
404
+ def _log_metrics(self, telemetry: Dict) -> None:
405
+ for key in self.metrics:
406
+ val = telemetry[key].mean().item()
407
+ self.metrics[key].append(val)
408
+ drift = detect_metric_drift(
409
+ self.metrics, window=self.drift_window, threshold=self.drift_threshold
410
+ )
411
+ bad = [k for k, v in drift.items() if v]
412
+ if bad:
413
+ warnings.warn(
414
+ f"Metric drift detected: {', '.join(bad)}",
415
+ MetricDriftWarning,
416
+ )
417
+
418
+ def infer_text(self, text: str) -> Dict[str, Any]:
419
+ """Run text through the model using the safety gate."""
420
+ assert self.model is not None
421
+ device = next(self.model.parameters()).device
422
+ bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device)
423
+ out_bits, telemetry = hil_safe_inference(
424
+ self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor
425
+ )
426
+ self._log_metrics(telemetry)
427
+ return {
428
+ "output": bits_to_text(out_bits.squeeze(0).tolist()),
429
+ "telemetry": telemetry,
430
+ }
431
+
432
+ def get_status(self) -> Dict[str, Any]:
433
+ info: Dict[str, Any] = {
434
+ "use_gpu": self.use_gpu,
435
+ "diffusion": self.diffusion,
436
+ "compression": self.use_compression,
437
+ "lambda_K": self.lambda_K,
438
+ "lambda_C": self.lambda_C,
439
+ "lambda_S": self.lambda_S,
440
+ "c_floor": self.c_floor,
441
+ "s_floor": self.s_floor,
442
+ "qat": self.qat,
443
+ }
444
+ if self.model is not None:
445
+ info.update(
446
+ {
447
+ "d_model": self.model.d_model,
448
+ "num_layers": self.model.num_layers,
449
+ "d_ff": self.model.layers[0].linear1.out_features,
450
+ "nhead": self.model.layers[0].self_attn.num_heads,
451
+ "max_seq_len": self.model.pos_enc.pe.size(0),
452
+ }
453
+ )
454
+ else:
455
+ info.update(
456
+ {
457
+ "d_model": None,
458
+ "num_layers": 0,
459
+ "d_ff": None,
460
+ "nhead": None,
461
+ "max_seq_len": None,
462
+ }
463
+ )
464
+ return info
465
+
466
+ def get_model_config(self) -> Dict[str, Any]:
467
+ """Return current model hyperparameters and safety settings."""
468
+ cfg: Dict[str, Any] = {
469
+ "lambda_K": self.lambda_K,
470
+ "lambda_C": self.lambda_C,
471
+ "lambda_S": self.lambda_S,
472
+ "c_floor": self.c_floor,
473
+ "s_floor": self.s_floor,
474
+ }
475
+ if self.model is not None:
476
+ cfg.update(
477
+ {
478
+ "d_model": self.model.d_model,
479
+ "nhead": self.model.layers[0].self_attn.num_heads,
480
+ "num_layers": self.model.num_layers,
481
+ "dim_feedforward": self.model.layers[0].linear1.out_features,
482
+ "max_seq_len": self.model.pos_enc.pe.size(0),
483
+ "chunk_size": self.model.chunk_size,
484
+ "reversible": self.model.reversible,
485
+ "use_checkpoint": self.model.use_checkpoint,
486
+ }
487
+ )
488
+ else:
489
+ cfg.update(
490
+ {
491
+ "d_model": None,
492
+ "nhead": None,
493
+ "num_layers": 0,
494
+ "dim_feedforward": None,
495
+ "max_seq_len": None,
496
+ "chunk_size": None,
497
+ "reversible": None,
498
+ "use_checkpoint": None,
499
+ }
500
+ )
501
+ return cfg
502
+
503
+ def get_metrics(self) -> Dict[str, Any]:
504
+ """Return logged telemetry metrics with summary statistics."""
505
+ from statistics import mean, stdev
506
+
507
+ data = {
508
+ "negentropy": self.metrics["negentropy_logits"],
509
+ "lz_complexity": self.metrics["lz_complexity_logits"],
510
+ "symbiosis": self.metrics["symbiosis_score"],
511
+ }
512
+ summary: Dict[str, Dict[str, float | None]] = {}
513
+ for key, values in data.items():
514
+ if values:
515
+ m = mean(values)
516
+ s = stdev(values) if len(values) > 1 else 0.0
517
+ summary[key] = {"mean": m, "std": s}
518
+ else:
519
+ summary[key] = {"mean": None, "std": None}
520
+ data["summary"] = summary
521
+ return data
522
+
523
+
524
+ def _save_state(self) -> None:
525
+ if self.model is None:
526
+ return
527
+ torch.save(self.model, self.weights_path)
528
+ with open(self.telemetry_log, "w") as f:
529
+ json.dump(self.metrics, f)
530
+
531
+
532
+ manager: ModelManager | None = None
533
+
534
+
535
+ @app.route("/")
536
+ def index():
537
+ return render_template(
538
+ "dashboard.html",
539
+ metrics=manager.metrics,
540
+ lambdas={
541
+ "lambda_K": manager.lambda_K,
542
+ "lambda_C": manager.lambda_C,
543
+ "lambda_S": manager.lambda_S,
544
+ },
545
+ diffusion=manager.diffusion,
546
+ compression=manager.use_compression,
547
+ defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty},
548
+ c_floor=manager.c_floor,
549
+ s_floor=manager.s_floor,
550
+ qat=manager.qat,
551
+ )
552
+
553
+
554
+ @app.route("/status", methods=["GET"])
555
+ def status():
556
+ if MCP_SERVER_ADDR:
557
+ return jsonify(mcp_get("/status"))
558
+ return jsonify(manager.get_status())
559
+
560
+
561
+ @app.route("/model_config", methods=["GET"])
562
+ def model_config():
563
+ if MCP_SERVER_ADDR:
564
+ return jsonify(mcp_get("/model_config"))
565
+ return jsonify(manager.get_model_config())
566
+
567
+
568
+ @app.route("/metrics", methods=["GET"])
569
+ def metrics():
570
+ if MCP_SERVER_ADDR:
571
+ return jsonify(mcp_get("/metrics"))
572
+ return jsonify(manager.get_metrics())
573
+
574
+
575
+ @app.route("/save_checkpoint", methods=["POST"])
576
+ def save_checkpoint_route():
577
+ repo_id = request.json.get("repo_id")
578
+ token = request.json.get("token") or os.getenv("HF_TOKEN")
579
+ if MCP_SERVER_ADDR:
580
+ return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token}))
581
+ if manager.model is None:
582
+ return jsonify({"error": "model not initialized"}), 400
583
+ if token:
584
+ hf_login(token=token)
585
+ save_checkpoint(manager.model, repo_id=repo_id)
586
+ return jsonify({"status": "saved"})
587
+
588
+
589
+ @app.route("/download_checkpoint", methods=["POST"])
590
+ def download_checkpoint_route():
591
+ repo_id = request.json.get("repo_id")
592
+ token = request.json.get("token") or os.getenv("HF_TOKEN")
593
+ if MCP_SERVER_ADDR:
594
+ return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token}))
595
+ if token:
596
+ hf_login(token=token)
597
+ dest = manager.weights_path + ".gz"
598
+ ok = download_checkpoint(dest, repo_id=repo_id)
599
+ if not ok:
600
+ return jsonify({"status": "failed"}), 500
601
+ if manager.model is None:
602
+ return jsonify({"status": "downloaded", "loaded": False})
603
+ with gzip.open(dest, "rb") as f:
604
+ state = torch.load(f, map_location="cpu")
605
+ manager.model.load_state_dict(state)
606
+ manager.optimizer, manager.scheduler = configure_optimizer(
607
+ manager.model, lr=1e-3, total_steps=manager.total_steps
608
+ )
609
+ manager._apply_device()
610
+ manager._save_state()
611
+ return jsonify({"status": "downloaded", "loaded": True})
612
+
613
+
614
+ @app.route("/text_to_bits", methods=["POST"])
615
+ def text_to_bits_route():
616
+ text = request.json.get("text", "")
617
+ if len(text) > 100_000:
618
+ return jsonify({"error": "text too large"}), 413
619
+ return jsonify({"bits": text_to_bits(text)})
620
+
621
+
622
+ @app.route("/dataset", methods=["GET"])
623
+ def dataset_route():
624
+ name = request.args.get("name", "")
625
+ split = request.args.get("split", "train")
626
+ size = int(request.args.get("size", 1))
627
+ seq_len = int(request.args.get("seq_len", 64))
628
+ if size * seq_len > 1_000_000:
629
+ return jsonify({"error": "dataset too large"}), 413
630
+ if name == "wikitext2":
631
+ try:
632
+ from datasets import load_dataset
633
+
634
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
635
+ lines = [t for t in ds["text"] if t.strip()][:size]
636
+ except Exception:
637
+ bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
638
+ return jsonify({"bits": bits.tolist()})
639
+ bits_list = []
640
+ for text in lines:
641
+ b = text_to_bits(text)[:seq_len]
642
+ if len(b) < seq_len:
643
+ b.extend([0] * (seq_len - len(b)))
644
+ bits_list.append(b)
645
+ if len(bits_list) < size:
646
+ pad = size - len(bits_list)
647
+ bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
648
+ return jsonify({"bits": bits_list})
649
+ return jsonify({"error": "unknown dataset"}), 400
650
+
651
+
652
+ @app.route("/init", methods=["POST"])
653
+ def init_model():
654
+ data = request.json or {}
655
+ int_fields = {
656
+ "d_model",
657
+ "nhead",
658
+ "num_layers",
659
+ "dim_feedforward",
660
+ "max_seq_len",
661
+ "chunk_size",
662
+ "overlap",
663
+ }
664
+ float_fields = {"act_threshold"}
665
+ bool_fields = {"reversible", "use_checkpoint"}
666
+ params = {}
667
+ for k, v in data.items():
668
+ if v is None:
669
+ params[k] = None
670
+ elif k in int_fields:
671
+ params[k] = int(v)
672
+ elif k in float_fields:
673
+ params[k] = float(v)
674
+ elif k in bool_fields:
675
+ params[k] = bool(v)
676
+ else:
677
+ params[k] = v
678
+ if MCP_SERVER_ADDR:
679
+ data = mcp_post("/init", params)
680
+ return jsonify(data)
681
+ manager.init_model(params)
682
+ return jsonify({"status": "initialized", "params": params})
683
+
684
+
685
+ @app.route("/train", methods=["POST"])
686
+ def train_model():
687
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
688
+ if MCP_SERVER_ADDR:
689
+ data = mcp_post("/train", {"bits": request.json["bits"]})
690
+ return jsonify(data)
691
+ loss, ratio = manager.train_step(bits)
692
+ return jsonify({"loss": loss, "ratio": ratio})
693
+
694
+
695
+ @app.route("/train_epochs", methods=["POST"])
696
+ def train_epochs_route():
697
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
698
+ epochs = int(request.json.get("epochs", 1))
699
+ compress_prob = float(request.json.get("compress_prob", 0.5))
700
+ direct_prob = float(request.json.get("direct_prob", 0.0))
701
+ if MCP_SERVER_ADDR:
702
+ data = mcp_post(
703
+ "/train_epochs",
704
+ {
705
+ "bits": request.json["bits"],
706
+ "epochs": epochs,
707
+ "compress_prob": compress_prob,
708
+ "direct_prob": direct_prob,
709
+ },
710
+ )
711
+ return jsonify(data)
712
+ metrics = manager.train_epochs(
713
+ bits,
714
+ epochs=epochs,
715
+ compress_prob=compress_prob,
716
+ direct_prob=direct_prob,
717
+ )
718
+ return jsonify({"metrics": metrics})
719
+
720
+
721
+ @app.route("/scale_up", methods=["POST"])
722
+ def scale_up():
723
+ width_mult = float(request.json.get("width_mult", 1.0))
724
+ if MCP_SERVER_ADDR:
725
+ data = mcp_post("/scale_up", {"width_mult": width_mult})
726
+ return jsonify(data)
727
+ manager.scale_up(width_mult)
728
+ return jsonify({
729
+ "status": "scaled",
730
+ "layers": manager.model.num_layers,
731
+ "d_model": manager.model.d_model,
732
+ })
733
+
734
+
735
+ @app.route("/collapse", methods=["POST"])
736
+ def collapse_model():
737
+ cluster_bits = request.json["clusters"]
738
+ params = {k: int(v) for k, v in request.json["params"].items()}
739
+ width_scale = float(request.json.get("width_scale", 1.0))
740
+ if MCP_SERVER_ADDR:
741
+ data = mcp_post(
742
+ "/collapse",
743
+ {"clusters": cluster_bits, "params": params, "width_scale": width_scale},
744
+ )
745
+ return jsonify(data)
746
+ manager.collapse(cluster_bits, params, width_scale)
747
+ return jsonify({"status": "collapsed"})
748
+
749
+
750
+ @app.route("/lambdas", methods=["GET", "POST"])
751
+ def update_lambdas():
752
+ if request.method == "POST":
753
+ data = request.json
754
+ if MCP_SERVER_ADDR:
755
+ res = mcp_post("/lambdas", data)
756
+ return jsonify(res)
757
+ manager.set_lambdas(
758
+ float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])
759
+ )
760
+ return jsonify({"status": "updated"})
761
+ else:
762
+ if MCP_SERVER_ADDR:
763
+ return jsonify(mcp_get("/lambdas"))
764
+ return jsonify(
765
+ {
766
+ "lambda_K": manager.lambda_K,
767
+ "lambda_C": manager.lambda_C,
768
+ "lambda_S": manager.lambda_S,
769
+ }
770
+ )
771
+
772
+
773
+ @app.route("/config/telemetry", methods=["GET", "POST"])
774
+ def telemetry_config():
775
+ """Get or update telemetry λ weights and safety floors."""
776
+ if request.method == "POST":
777
+ data = request.json
778
+ if MCP_SERVER_ADDR:
779
+ res = mcp_post("/config/telemetry", data)
780
+ return jsonify(res)
781
+ manager.set_lambdas(
782
+ float(data.get("lambda_K", manager.lambda_K)),
783
+ float(data.get("lambda_C", manager.lambda_C)),
784
+ float(data.get("lambda_S", manager.lambda_S)),
785
+ )
786
+ manager.set_floors(
787
+ float(data.get("c_floor", manager.c_floor)),
788
+ float(data.get("s_floor", manager.s_floor)),
789
+ )
790
+ return jsonify({"status": "updated"})
791
+ else:
792
+ if MCP_SERVER_ADDR:
793
+ return jsonify(mcp_get("/config/telemetry"))
794
+ return jsonify(
795
+ {
796
+ "lambda_K": manager.lambda_K,
797
+ "lambda_C": manager.lambda_C,
798
+ "lambda_S": manager.lambda_S,
799
+ "c_floor": manager.c_floor,
800
+ "s_floor": manager.s_floor,
801
+ }
802
+ )
803
+
804
+
805
+ @app.route("/diffusion", methods=["GET", "POST"])
806
+ def update_diffusion():
807
+ if request.method == "POST":
808
+ if MCP_SERVER_ADDR:
809
+ return jsonify(mcp_post("/diffusion", request.json))
810
+ manager.set_diffusion(bool(request.json.get("diffusion", False)))
811
+ return jsonify({"status": "updated"})
812
+ else:
813
+ if MCP_SERVER_ADDR:
814
+ return jsonify(mcp_get("/diffusion"))
815
+ return jsonify({"diffusion": manager.diffusion})
816
+
817
+
818
+ @app.route("/gpu", methods=["GET", "POST"])
819
+ def update_gpu():
820
+ if request.method == "POST":
821
+ if MCP_SERVER_ADDR:
822
+ return jsonify(mcp_post("/gpu", request.json))
823
+ manager.set_gpu(bool(request.json.get("use_gpu", False)))
824
+ return jsonify({"status": "updated"})
825
+ else:
826
+ if MCP_SERVER_ADDR:
827
+ return jsonify(mcp_get("/gpu"))
828
+ return jsonify({"use_gpu": manager.use_gpu})
829
+
830
+
831
+ @app.route("/compression", methods=["GET", "POST"])
832
+ def update_compression():
833
+ if request.method == "POST":
834
+ if MCP_SERVER_ADDR:
835
+ return jsonify(mcp_post("/compression", request.json))
836
+ manager.set_compression(bool(request.json.get("compression", False)))
837
+ return jsonify({"status": "updated"})
838
+ else:
839
+ if MCP_SERVER_ADDR:
840
+ return jsonify(mcp_get("/compression"))
841
+ return jsonify({"compression": manager.use_compression})
842
+
843
+
844
+ @app.route("/qat", methods=["GET", "POST"])
845
+ def update_qat():
846
+ if request.method == "POST":
847
+ if MCP_SERVER_ADDR:
848
+ return jsonify(mcp_post("/qat", request.json))
849
+ manager.set_qat(bool(request.json.get("qat", False)))
850
+ return jsonify({"status": "updated"})
851
+ else:
852
+ if MCP_SERVER_ADDR:
853
+ return jsonify(mcp_get("/qat"))
854
+ return jsonify({"qat": manager.qat})
855
+
856
+
857
+ @app.route("/infer", methods=["POST"])
858
+ def inference():
859
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
860
+ if MCP_SERVER_ADDR:
861
+ data = mcp_post("/infer", {"bits": request.json["bits"]})
862
+ return jsonify(data)
863
+ result = manager.infer(bits)
864
+ return jsonify(result)
865
+
866
+
867
+ @app.route("/infer_long", methods=["POST"])
868
+ def inference_long():
869
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
870
+ ctx = int(request.json.get("ctx_bits", 4096))
871
+ overlap = int(request.json.get("overlap", 256))
872
+ if MCP_SERVER_ADDR:
873
+ data = mcp_post(
874
+ "/infer_long",
875
+ {"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap},
876
+ )
877
+ return jsonify(data)
878
+ result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
879
+ return jsonify(result)
880
+
881
+
882
+ @app.route("/infer_text", methods=["POST"])
883
+ def inference_text():
884
+ text = request.json.get("text", "")
885
+ if MCP_SERVER_ADDR:
886
+ data = mcp_post("/infer_text", {"text": text})
887
+ return jsonify(data)
888
+ result = manager.infer_text(text)
889
+ return jsonify(result)
890
+
891
+ @app.route("/plot.png")
892
+ def plot_png():
893
+ if MCP_SERVER_ADDR:
894
+ resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png")
895
+ resp.raise_for_status()
896
+ return send_file(io.BytesIO(resp.content), mimetype="image/png")
897
+ fig, _ = plot_telemetry(manager.metrics)
898
+ buf = io.BytesIO()
899
+ fig.savefig(buf, format="png")
900
+ plt.close(fig)
901
+ buf.seek(0)
902
+ return send_file(buf, mimetype="image/png")
903
+
904
+
905
+ def run_dashboard(host: str | None = None, port: int | None = None,
906
+ snapshot_dir: str | None = None, telemetry_log: str | None = None) -> None:
907
+ """Launch the Flask dashboard server."""
908
+ env_host = os.getenv("HOST", "0.0.0.0")
909
+ env_port = int(os.getenv("PORT", "5000"))
910
+ host = host or env_host
911
+ port = port or env_port
912
+ global manager
913
+ if manager is None:
914
+ manager = ModelManager(snapshot_dir, telemetry_log)
915
+ app.run(host=host, port=port, debug=True)
916
+
917
+
918
+ if __name__ == "__main__":
919
+ import argparse
920
+
921
+ parser = argparse.ArgumentParser(description="Run dashboard server")
922
+ parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0"))
923
+ parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000")))
924
+ parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots"))
925
+ parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG"))
926
+ args = parser.parse_args()
927
+ run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log)
BitTransformerLM/bit_transformer/distil.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .model import BitTransformerLM
10
+
11
+
12
+ @dataclass
13
+ class TelemetryLog:
14
+ """Telemetry container holding attention maps across steps.
15
+
16
+ Attributes:
17
+ attention_maps: Tensor of shape [steps, heads, seq, seq].
18
+ """
19
+
20
+ attention_maps: torch.Tensor
21
+
22
+
23
+ def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM:
24
+ """Return a pruned copy of ``model`` according to attention telemetry.
25
+
26
+ Args:
27
+ model: Teacher model to distill from.
28
+ scale: Fraction of weights to retain (0 < scale <= 1).
29
+ telemetry: Logged attention maps used to estimate parameter importance.
30
+
31
+ This function computes an importance score for each weight in the model's
32
+ linear layers using the supplied attention maps. The score is the mean
33
+ activation over time multiplied by the number of visits (non-zero
34
+ attention). The bottom ``(1 - scale)`` fraction of weights in each layer are
35
+ zeroed out, yielding a sparsified student model.
36
+ """
37
+ if not (0.0 < scale <= 1.0):
38
+ raise ValueError("scale must lie in (0, 1].")
39
+
40
+ # Clone the model so the teacher remains untouched.
41
+ student = BitTransformerLM(
42
+ d_model=model.d_model,
43
+ nhead=model.layers[0].self_attn.num_heads,
44
+ num_layers=model.num_layers,
45
+ dim_feedforward=model.layers[0].linear1.out_features,
46
+ max_seq_len=model.pos_enc.pe.size(0),
47
+ lambda_K=model.lambda_K,
48
+ lambda_C=model.lambda_C,
49
+ lambda_S=model.lambda_S,
50
+ reversible=model.reversible,
51
+ use_checkpoint=model.use_checkpoint,
52
+ use_autocast=model.use_autocast,
53
+ use_act=model.use_act,
54
+ act_threshold=model.act_threshold,
55
+ chunk_size=model.chunk_size,
56
+ overlap=model.overlap,
57
+ )
58
+ student.load_state_dict(model.state_dict())
59
+
60
+ attn = telemetry.attention_maps # [steps, heads, seq, seq]
61
+ steps = attn.shape[0]
62
+ heads = attn.shape[1]
63
+ mean_act = attn.mean(dim=(0, 2, 3))
64
+ visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1)
65
+ head_importance = mean_act * visits
66
+ head_importance = head_importance / head_importance.sum()
67
+
68
+ prune_frac = 1.0 - scale
69
+
70
+ for module in student.modules():
71
+ if isinstance(module, nn.Linear):
72
+ weight = module.weight.data
73
+ out_features = weight.size(0)
74
+ if out_features % heads == 0:
75
+ repeats = out_features // heads
76
+ row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1)
77
+ else:
78
+ row_scores = head_importance.mean().expand(out_features, 1)
79
+
80
+ importance = weight.abs() * row_scores
81
+ k = int(importance.numel() * prune_frac)
82
+ if k > 0:
83
+ thresh = torch.topk(importance.view(-1), k, largest=False).values.max()
84
+ mask = importance > thresh
85
+ weight.mul_(mask)
86
+ if module.bias is not None:
87
+ row_mask = mask.view(out_features, -1).any(dim=1)
88
+ module.bias.data.mul_(row_mask)
89
+
90
+ return student
BitTransformerLM/bit_transformer/distributed.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional
4
+
5
+ from torch.distributed.fsdp import FullyShardedDataParallel
6
+ try:
7
+ from torch.distributed.pipeline.sync import Pipe
8
+ except Exception: # pragma: no cover - Pipe may not be available in CPU builds
9
+ Pipe = None
10
+
11
+ from .model import BitTransformerLM
12
+
13
+
14
+ def wrap_fsdp(model: BitTransformerLM, **kwargs) -> FullyShardedDataParallel:
15
+ """Return a ``FullyShardedDataParallel`` wrapped model on the given device."""
16
+ device = kwargs.pop("device_id", torch.device("cpu"))
17
+ model = model.to(device)
18
+ return FullyShardedDataParallel(model, device_id=device, **kwargs)
19
+
20
+
21
+ def make_pipeline(model: BitTransformerLM, chunks: int = 1) -> Pipe:
22
+ """Wrap the model with ``Pipe`` for simple pipeline parallelism.
23
+
24
+ The entire model is placed in an ``nn.Sequential`` so all existing telemetry
25
+ remains available. ``chunks`` controls microbatch splitting.
26
+ """
27
+ if Pipe is None:
28
+ raise RuntimeError("Pipeline parallelism not available in this build")
29
+ seq = nn.Sequential(model)
30
+ return Pipe(seq, chunks=chunks)
BitTransformerLM/bit_transformer/hf_checkpoint.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gzip
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from huggingface_hub import HfApi, hf_hub_download, login
11
+
12
+ REPO_ID = "architect/bittransformerlm"
13
+ FILENAME = "model.pt.gz"
14
+
15
+
16
+ def hf_login(token: Optional[str] = None) -> None:
17
+ """Authenticate with Hugging Face.
18
+
19
+ The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
20
+ variable. If omitted entirely, the library will attempt an interactive login.
21
+ """
22
+ login(token=token)
23
+
24
+
25
+ def save_checkpoint(
26
+ model: torch.nn.Module,
27
+ *,
28
+ repo_id: str = REPO_ID,
29
+ filename: str = FILENAME,
30
+ ) -> None:
31
+ """Upload the model weights to ``repo_id`` under ``filename``.
32
+
33
+ The file within the repository is overwritten each time to avoid
34
+ accumulating checkpoints.
35
+ """
36
+ with tempfile.TemporaryDirectory() as tmp:
37
+ tmp_pt = os.path.join(tmp, "model.pt")
38
+ tmp_gz = os.path.join(tmp, filename)
39
+ torch.save(model.state_dict(), tmp_pt)
40
+ with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
41
+ dst.write(src.read())
42
+ HfApi().upload_file(
43
+ path_or_fileobj=tmp_gz,
44
+ path_in_repo=f"checkpoints/{filename}",
45
+ repo_id=repo_id,
46
+ repo_type="model",
47
+ overwrite=True,
48
+ )
49
+
50
+
51
+ def download_checkpoint(
52
+ dest_path: str,
53
+ *,
54
+ repo_id: str = REPO_ID,
55
+ filename: str = FILENAME,
56
+ ) -> bool:
57
+ """Download the latest checkpoint to ``dest_path``.
58
+
59
+ Returns ``True`` if the checkpoint was successfully retrieved.
60
+ """
61
+ try:
62
+ buf = hf_hub_download(
63
+ repo_id,
64
+ f"checkpoints/{filename}",
65
+ repo_type="model",
66
+ force_download=True,
67
+ )
68
+ except Exception as exc: # pragma: no cover - network errors
69
+ print("Failed to download checkpoint", exc)
70
+ return False
71
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
72
+ shutil.copyfile(buf, dest_path)
73
+ return True
74
+
75
+
76
+ __all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]
BitTransformerLM/bit_transformer/model.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import contextlib
3
+ import logging
4
+ from typing import Dict, List, Tuple
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import sys
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+
13
+ from .torch_utils import cpu_autocast
14
+
15
+ from .optimization import configure_optimizer
16
+ from .compression import decompress_bits
17
+ from .parity import enforce_parity
18
+
19
+ _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
20
+
21
+
22
+ def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
23
+ """Return or create a cached upper-triangular mask."""
24
+ key = (seq_len, device)
25
+ if key not in _mask_cache:
26
+ _mask_cache[key] = torch.triu(
27
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
28
+ )
29
+ return _mask_cache[key]
30
+
31
+ try: # torch.compile may not work on all Python versions
32
+ if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
33
+ compile_fn = torch.compile
34
+ else:
35
+ raise RuntimeError
36
+ except Exception: # pragma: no cover - handle missing torch or unsupported version
37
+
38
+ def compile_fn(fn=None, **kwargs):
39
+ if fn is None:
40
+ return lambda f: f
41
+ return fn
42
+
43
+
44
+ class PositionalEncoding(nn.Module):
45
+ """Sinusoidal positional encoding."""
46
+
47
+ def __init__(self, d_model: int, max_len: int = 1024) -> None:
48
+ super().__init__()
49
+ pe = torch.zeros(max_len, d_model)
50
+ pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
51
+ inv = torch.exp(
52
+ torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
53
+ )
54
+ pe[:, 0::2] = torch.sin(pos * inv)
55
+ pe[:, 1::2] = torch.cos(pos * inv)
56
+ self.register_buffer("pe", pe.unsqueeze(1))
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ """Add positional encoding to input tensor."""
60
+ return x + self.pe[: x.size(0)]
61
+
62
+
63
+ class LoggingTransformerEncoderLayer(nn.Module):
64
+ """Transformer encoder layer that exposes attention weights.
65
+
66
+ It optionally performs chunked attention with a fixed window size.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ d_model: int,
72
+ nhead: int,
73
+ dim_feedforward: int = 512,
74
+ dropout: float = 0.1,
75
+ chunk_size: int | None = None,
76
+ overlap: int = 0,
77
+ full_attn_logging: bool | None = None,
78
+ ) -> None:
79
+ super().__init__()
80
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
81
+ self.chunk_size = chunk_size
82
+ self.overlap = overlap
83
+ if full_attn_logging is None:
84
+ full_attn_logging = False if chunk_size is not None else True
85
+ self.full_attn_logging = full_attn_logging
86
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
87
+ self.dropout = nn.Dropout(dropout)
88
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
89
+ self.norm1 = nn.LayerNorm(d_model)
90
+ self.norm2 = nn.LayerNorm(d_model)
91
+ self.dropout1 = nn.Dropout(dropout)
92
+ self.dropout2 = nn.Dropout(dropout)
93
+ self.activation = F.relu
94
+
95
+ def _chunked_attn(
96
+ self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
97
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
98
+ """Perform chunked self attention with overlap."""
99
+ T, B, D = src.shape
100
+ src_b = src.transpose(0, 1) # [B, T, D]
101
+ C = self.chunk_size or T
102
+ O = self.overlap
103
+ n_chunks = (T + C - 1) // C
104
+ pad_len = n_chunks * C - T
105
+ src_pad = F.pad(src_b, (0, 0, O, pad_len + O))
106
+ chunk_len = C + 2 * O
107
+ chunks = src_pad.unfold(1, chunk_len, C) # [B, n_chunks, chunk_len, D]
108
+ mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None
109
+ out, weights = self.self_attn(
110
+ chunks.reshape(B * n_chunks, chunk_len, D),
111
+ chunks.reshape(B * n_chunks, chunk_len, D),
112
+ chunks.reshape(B * n_chunks, chunk_len, D),
113
+ attn_mask=mask,
114
+ need_weights=True,
115
+ average_attn_weights=False,
116
+ )
117
+ out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
118
+ weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
119
+ :, :, :, O : O + C
120
+ ]
121
+ seq = out.reshape(B, n_chunks * C, D)[:, :T]
122
+ if self.full_attn_logging and C < T:
123
+ full_attn = torch.zeros(
124
+ B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=src.device
125
+ )
126
+ for idx in range(n_chunks):
127
+ s = idx * C
128
+ start = max(s - O, 0)
129
+ end = min(s + C, n_chunks * C)
130
+ src_start = O - (s - start)
131
+ src_end = src_start + (end - start)
132
+ full_attn[:, :, s : s + C, start:end] = weights[:, idx, :, src_start:src_end]
133
+ full_attn = full_attn[:, :, :T, :T]
134
+ attn_out = full_attn.detach()
135
+ else:
136
+ attn_out = torch.empty(0, device=src.device)
137
+ return seq.transpose(0, 1), attn_out
138
+
139
+ def forward(
140
+ self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
141
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ """Return output and attention map."""
143
+ if self.chunk_size is not None:
144
+ attn_output, attn_weights = self._chunked_attn(src, attn_mask)
145
+ else:
146
+ qkv = src.transpose(0, 1)
147
+ attn_output, attn_weights = self.self_attn(
148
+ qkv,
149
+ qkv,
150
+ qkv,
151
+ attn_mask=attn_mask,
152
+ need_weights=True,
153
+ average_attn_weights=False,
154
+ )
155
+ attn_output = attn_output.transpose(0, 1)
156
+ src = src + self.dropout1(attn_output)
157
+ src = self.norm1(src)
158
+ out = self.linear2(self.dropout(self.activation(self.linear1(src))))
159
+ src = src + self.dropout2(out)
160
+ src = self.norm2(src)
161
+ return src, attn_weights.detach()
162
+
163
+
164
+ class ReversibleLoggingTransformerEncoderLayer(nn.Module):
165
+ """Reversible transformer encoder layer with checkpointing."""
166
+
167
+ def __init__(
168
+ self,
169
+ d_model: int,
170
+ nhead: int,
171
+ dim_feedforward: int = 512,
172
+ dropout: float = 0.1,
173
+ chunk_size: int | None = None,
174
+ overlap: int = 0,
175
+ full_attn_logging: bool | None = None,
176
+ ) -> None:
177
+ super().__init__()
178
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
179
+ self.chunk_size = chunk_size
180
+ self.overlap = overlap
181
+ if full_attn_logging is None:
182
+ full_attn_logging = False if chunk_size is not None else True
183
+ self.full_attn_logging = full_attn_logging
184
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
185
+ self.dropout = nn.Dropout(dropout)
186
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
187
+ self.norm1 = nn.LayerNorm(d_model)
188
+ self.norm2 = nn.LayerNorm(d_model)
189
+ self.dropout1 = nn.Dropout(dropout)
190
+ self.dropout2 = nn.Dropout(dropout)
191
+ self.activation = F.relu
192
+
193
+ @compile_fn
194
+ def _sa_block(
195
+ self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ if self.chunk_size is not None:
198
+ T, B, D = x.shape
199
+ x_b = x.transpose(0, 1)
200
+ C = self.chunk_size or T
201
+ O = self.overlap
202
+ n_chunks = (T + C - 1) // C
203
+ pad_len = n_chunks * C - T
204
+ src_pad = F.pad(x_b, (0, 0, O, pad_len + O))
205
+ chunk_len = C + 2 * O
206
+ chunks = src_pad.unfold(1, chunk_len, C)
207
+ mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None
208
+ out, weights = self.self_attn(
209
+ chunks.reshape(B * n_chunks, chunk_len, D),
210
+ chunks.reshape(B * n_chunks, chunk_len, D),
211
+ chunks.reshape(B * n_chunks, chunk_len, D),
212
+ attn_mask=mask,
213
+ need_weights=True,
214
+ average_attn_weights=False,
215
+ )
216
+ out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
217
+ weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
218
+ :, :, :, O : O + C
219
+ ]
220
+ seq = out.reshape(B, n_chunks * C, D)[:, :T]
221
+ if self.full_attn_logging and C < T:
222
+ full_attn = torch.zeros(
223
+ B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device
224
+ )
225
+ for idx in range(n_chunks):
226
+ s = idx * C
227
+ start = max(s - O, 0)
228
+ end = min(s + C, n_chunks * C)
229
+ src_start = O - (s - start)
230
+ src_end = src_start + (end - start)
231
+ full_attn[:, :, s : s + C, start:end] = weights[
232
+ :, idx, :, src_start:src_end
233
+ ]
234
+ full_attn = full_attn[:, :, :T, :T]
235
+ weights = full_attn.detach()
236
+ else:
237
+ weights = torch.empty(0, device=x.device)
238
+ attn_out = seq.transpose(0, 1)
239
+ else:
240
+ qkv = x.transpose(0, 1)
241
+ attn_out, weights = self.self_attn(
242
+ qkv,
243
+ qkv,
244
+ qkv,
245
+ attn_mask=attn_mask,
246
+ need_weights=True,
247
+ average_attn_weights=False,
248
+ )
249
+ attn_out = attn_out.transpose(0, 1)
250
+ x = self.norm1(x + self.dropout1(attn_out))
251
+ return x, weights.detach()
252
+
253
+ @compile_fn
254
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
255
+ out = self.linear2(self.dropout(self.activation(self.linear1(x))))
256
+ x = self.norm2(x + self.dropout2(out))
257
+ return x
258
+
259
+ def forward(
260
+ self,
261
+ x1: torch.Tensor,
262
+ x2: torch.Tensor,
263
+ attn_mask: torch.Tensor | None = None,
264
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
+ y1, weights = self._sa_block(x2, attn_mask)
266
+ y1 = x1 + y1
267
+ y2 = x2 + self._ff_block(y1)
268
+ return y1, y2, weights
269
+
270
+
271
+ class BitTransformerLM(nn.Module):
272
+ """Transformer language model that operates on raw bits (0/1) with telemetry."""
273
+
274
+ def __init__(
275
+ self,
276
+ d_model: int = 128,
277
+ nhead: int = 8,
278
+ num_layers: int = 4,
279
+ dim_feedforward: int = 512,
280
+ max_seq_len: int = 1024,
281
+ lambda_K: float = 1.0,
282
+ lambda_C: float = 1.0,
283
+ lambda_S: float = 1.0,
284
+ reversible: bool = False,
285
+ use_checkpoint: bool = True,
286
+ use_autocast: bool = False,
287
+ use_act: bool = False,
288
+ act_threshold: float = 0.9,
289
+ chunk_size: int | None = None,
290
+ overlap: int = 0,
291
+ full_attn_logging: bool | None = None,
292
+ ) -> None:
293
+ """Create a BitTransformer language model.
294
+
295
+ Args:
296
+ full_attn_logging: When ``False`` and ``chunk_size`` is
297
+ smaller than the sequence length, the model skips
298
+ reconstructing the full ``T×T`` attention matrices for
299
+ telemetry to reduce memory use.
300
+ """
301
+ super().__init__()
302
+ self.d_model = d_model
303
+ self.num_layers = num_layers
304
+ self.lambda_K = lambda_K
305
+ self.lambda_C = lambda_C
306
+ self.lambda_S = lambda_S
307
+ self.reversible = reversible
308
+ self.use_checkpoint = use_checkpoint
309
+ self.use_autocast = use_autocast
310
+ self.use_act = use_act
311
+ self.act_threshold = act_threshold
312
+ self.chunk_size = chunk_size
313
+ self.overlap = overlap
314
+ if full_attn_logging is None:
315
+ full_attn_logging = False if chunk_size is not None else True
316
+ self.full_attn_logging = full_attn_logging
317
+
318
+ # Bit embedding: two possible input values
319
+ self.embedding = nn.Embedding(2, d_model)
320
+ self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len)
321
+
322
+ layer_cls = (
323
+ ReversibleLoggingTransformerEncoderLayer
324
+ if reversible
325
+ else LoggingTransformerEncoderLayer
326
+ )
327
+ self.layers = nn.ModuleList(
328
+ [
329
+ layer_cls(
330
+ d_model=d_model,
331
+ nhead=nhead,
332
+ dim_feedforward=dim_feedforward,
333
+ chunk_size=chunk_size,
334
+ overlap=overlap,
335
+ full_attn_logging=full_attn_logging,
336
+ )
337
+ for _ in range(num_layers)
338
+ ]
339
+ )
340
+
341
+ if self.use_act:
342
+ self.halt_projs = nn.ModuleList(
343
+ [nn.Linear(d_model, 1) for _ in range(num_layers)]
344
+ )
345
+
346
+ self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1
347
+
348
+ def expand_positional_encoding(self, new_len: int) -> None:
349
+ """Expand positional encoding to at least ``new_len``."""
350
+ cur_len = self.pos_enc.pe.size(0)
351
+ if new_len <= cur_len:
352
+ return
353
+ device = self.pos_enc.pe.device
354
+ d_model = self.d_model
355
+ pe = torch.zeros(new_len, d_model, device=device)
356
+ pe[:cur_len] = self.pos_enc.pe.squeeze(1)
357
+ pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1)
358
+ inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model))
359
+ pe[cur_len:, 0::2] = torch.sin(pos * inv)
360
+ pe[cur_len:, 1::2] = torch.cos(pos * inv)
361
+ self.pos_enc.pe = pe.unsqueeze(1)
362
+
363
+ def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None:
364
+ """Update weighting coefficients for telemetry metrics."""
365
+ self.lambda_K = lambda_K
366
+ self.lambda_C = lambda_C
367
+ self.lambda_S = lambda_S
368
+
369
+ def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor:
370
+ """Return raw bit sequences, decompressing if input appears run-length encoded."""
371
+ if codes.dim() <= 1:
372
+ return codes
373
+ needs_decompress = codes.max().item() > 1
374
+ if not needs_decompress and codes.size(1) % 2 == 0:
375
+ vals = codes[:, 0::2]
376
+ if torch.all(vals[:, 1:] != vals[:, :-1]):
377
+ needs_decompress = True
378
+ if not needs_decompress:
379
+ return codes
380
+ seqs = [decompress_bits(row.to(torch.uint8)) for row in codes]
381
+ max_len = max(seq.numel() for seq in seqs)
382
+ padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs]
383
+ return torch.stack(padded)
384
+
385
+ def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor:
386
+ """Approximate negentropy of bit sequences.
387
+
388
+ Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered
389
+ sequence (all zeros or ones) and ``0`` reflects maximal entropy.
390
+ """
391
+ codes = self._maybe_decompress(codes)
392
+ p = codes.float().mean(dim=1)
393
+ entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
394
+ max_e = math.log(2.0)
395
+ return 1 - entropy / max_e
396
+
397
+ def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor:
398
+ """Differentiable proxy for Lempel–Ziv complexity.
399
+
400
+ Values near ``0`` indicate highly compressible sequences while values
401
+ approaching ``1`` correspond to rapid bit alternation.
402
+ """
403
+ codes = self._maybe_decompress(codes)
404
+ diffs = torch.abs(codes[:, 1:] - codes[:, :-1])
405
+ return diffs.float().mean(dim=1)
406
+
407
+ def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
408
+ """Negentropy computed from model logits.
409
+
410
+ Parameters
411
+ ----------
412
+ logits: ``torch.Tensor``
413
+ Logit tensor of shape ``(B, T, 2)``.
414
+ detach: bool, default ``True``
415
+ When ``True`` the computation is detached from the autograd graph.
416
+ """
417
+ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
418
+ prob = logits.softmax(-1)
419
+ if detach:
420
+ prob = prob.detach()
421
+ p = prob[..., 1].mean(dim=1)
422
+ entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
423
+ max_e = math.log(2.0)
424
+ return 1 - entropy / max_e
425
+
426
+ def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
427
+ """LZ complexity proxy computed from logits.
428
+
429
+ Parameters
430
+ ----------
431
+ logits: ``torch.Tensor``
432
+ Logit tensor of shape ``(B, T, 2)``.
433
+ detach: bool, default ``True``
434
+ When ``True`` the computation is detached from the autograd graph.
435
+ """
436
+ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
437
+ prob = logits.softmax(-1)
438
+ if detach:
439
+ prob = prob.detach()
440
+ prob1 = prob[..., 1]
441
+ diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1])
442
+ return diffs.mean(dim=1)
443
+
444
+ def symbiosis_kl_logits(
445
+ self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True
446
+ ) -> torch.Tensor:
447
+ """Symbiosis score from KL divergence to a reference distribution.
448
+
449
+ Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with
450
+ the reference distribution and ``0`` indicating maximal divergence.
451
+ """
452
+ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
453
+ probs = logits.softmax(-1)
454
+ if detach:
455
+ probs = probs.detach()
456
+ ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device)
457
+ kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1)
458
+ max_kl = math.log(2.0)
459
+ return 1 - kl / max_kl
460
+
461
+ def _act_step(
462
+ self,
463
+ hidden: torch.Tensor,
464
+ idx: int,
465
+ halt_prob: torch.Tensor,
466
+ act_state: torch.Tensor,
467
+ halt_history: List[torch.Tensor],
468
+ ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
469
+ """Apply one step of ACT halting logic."""
470
+ p = torch.sigmoid(self.halt_projs[idx](hidden))
471
+ delta = (1 - halt_prob) * p
472
+ halt_prob = halt_prob + delta
473
+ act_state = act_state + hidden * delta
474
+ halt_history.append(halt_prob.detach())
475
+ min_prob = halt_prob.detach().min()
476
+ if dist.is_initialized():
477
+ dist.all_reduce(min_prob, op=dist.ReduceOp.MIN)
478
+ return halt_prob, act_state, min_prob.item() >= self.act_threshold
479
+
480
+ def forward(
481
+ self, bit_seq: torch.Tensor, causal: bool = True
482
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
483
+ """Forward pass returning logits and telemetry from the same graph.
484
+
485
+ By default the model uses causal masking and (optional) chunked
486
+ attention. When ``causal`` is ``False`` the model operates in
487
+ "Diffusion LM" mode. In this mode chunked attention is temporarily
488
+ disabled so that every token can attend to the full sequence
489
+ bidirectionally. The original chunking configuration is restored after
490
+ the forward pass.
491
+ """
492
+
493
+ # Disable chunking when running in bidirectional (non-causal) mode
494
+ orig_chunks = None
495
+ orig_model_chunk = None
496
+ if not causal and self.chunk_size is not None:
497
+ orig_model_chunk = self.chunk_size
498
+ orig_chunks = [layer.chunk_size for layer in self.layers]
499
+ self.chunk_size = None
500
+ for layer in self.layers:
501
+ layer.chunk_size = None
502
+
503
+ try:
504
+ ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext()
505
+ with ctx:
506
+ x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model)
507
+ x = self.pos_enc(x)
508
+
509
+ attn_mask = get_tri_mask(x.size(0), x.device) if causal else None
510
+
511
+ activations: List[torch.Tensor] = []
512
+ attn_maps: List[torch.Tensor] = []
513
+ halt_history: List[torch.Tensor] = []
514
+ if self.use_act:
515
+ halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device)
516
+ act_state = torch.zeros_like(x)
517
+ if self.reversible:
518
+ x1, x2 = x, x
519
+ for idx, layer in enumerate(self.layers):
520
+ if self.use_checkpoint:
521
+ x1, x2, attn = checkpoint.checkpoint(
522
+ layer, x1, x2, attn_mask
523
+ )
524
+ else:
525
+ x1, x2, attn = layer(x1, x2, attn_mask)
526
+ combined = (x1 + x2) / 2
527
+ activations.append(combined)
528
+ if attn.numel() > 0:
529
+ attn_maps.append(attn)
530
+ if self.use_act:
531
+ halt_prob, act_state, should_break = self._act_step(
532
+ combined, idx, halt_prob, act_state, halt_history
533
+ )
534
+ if should_break:
535
+ break
536
+ x = (x1 + x2) / 2
537
+ else:
538
+ for idx, layer in enumerate(self.layers):
539
+ if self.use_checkpoint:
540
+ x, attn = checkpoint.checkpoint(layer, x, attn_mask)
541
+ else:
542
+ x, attn = layer(x, attn_mask)
543
+ activations.append(x)
544
+ if attn.numel() > 0:
545
+ attn_maps.append(attn)
546
+ if self.use_act:
547
+ halt_prob, act_state, should_break = self._act_step(
548
+ x, idx, halt_prob, act_state, halt_history
549
+ )
550
+ if should_break:
551
+ break
552
+ if self.use_act:
553
+ act_state = act_state + x * (1 - halt_prob)
554
+ x = act_state
555
+ logits = self.out_head(x)
556
+
557
+ # Per-layer entropy of activations
558
+ entropies = []
559
+ for act in activations:
560
+ prob = act.softmax(-1)
561
+ ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean()
562
+ entropies.append(ent)
563
+
564
+ attn_entropies = []
565
+ for attn in attn_maps:
566
+ prob = attn # weights are already softmaxed
567
+ ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1)
568
+ ent = ent.mean(1)
569
+ attn_entropies.append(ent)
570
+ if attn_entropies:
571
+ attn_entropy_map = torch.stack(attn_entropies).mean(0)
572
+ else:
573
+ attn_entropy_map = torch.zeros(
574
+ bit_seq.size(0), bit_seq.size(1), device=bit_seq.device
575
+ )
576
+ max_ent = math.log(attn_entropy_map.size(-1))
577
+ attn_entropy_map = attn_entropy_map / max_ent
578
+ attn_entropy = attn_entropy_map.mean(1)
579
+
580
+ logits_bt = logits.transpose(0, 1)
581
+ negentropy_in = self.negentropy_kpi(bit_seq)
582
+ lz_in = self.lz_complexity(bit_seq.float())
583
+ negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False)
584
+ lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False)
585
+ kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False)
586
+
587
+ raw_sym = (
588
+ (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2
589
+ + negentropy_logits_b * lz_logits_b
590
+ - self.lambda_S * kl_div_b
591
+ - 0.1 * attn_entropy
592
+ )
593
+ weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach()
594
+ raw_sym = raw_sym - 0.01 * weight_norm
595
+ sym_score = torch.sigmoid(raw_sym)
596
+
597
+ B, T = bit_seq.shape
598
+ assert logits_bt.shape[:2] == (B, T)
599
+ assert attn_entropy_map.shape == (B, T)
600
+
601
+ telemetry = {
602
+ "activations": activations,
603
+ "attention_maps": attn_maps,
604
+ "attention_entropy": attn_entropy_map,
605
+ "entropy": entropies,
606
+ "attention_entropy_mean": attn_entropy,
607
+ "negentropy_input": negentropy_in.detach(),
608
+ "lz_complexity_input": lz_in.detach(),
609
+ "negentropy_logits": negentropy_logits_b.detach(),
610
+ "lz_complexity_logits": lz_logits_b.detach(),
611
+ "symbiosis_kl": kl_div_b.detach(),
612
+ "symbiosis_score": sym_score.detach(),
613
+ }
614
+ if self.use_act:
615
+ telemetry["halt_probs"] = halt_history
616
+
617
+ return logits_bt, telemetry
618
+ finally:
619
+ if orig_chunks is not None:
620
+ self.chunk_size = orig_model_chunk
621
+ for layer, chunk in zip(self.layers, orig_chunks):
622
+ layer.chunk_size = chunk
623
+
624
+ def forward_compressed(
625
+ self, compressed_bits, causal: bool = True
626
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
627
+ """Decompress bit sequences then run the normal forward pass."""
628
+ if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1:
629
+ sequences = [decompress_bits(compressed_bits).to(torch.long)]
630
+ else:
631
+ sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits]
632
+ lengths = [seq.numel() for seq in sequences]
633
+ if len(set(lengths)) != 1:
634
+ raise ValueError("Sequences decompress to different lengths")
635
+ bits = torch.stack(sequences)
636
+ return self.forward(bits, causal=causal)
637
+
638
+ def _current_params(self) -> Dict:
639
+ """Return a dictionary with the current model hyperparameters."""
640
+ return {
641
+ "d_model": self.d_model,
642
+ "nhead": self.layers[0].self_attn.num_heads,
643
+ "num_layers": self.num_layers,
644
+ "dim_feedforward": self.layers[0].linear1.out_features,
645
+ "max_seq_len": self.pos_enc.pe.size(0),
646
+ "lambda_K": self.lambda_K,
647
+ "lambda_C": self.lambda_C,
648
+ "lambda_S": self.lambda_S,
649
+ "reversible": self.reversible,
650
+ "use_checkpoint": self.use_checkpoint,
651
+ "use_autocast": self.use_autocast,
652
+ "use_act": self.use_act,
653
+ "act_threshold": self.act_threshold,
654
+ "chunk_size": self.chunk_size,
655
+ "overlap": self.overlap,
656
+ }
657
+
658
+ def double_width(self) -> "BitTransformerLM":
659
+ """Return a copy of the model with doubled hidden size."""
660
+ from .scale import expand_model
661
+
662
+ params = self._current_params()
663
+ params["d_model"] *= 2
664
+ params["dim_feedforward"] *= 2
665
+ return expand_model(self, params)
666
+
667
+ def double_layers(self) -> "BitTransformerLM":
668
+ """Return a copy of the model with twice as many layers."""
669
+ from .scale import expand_model
670
+
671
+ params = self._current_params()
672
+ params["num_layers"] *= 2
673
+ return expand_model(self, params)
674
+
675
+ def double_length(self) -> "BitTransformerLM":
676
+ """Return a copy of the model with doubled maximum sequence length."""
677
+ from .scale import expand_model
678
+
679
+ params = self._current_params()
680
+ params["max_seq_len"] *= 2
681
+ params["chunk_size"] = params["max_seq_len"]
682
+ return expand_model(self, params)
683
+
684
+ def train_full_sequence(
685
+ self,
686
+ bits: torch.Tensor,
687
+ *,
688
+ ctx_bits: int = 4096,
689
+ detach_every_n: int = 1_048_576,
690
+ ) -> float:
691
+ """Train on a long bit tensor using sliding windows.
692
+
693
+ Parameters
694
+ ----------
695
+ bits: ``torch.Tensor``
696
+ 1D tensor containing the full bit sequence.
697
+ ctx_bits: int
698
+ Size of the training context window.
699
+ detach_every_n: int
700
+ Interval in bits for optimizer updates and graph detachment.
701
+ Returns
702
+ -------
703
+ float
704
+ Mean loss over all windows.
705
+ """
706
+ self.train()
707
+ optimizer, scheduler = configure_optimizer(
708
+ self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits)
709
+ )
710
+ accum = 0
711
+ total_loss = 0.0
712
+ count = 0
713
+ for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits):
714
+ segment = bits[start : start + ctx_bits + 1].unsqueeze(0)
715
+ logits, _ = self(segment)
716
+ pred = logits[:, :-1, :].reshape(-1, 2)
717
+ target = segment[:, 1:].reshape(-1)
718
+ loss = F.cross_entropy(pred, target)
719
+ loss.backward()
720
+ accum += ctx_bits
721
+ total_loss += loss.item()
722
+ count += 1
723
+ if accum >= detach_every_n:
724
+ torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
725
+ optimizer.step()
726
+ scheduler.step()
727
+ optimizer.zero_grad()
728
+ accum = 0
729
+ if accum > 0:
730
+ torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
731
+ optimizer.step()
732
+ scheduler.step()
733
+ optimizer.zero_grad()
734
+ return total_loss / max(1, count)
735
+
736
+
737
+ def infer_long_sequence(
738
+ model: BitTransformerLM,
739
+ bits: torch.Tensor,
740
+ *,
741
+ ctx_bits: int = 4096,
742
+ overlap: int = 256,
743
+ ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]:
744
+ """Infer a long bit sequence using sliding windows with overlap."""
745
+ model.eval()
746
+ device = next(model.parameters()).device
747
+ bits = bits.to(device)
748
+ step = ctx_bits - overlap
749
+ outputs: List[torch.Tensor] = []
750
+ logs: List[Dict[str, torch.Tensor]] = []
751
+ for start in range(0, bits.numel(), step):
752
+ window = bits[start : start + ctx_bits].unsqueeze(0)
753
+ logits, tele = model(window, causal=True)
754
+ pred = logits.argmax(-1).squeeze(0)
755
+ outputs.append(pred)
756
+ logs.append(tele)
757
+ out = torch.cat(outputs)[: bits.numel()]
758
+ return out, logs
759
+
760
+
761
+ def diffusion_inference(
762
+ model: BitTransformerLM,
763
+ *,
764
+ length: int,
765
+ steps: int = 8,
766
+ batch_size: int = 1,
767
+ init_bits: torch.Tensor | None = None,
768
+ schedule: str = "linear",
769
+ ) -> torch.Tensor:
770
+ """Generate bit sequences using iterative denoising diffusion.
771
+
772
+ Parameters
773
+ ----------
774
+ model: ``BitTransformerLM``
775
+ The model used for denoising. It is run in non-causal mode with
776
+ chunked attention disabled, enabling full-context bidirectional
777
+ attention.
778
+ length: int
779
+ Length of the bit sequences to generate.
780
+ steps: int, default ``8``
781
+ Number of denoising iterations. More steps generally yield sharper
782
+ samples at the cost of compute.
783
+ batch_size: int, default ``1``
784
+ Number of sequences to generate in parallel.
785
+ init_bits: ``torch.Tensor`` | ``None``
786
+ Optional initial noisy bits of shape ``(batch_size, length)``. When
787
+ ``None`` random noise is used.
788
+ schedule: str, default ``"linear"``
789
+ Noise schedule for the denoising mask probability. Options are
790
+ ``"linear"``, ``"cosine"``, and ``"exp"``.
791
+
792
+ Returns
793
+ -------
794
+ ``torch.Tensor``
795
+ A tensor of shape ``(batch_size, length)`` containing generated bits.
796
+ """
797
+
798
+ model.eval()
799
+ device = next(model.parameters()).device
800
+ if init_bits is None:
801
+ bits = torch.randint(0, 2, (batch_size, length), device=device)
802
+ else:
803
+ bits = init_bits.to(device)
804
+ if bits.shape != (batch_size, length):
805
+ raise ValueError("init_bits must have shape (batch_size, length)")
806
+
807
+ for step in range(steps):
808
+ logits, _ = model(bits, causal=False)
809
+ prob = logits.softmax(-1)[..., 1]
810
+ t = (step + 1) / steps
811
+ if schedule == "linear":
812
+ mask_prob = 1.0 - t
813
+ elif schedule == "cosine":
814
+ mask_prob = math.cos(math.pi * t / 2)
815
+ elif schedule == "exp":
816
+ mask_prob = math.exp(-5 * t)
817
+ else:
818
+ raise ValueError(f"unknown schedule: {schedule}")
819
+ mask = (torch.rand_like(bits.float()) < mask_prob).long()
820
+ sampled = torch.bernoulli(prob).long()
821
+ bits = torch.where(mask.bool(), sampled, bits)
822
+ if bits.shape[-1] % 9 == 0:
823
+ bits, corrections = enforce_parity(bits)
824
+ if corrections:
825
+ logging.info("Parity corrections applied: %d", corrections)
826
+ try:
827
+ from .safety import hil_safe_inference
828
+
829
+ hil_safe_inference(model, bits, causal=False, strict=False)
830
+ except RuntimeError as exc:
831
+ logging.warning("Safety gate warning: %s", exc)
832
+ return bits
833
+
834
+
835
+ def example_usage() -> float:
836
+ """Run the example from the README and return the loss."""
837
+ B, L = 4, 16
838
+ model = BitTransformerLM(
839
+ d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L
840
+ )
841
+ bits = torch.randint(0, 2, (B, L), dtype=torch.long)
842
+ logits, _ = model(bits)
843
+ pred = logits[:, :-1, :].reshape(-1, 2)
844
+ target = bits[:, 1:].reshape(-1)
845
+ loss = F.cross_entropy(pred, target)
846
+ return loss.item()
847
+
848
+
849
+ def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]:
850
+ """Demonstrate a training step where metrics do not affect gradients."""
851
+ B, L = 4, 16
852
+ model = BitTransformerLM(
853
+ d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L
854
+ )
855
+ optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1)
856
+
857
+ bits = torch.randint(0, 2, (B, L), dtype=torch.long)
858
+ logits, telemetry = model(bits)
859
+
860
+ pred = logits[:, :-1, :].reshape(-1, 2)
861
+ target = bits[:, 1:].reshape(-1)
862
+ loss = F.cross_entropy(pred, target)
863
+
864
+ loss.backward()
865
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
866
+ optimizer.step()
867
+ scheduler.step()
868
+ optimizer.zero_grad()
869
+ return loss.item(), telemetry
870
+
871
+
872
+ if __name__ == "__main__":
873
+ loss, telemetry = example_training_step()
874
+ print("Composite loss:", loss)
875
+ print("Telemetry keys:", list(telemetry.keys()))
BitTransformerLM/bit_transformer/optimization.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.optim import AdamW
4
+ from torch.optim.lr_scheduler import OneCycleLR
5
+
6
+
7
+ def configure_optimizer(
8
+ model: nn.Module,
9
+ lr: float = 1e-3,
10
+ weight_decay: float = 0.01,
11
+ total_steps: int = 100
12
+ ):
13
+ """Return AdamW optimizer with OneCycleLR scheduler."""
14
+ optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
15
+ scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps)
16
+ return optimizer, scheduler
17
+
18
+
19
+ def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float:
20
+ """Scale the learning rate of all param groups by ``factor``.
21
+
22
+ Parameters
23
+ ----------
24
+ optimizer:
25
+ The optimizer whose learning rate will be adjusted.
26
+ factor:
27
+ Multiplicative factor applied to the current learning rate.
28
+
29
+ Returns
30
+ -------
31
+ float
32
+ The updated learning rate of the first parameter group.
33
+ """
34
+ for param_group in optimizer.param_groups:
35
+ param_group["lr"] *= factor
36
+ return optimizer.param_groups[0]["lr"]
37
+
BitTransformerLM/bit_transformer/parity.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def enforce_parity(bits: torch.Tensor) -> tuple[torch.Tensor, int]:
4
+ """Fix parity bits so each 9-bit chunk has even parity.
5
+
6
+ Parameters
7
+ ----------
8
+ bits: ``torch.Tensor``
9
+ Tensor of shape ``(..., length)`` where ``length`` is a multiple of 9.
10
+
11
+ Returns
12
+ -------
13
+ tuple[torch.Tensor, int]
14
+ Corrected tensor and number of bytes that were adjusted.
15
+ """
16
+ if bits.shape[-1] % 9 != 0:
17
+ raise ValueError("Bit stream length must be multiple of 9")
18
+ flat = bits.clone().view(-1, 9)
19
+ payload = flat[:, :8]
20
+ parity = flat[:, 8]
21
+ new_parity = payload.sum(dim=1) % 2
22
+ corrections = (parity != new_parity).sum().item()
23
+ flat[:, 8] = new_parity
24
+ return flat.view_as(bits), corrections
BitTransformerLM/bit_transformer/quantization.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.ao.quantization.fake_quantize import FakeQuantize
5
+ from torch.ao.quantization.observer import MinMaxObserver
6
+ from torch.ao.quantization.qconfig import QConfig
7
+ from torch.ao.quantization import convert
8
+
9
+ from .model import BitTransformerLM
10
+
11
+
12
+ def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
13
+ """Return a dynamically quantized copy of the model for inference."""
14
+ quantized = torch.quantization.quantize_dynamic(
15
+ model, {nn.Linear}, dtype=dtype
16
+ )
17
+ return quantized
18
+
19
+
20
+ class FourBitObserver(MinMaxObserver):
21
+ """Min-max observer configured for 4-bit quantization."""
22
+
23
+ def __init__(self, **kwargs):
24
+ super().__init__(
25
+ quant_min=0,
26
+ quant_max=15,
27
+ dtype=torch.quint8,
28
+ qscheme=torch.per_tensor_affine,
29
+ **kwargs,
30
+ )
31
+
32
+
33
+ FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)
34
+
35
+ four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)
36
+
37
+
38
+ class QATLinear(nn.Linear):
39
+ """Linear layer with fake quantization for QAT."""
40
+
41
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
42
+ super().__init__(in_features, out_features, bias)
43
+ self.weight_fake_quant = FourBitFakeQuantize()
44
+ self.activation_post_process = FourBitFakeQuantize()
45
+
46
+ @classmethod
47
+ def from_float(cls, mod: nn.Linear) -> "QATLinear":
48
+ qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
49
+ qat.weight = mod.weight
50
+ qat.bias = mod.bias
51
+ return qat
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.activation_post_process(x)
55
+ w = self.weight_fake_quant(self.weight)
56
+ return nn.functional.linear(x, w, self.bias)
57
+
58
+
59
+ def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
60
+ """Prepare BitTransformerLM for quantization-aware training."""
61
+
62
+ for name, module in model.named_children():
63
+ if isinstance(module, nn.Linear):
64
+ setattr(model, name, QATLinear.from_float(module))
65
+ else:
66
+ prepare_qat_fx(module)
67
+ return model
68
+
69
+
70
+ def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
71
+ """Convert a QAT-prepared model to a quantized version."""
72
+
73
+ for name, module in model.named_children():
74
+ if isinstance(module, QATLinear):
75
+ w = module.weight.data
76
+ qmin, qmax = 0, 15
77
+ min_w = w.min()
78
+ max_w = w.max()
79
+ scale = (max_w - min_w) / (qmax - qmin + 1e-8)
80
+ zero_point = qmin - torch.round(min_w / scale)
81
+ q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
82
+ new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
83
+ new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
84
+ if module.bias is not None:
85
+ new_mod.bias = nn.Parameter(module.bias.data)
86
+ setattr(model, name, new_mod)
87
+ else:
88
+ convert_qat_fx(module)
89
+ return model
BitTransformerLM/bit_transformer/safety.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import torch
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ from .model import BitTransformerLM
7
+
8
+
9
+ class SafetyGate:
10
+ """Exponential moving average safety gate with burn-in."""
11
+
12
+ def __init__(
13
+ self,
14
+ *,
15
+ c_floor: float = 0.3,
16
+ s_floor: float = 0.5,
17
+ decay: float = 0.9,
18
+ burn_in: int = 10,
19
+ ) -> None:
20
+ self.c_floor = c_floor
21
+ self.s_floor = s_floor
22
+ self.decay = decay
23
+ self.burn_in = burn_in
24
+ self.step = 0
25
+ self._c_ema: Optional[float] = None
26
+ self._s_ema: Optional[float] = None
27
+
28
+ def should_trigger(self, c_val: float, s_val: float) -> bool:
29
+ """Update EMA scores and check if gating should trigger."""
30
+
31
+ self.step += 1
32
+ if self._c_ema is None:
33
+ self._c_ema = c_val
34
+ self._s_ema = s_val
35
+ else:
36
+ self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val
37
+ self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val
38
+ if self.step <= self.burn_in:
39
+ return False
40
+ return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor
41
+
42
+
43
+ def hil_safe_inference(
44
+ model: BitTransformerLM,
45
+ bit_seq: torch.Tensor,
46
+ c_floor: float = 0.3,
47
+ s_floor: float = 0.5,
48
+ *,
49
+ causal: bool = True,
50
+ strict: bool = True,
51
+ gate: Optional[SafetyGate] = None,
52
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
53
+ """Run inference with telemetry gating.
54
+
55
+ Parameters
56
+ ----------
57
+ model:
58
+ Model to run inference with.
59
+ bit_seq:
60
+ Input bit sequences.
61
+ c_floor, s_floor:
62
+ Minimum LZ complexity and symbiosis score required for safe output.
63
+ causal:
64
+ Whether to run the model in causal (autoregressive) mode. When ``False``
65
+ the model performs full-context Diffusion LM inference.
66
+ strict:
67
+ If ``False`` the function returns model outputs even when the floors are
68
+ not met instead of raising ``RuntimeError``.
69
+ gate:
70
+ Optional :class:`SafetyGate` that applies EMA smoothing and burn-in
71
+ before enforcing the floors.
72
+ """
73
+ model.eval()
74
+ with torch.no_grad():
75
+ logits, telemetry = model(bit_seq, causal=causal)
76
+ c_val = float(telemetry["lz_complexity_logits"].mean().item())
77
+ s_val = float(telemetry["symbiosis_score"].mean().item())
78
+ c_val = max(0.0, min(1.0, c_val))
79
+ s_val = max(0.0, min(1.0, s_val))
80
+ if gate is not None:
81
+ triggered = gate.should_trigger(c_val, s_val)
82
+ else:
83
+ triggered = c_val <= c_floor or s_val <= s_floor
84
+ if strict and triggered:
85
+ raise RuntimeError(
86
+ f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}"
87
+ )
88
+ return logits.argmax(-1), telemetry
89
+
90
+
91
+ def demo_hil_safety() -> None:
92
+ """Demonstrate gating on random bits."""
93
+ bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
94
+ model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
95
+ try:
96
+ out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0)
97
+ print("Safe output bits:", out.squeeze(0).tolist())
98
+ except RuntimeError as e:
99
+ print("Gate triggered:", e)
100
+
101
+
102
+ def safe_sample_with_retry(
103
+ model: BitTransformerLM,
104
+ bit_seq: torch.Tensor,
105
+ c_floor: float = 0.3,
106
+ s_floor: float = 0.5,
107
+ *,
108
+ causal: bool = True,
109
+ max_retries: int = 3,
110
+ backoff: float = 0.1,
111
+ gate: Optional[SafetyGate] = None,
112
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
113
+ """Run :func:`hil_safe_inference` with automatic retries.
114
+
115
+ The helper retries failed safety checks by toggling diffusion mode and
116
+ refreshing the input bits. An exponential backoff is applied between
117
+ attempts and warnings are logged for each retry.
118
+
119
+ Parameters
120
+ ----------
121
+ gate:
122
+ Optional :class:`SafetyGate` instance shared across retries to apply
123
+ EMA smoothing and burn-in.
124
+
125
+ Returns
126
+ -------
127
+ Tuple[torch.Tensor, Dict[str, torch.Tensor]]
128
+ The sampled bits and associated telemetry.
129
+ """
130
+
131
+ for attempt in range(max_retries):
132
+ try:
133
+ return hil_safe_inference(
134
+ model,
135
+ bit_seq,
136
+ c_floor,
137
+ s_floor,
138
+ causal=causal,
139
+ strict=True,
140
+ gate=gate,
141
+ )
142
+ except RuntimeError as exc: # safety gate triggered
143
+ logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc)
144
+ if attempt >= max_retries - 1:
145
+ raise
146
+ time.sleep(backoff * (2 ** attempt))
147
+ causal = False # retry in diffusion mode
148
+ bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)
149
+
BitTransformerLM/bit_transformer/scale.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict
3
+ from .model import BitTransformerLM
4
+ import torch.nn as nn
5
+
6
+
7
+ def expand_model(model: BitTransformerLM, new_params: Dict) -> BitTransformerLM:
8
+ """Return a new model with updated params and copied weights."""
9
+ new_model = BitTransformerLM(**new_params)
10
+ new_state = new_model.state_dict()
11
+ old_state = model.state_dict()
12
+
13
+ for k, v in old_state.items():
14
+ if k in new_state:
15
+ dest = new_state[k]
16
+ slices = tuple(slice(0, min(d, s)) for d, s in zip(dest.shape, v.shape))
17
+ dest[slices].copy_(v[slices])
18
+ if dest.shape != v.shape:
19
+ mask = torch.ones_like(dest, dtype=torch.bool)
20
+ mask[slices] = False
21
+ if "bias" in k:
22
+ dest[mask] = 0.0
23
+ else:
24
+ dest[mask] = 0.001 * torch.randn_like(dest[mask])
25
+
26
+ for k, v in new_state.items():
27
+ if k not in old_state:
28
+ if "bias" in k:
29
+ v.zero_()
30
+ elif v.dim() > 1:
31
+ nn.init.normal_(v, mean=0.0, std=1e-3)
32
+ else:
33
+ v.zero_()
34
+
35
+ new_model.load_state_dict(new_state)
36
+ return new_model
BitTransformerLM/bit_transformer/static/style.css ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary: #1e40af;
3
+ --bg: #f5f6fa;
4
+ }
5
+
6
+ body {
7
+ font-family: Arial, sans-serif;
8
+ background-color: var(--bg);
9
+ margin: 0;
10
+ padding: 0;
11
+ line-height: 1.5;
12
+ color: #333;
13
+ }
14
+
15
+ .container {
16
+ max-width: 900px;
17
+ margin: 0 auto;
18
+ padding-bottom: 2rem;
19
+ }
20
+
21
+ h1 {
22
+ text-align: center;
23
+ background: var(--primary);
24
+ color: #fff;
25
+ margin: 0;
26
+ padding: 1rem 0;
27
+ }
28
+
29
+ section {
30
+ background: #fff;
31
+ margin: 1rem auto;
32
+ padding: 1rem 1.5rem;
33
+ border-radius: 8px;
34
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
35
+ width: 90%;
36
+ max-width: 800px;
37
+ }
38
+
39
+ section h2 {
40
+ margin-top: 0;
41
+ color: var(--primary);
42
+ font-size: 1.25rem;
43
+ }
44
+
45
+ form {
46
+ display: flex;
47
+ flex-wrap: wrap;
48
+ gap: 0.5rem 1rem;
49
+ }
50
+
51
+ form input[type="text"],
52
+ form input[type="number"],
53
+ form textarea {
54
+ flex: 1 1 200px;
55
+ padding: 0.4em;
56
+ border: 1px solid #ccc;
57
+ border-radius: 4px;
58
+ }
59
+
60
+ form button,
61
+ button#scaleBtn {
62
+ padding: 0.4em 0.8em;
63
+ border: none;
64
+ background: var(--primary);
65
+ color: #fff;
66
+ border-radius: 4px;
67
+ cursor: pointer;
68
+ }
69
+
70
+ form button:hover,
71
+ button#scaleBtn:hover {
72
+ background-color: #1d4ed8;
73
+ }
74
+
75
+ pre, p#trainOut {
76
+ background: #f0f0f0;
77
+ padding: 0.5rem;
78
+ border-radius: 4px;
79
+ overflow-x: auto;
80
+ }
81
+
82
+ label {
83
+ display: flex;
84
+ align-items: center;
85
+ gap: 0.5rem;
86
+ }
87
+
88
+ img#plot {
89
+ max-width: 100%;
90
+ height: auto;
91
+ display: block;
92
+ margin: auto;
93
+ }
BitTransformerLM/bit_transformer/telemetry.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, List, TYPE_CHECKING
3
+
4
+ import torch
5
+ from sklearn.cluster import KMeans
6
+
7
+ if TYPE_CHECKING: # pragma: no cover
8
+ from .model import BitTransformerLM
9
+
10
+
11
+ class TelemetrySynthesizer:
12
+ """Analyze telemetry batches and cluster activation patterns."""
13
+
14
+ def __init__(self, n_clusters: int = 2) -> None:
15
+ self.n_clusters = n_clusters
16
+
17
+ def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
18
+ """Compute activation/attention summaries for a single telemetry dict."""
19
+ acts = telemetry["activations"]
20
+ attn = telemetry["attention_maps"]
21
+ summaries = []
22
+ for a, m in zip(acts, attn):
23
+ mean = a.mean().item()
24
+ var = a.var(unbiased=False).item()
25
+ prob = m.softmax(-1)
26
+ entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
27
+ summaries.append([mean, var, entropy])
28
+ return np.array(summaries).ravel()
29
+
30
+ def synthesize(
31
+ self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
32
+ ) -> Dict[str, List]:
33
+ """Cluster telemetry summaries and return cluster info."""
34
+ data = np.stack([self._summary(t) for t in telemetries])
35
+ km = KMeans(n_clusters=self.n_clusters, n_init=1)
36
+ labels = km.fit_predict(data)
37
+ representatives: List[List[int]] = []
38
+ for c in range(self.n_clusters):
39
+ idx = np.where(labels == c)[0]
40
+ if len(idx) > 0:
41
+ representatives.append(bit_seqs[idx[0]].tolist())
42
+ else:
43
+ representatives.append([])
44
+ return {"cluster_assignments": labels.tolist(), "representatives": representatives}
45
+
46
+ def cluster_sequences(
47
+ self, model: "BitTransformerLM", bit_seqs: torch.Tensor
48
+ ) -> List[List[int]]:
49
+ """Run the model to gather telemetry and return representative sequences.
50
+
51
+ Parameters
52
+ ----------
53
+ model: BitTransformerLM
54
+ Model used to compute telemetry for each sequence.
55
+ bit_seqs: torch.Tensor
56
+ Tensor containing one bit sequence per row.
57
+
58
+ Returns
59
+ -------
60
+ list[list[int]]
61
+ Representative sequences chosen from KMeans clusters.
62
+ """
63
+ telemetries: List[Dict[str, List[torch.Tensor]]] = []
64
+ with torch.no_grad():
65
+ for seq in bit_seqs:
66
+ _, tele = model(seq.unsqueeze(0))
67
+ telemetries.append(tele)
68
+ info = self.synthesize(telemetries, bit_seqs)
69
+ return info["representatives"]
70
+
71
+
72
+ def detect_metric_drift(
73
+ metrics_log: Dict[str, List[float]],
74
+ window: int = 10,
75
+ threshold: float = 0.2,
76
+ ) -> Dict[str, bool]:
77
+ """Detect metric drift between consecutive windows.
78
+
79
+ Args:
80
+ metrics_log: History of scalar metrics keyed by name.
81
+ window: Number of recent steps to compare.
82
+ threshold: Absolute difference required to flag drift.
83
+
84
+ Returns:
85
+ Dictionary mapping metric keys to a boolean drift indicator.
86
+ """
87
+ drift = {}
88
+ for key, values in metrics_log.items():
89
+ if len(values) < window * 2:
90
+ drift[key] = False
91
+ continue
92
+ recent = np.mean(values[-window:])
93
+ prev = np.mean(values[-2 * window : -window])
94
+ drift[key] = abs(recent - prev) > threshold
95
+ return drift
BitTransformerLM/bit_transformer/templates/dashboard.html ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Bit Transformer Dashboard</title>
6
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
7
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
8
+ </head>
9
+ <body>
10
+ <h1>Bit Transformer Dashboard</h1>
11
+ <div class="container">
12
+ <section>
13
+ <h2>Initialize Model</h2>
14
+ <form id="initForm">
15
+ d_model: <input type="number" name="d_model" value="{{ defaults.d_model }}" title="Model width (default {{ defaults.d_model }})"><br>
16
+ nhead: <input type="number" name="nhead" value="{{ defaults.nhead }}" title="Attention heads (default {{ defaults.nhead }})"><br>
17
+ num_layers: <input type="number" name="num_layers" value="{{ defaults.num_layers }}" title="Transformer layers (default {{ defaults.num_layers }})"><br>
18
+ dim_feedforward: <input type="number" name="dim_feedforward" value="{{ defaults.dim_feedforward }}" title="Feedforward dim (default {{ defaults.dim_feedforward }})"><br>
19
+ max_seq_len: <input type="number" name="max_seq_len" value="{{ defaults.max_seq_len }}" title="Max sequence length (default {{ defaults.max_seq_len }})"><br>
20
+ chunk_size: <input type="number" name="chunk_size" title="Chunked attention size"><br>
21
+ overlap: <input type="number" name="overlap" value="{{ defaults.overlap }}" title="Sliding window overlap"><br>
22
+ Reversible: <input type="checkbox" name="reversible" id="reversible_box" title="Use reversible layers (default {{ defaults.reversible }})"><br>
23
+ Gradient Checkpointing: <input type="checkbox" name="use_checkpoint" id="checkpoint_box" checked title="Enable gradient checkpointing (default {{ defaults.use_checkpoint }})"><br>
24
+ act_threshold: <input type="number" step="0.01" name="act_threshold" value="{{ defaults.act_threshold }}" title="ACT halt threshold (default {{ defaults.act_threshold }})"><br>
25
+ c_floor: <input type="number" step="0.01" name="c_floor" value="{{ c_floor }}" title="Complexity floor"><br>
26
+ s_floor: <input type="number" step="0.01" name="s_floor" value="{{ s_floor }}" title="Symbiosis floor"><br>
27
+ <button type="submit">Init</button>
28
+ </form>
29
+ </section>
30
+ <section>
31
+ <h2>Train Step</h2>
32
+ <form id="trainForm">
33
+ Bits (e.g. 0 1 0 1): <input type="text" name="bits" value="0 1 0 1"><br>
34
+ Upload file: <input type="file" id="train_file"><br>
35
+ <button type="submit">Train</button>
36
+ </form>
37
+ <label>Load sample dataset:
38
+ <select id="datasetSelect">
39
+ <option value="">--Select--</option>
40
+ <option value="wikitext2_train">Wikitext-2 (train)</option>
41
+ <option value="wikitext2_validation">Wikitext-2 (validation)</option>
42
+ </select>
43
+ </label>
44
+ <p id="trainOut"></p>
45
+ </section>
46
+ <section>
47
+ <h2>Scale Up</h2>
48
+ Width Mult: <input type="number" step="0.1" id="width_mult" value="1.0"><br>
49
+ <button id="scaleBtn">Scale Model</button>
50
+ </section>
51
+ <section>
52
+ <h2>Collapse Submodel</h2>
53
+ <form id="collapseForm">
54
+ Cluster Bits (JSON array of arrays):<br>
55
+ <textarea name="clusters" rows="3" cols="40">[[0,1,0,1],[1,1,0,0]]</textarea><br>
56
+ Target Params (JSON):<br>
57
+ <textarea name="params" rows="3" cols="40">{"d_model":32,"nhead":4,"num_layers":1,"dim_feedforward":64,"max_seq_len":16}</textarea><br>
58
+ Width Scale: <input type="number" step="0.1" id="width_scale" value="1.0"><br>
59
+ <button type="submit">Collapse</button>
60
+ </form>
61
+ </section>
62
+ <section>
63
+ <h2>Inference</h2>
64
+ <form id="inferForm">
65
+ Bits: <input type="text" name="bits" value="0 1 0 1"><br>
66
+ Upload file: <input type="file" id="infer_file"><br>
67
+ <button type="submit">Infer</button>
68
+ </form>
69
+ <pre id="inferOut"></pre>
70
+ </section>
71
+ <section>
72
+ <h2>Long Inference</h2>
73
+ <form id="inferLongForm">
74
+ Bits: <input type="text" name="bits" value="0 1 0 1"><br>
75
+ ctx_bits: <input type="number" name="ctx_bits" value="4096"><br>
76
+ overlap: <input type="number" name="overlap" value="256"><br>
77
+ <button type="submit">Infer Long</button>
78
+ </form>
79
+ <pre id="inferLongOut"></pre>
80
+ </section>
81
+ <section>
82
+ <h2>Text Inference</h2>
83
+ <form id="textInferForm">
84
+ Text: <input type="text" name="text" value="hello"><br>
85
+ <button type="submit">Infer Text</button>
86
+ </form>
87
+ <pre id="textInferOut"></pre>
88
+ </section>
89
+ <section>
90
+ <h2>&lambda; Weights</h2>
91
+ <form id="lambdaForm">
92
+ &lambda;<sub>K</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_K" oninput="lambda_K_val.innerText=value"><span id="lambda_K_val"></span><br>
93
+ &lambda;<sub>C</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_C" oninput="lambda_C_val.innerText=value"><span id="lambda_C_val"></span><br>
94
+ &lambda;<sub>S</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_S" oninput="lambda_S_val.innerText=value"><span id="lambda_S_val"></span><br>
95
+ <button type="submit">Update</button>
96
+ </form>
97
+ </section>
98
+ <section>
99
+ <h2>Diffusion LM</h2>
100
+ <label><input type="checkbox" id="diffusion_box"> Enable Diffusion Mode</label>
101
+ </section>
102
+ <section>
103
+ <h2>GPU Acceleration</h2>
104
+ <label><input type="checkbox" id="gpu_box"> Enable FSDP &amp; CUDA</label>
105
+ </section>
106
+ <section>
107
+ <h2>Enable Compression</h2>
108
+ <label><input type="checkbox" id="compression_box"> Compress I/O</label>
109
+ <p>Ratio: <span id="comp_ratio">1.0</span></p>
110
+ </section>
111
+ <section>
112
+ <h2>Quantization Aware Training</h2>
113
+ <label><input type="checkbox" id="qat_box"> Enable 4-bit QAT</label>
114
+ </section>
115
+ <section>
116
+ <h2>Model Status</h2>
117
+ <pre id="statusOut"></pre>
118
+ </section>
119
+ <section>
120
+ <h2>Telemetry</h2>
121
+ <canvas id="metricChart" width="600" height="300"></canvas>
122
+ </section>
123
+ <section>
124
+ <h2>Hugging Face Checkpoints</h2>
125
+ Repo ID: <input type="text" id="hf_repo"><br>
126
+ Token: <input type="password" id="hf_token" placeholder="optional"><br>
127
+ <button id="uploadBtn">Upload weights</button>
128
+ <button id="downloadBtn">Download weights</button>
129
+ <p id="hfStatus"></p>
130
+ </section>
131
+
132
+ <script>
133
+ async function postJSON(url, data){
134
+ const resp = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(data)});
135
+ return resp.json();
136
+ }
137
+
138
+ async function pollJob(id){
139
+ while(true){
140
+ const job = await fetch(`/job/${id}`).then(r=>r.json());
141
+ if(job.status === 'completed') return job.result;
142
+ if(job.status === 'error') throw job.error || 'Job failed';
143
+ await new Promise(r=>setTimeout(r, 1000));
144
+ }
145
+ }
146
+
147
+ function loadInitParams(){
148
+ const saved = JSON.parse(localStorage.getItem('init_params')||'{}');
149
+ const form = document.getElementById('initForm');
150
+ for(const [k,v] of Object.entries(saved)){
151
+ const el = form.elements[k];
152
+ if(!el) continue;
153
+ if(el.type === 'checkbox') el.checked = v; else el.value = v;
154
+ }
155
+ }
156
+ loadInitParams();
157
+
158
+ function byteArrayToBits(arr){
159
+ const bits=[];
160
+ for(const b of arr){
161
+ for(let i=7;i>=0;i--) bits.push((b>>i)&1);
162
+ }
163
+ return bits;
164
+ }
165
+
166
+ let trainFileBits=null, inferFileBits=null, datasetBits=null;
167
+
168
+ async function fileToBits(file){
169
+ if(file.type.startsWith('text')){
170
+ const text = await file.text();
171
+ const res = await postJSON('/text_to_bits', {text});
172
+ return res.bits;
173
+ }
174
+ const buf = await file.arrayBuffer();
175
+ return byteArrayToBits(new Uint8Array(buf));
176
+ }
177
+
178
+ let metricChart;
179
+ async function initChart(){
180
+ const data = await fetch('/metrics').then(r=>r.json());
181
+ const labels = data.negentropy.map((_,i)=>i);
182
+ const ctx = document.getElementById('metricChart').getContext('2d');
183
+ metricChart = new Chart(ctx, {
184
+ type:'line',
185
+ data:{
186
+ labels:labels,
187
+ datasets:[
188
+ {label:'Negentropy', data:data.negentropy, borderColor:'blue', fill:false},
189
+ {label:'LZ Complexity', data:data.lz_complexity, borderColor:'orange', fill:false},
190
+ {label:'Symbiosis', data:data.symbiosis, borderColor:'green', fill:false}
191
+ ]
192
+ },
193
+ options:{responsive:false, interaction:{mode:'index', intersect:false}}
194
+ });
195
+ }
196
+
197
+ async function updateChart(){
198
+ const data = await fetch('/metrics').then(r=>r.json());
199
+ const labels = data.negentropy.map((_,i)=>i);
200
+ metricChart.data.labels = labels;
201
+ metricChart.data.datasets[0].data = data.negentropy;
202
+ metricChart.data.datasets[1].data = data.lz_complexity;
203
+ metricChart.data.datasets[2].data = data.symbiosis;
204
+ metricChart.update();
205
+ }
206
+
207
+ initChart();
208
+ setInterval(updateChart, 2000);
209
+
210
+ async function refreshStatus(){
211
+ const [s, c] = await Promise.all([fetch('/status'), fetch('/model_config')]);
212
+ const status = await s.json();
213
+ const config = await c.json();
214
+ document.getElementById('statusOut').innerText = JSON.stringify({...status, ...config}, null, 2);
215
+ }
216
+
217
+ document.getElementById('initForm').addEventListener('submit', async (e)=>{
218
+ e.preventDefault();
219
+ const fd = new FormData(e.target);
220
+ const obj = Object.fromEntries(fd.entries());
221
+ const ints = ['d_model','nhead','num_layers','dim_feedforward','max_seq_len','chunk_size','overlap'];
222
+ ints.forEach(k=>{ if(obj[k]===''){ delete obj[k]; } else obj[k]=parseInt(obj[k]); });
223
+ obj.reversible = document.getElementById('reversible_box').checked;
224
+ obj.use_checkpoint = document.getElementById('checkpoint_box').checked;
225
+ obj.act_threshold = parseFloat(obj.act_threshold);
226
+ const floors = {c_floor: parseFloat(obj.c_floor), s_floor: parseFloat(obj.s_floor)};
227
+ delete obj.c_floor; delete obj.s_floor;
228
+ await postJSON('/init', obj);
229
+ await postJSON('/config/telemetry', floors);
230
+ localStorage.setItem('init_params', JSON.stringify({...obj, ...floors}));
231
+ refreshStatus();
232
+ updateChart();
233
+ });
234
+
235
+ document.getElementById('trainForm').addEventListener('submit', async (e)=>{
236
+ e.preventDefault();
237
+ const form = e.target;
238
+ let payload;
239
+ if(trainFileBits){
240
+ payload = trainFileBits;
241
+ } else if(datasetBits){
242
+ payload = datasetBits;
243
+ } else {
244
+ payload = [form.bits.value.trim().split(/\s+/).map(Number)];
245
+ }
246
+ for(const el of form.elements) el.disabled = true;
247
+ const out = document.getElementById('trainOut');
248
+ out.innerText = '⏳';
249
+ try{
250
+ const job = await postJSON('/train', {bits: payload});
251
+ const res = await pollJob(job.job_id);
252
+ out.innerText = 'Loss: '+res.loss.toFixed(4);
253
+ if(res.ratio !== undefined){
254
+ document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
255
+ }
256
+ } catch(err){
257
+ out.innerText = 'Error';
258
+ alert(err);
259
+ } finally {
260
+ for(const el of form.elements) el.disabled = false;
261
+ refreshStatus();
262
+ updateChart();
263
+ }
264
+ });
265
+
266
+ document.getElementById('train_file').addEventListener('change', async (e)=>{
267
+ const f = e.target.files[0];
268
+ if(!f) return;
269
+ const bits = await fileToBits(f);
270
+ trainFileBits = [bits];
271
+ datasetBits = null;
272
+ document.querySelector('#trainForm input[name="bits"]').value = bits.slice(0,64).join(' ');
273
+ });
274
+
275
+ document.querySelector('#trainForm input[name="bits"]').addEventListener('input', ()=>{
276
+ trainFileBits = null;
277
+ datasetBits = null;
278
+ });
279
+
280
+ document.getElementById('scaleBtn').addEventListener('click', async ()=>{
281
+ const btn = document.getElementById('scaleBtn');
282
+ const input = document.getElementById('width_mult');
283
+ const mult = parseFloat(input.value);
284
+ btn.disabled = true; input.disabled = true;
285
+ const original = btn.innerText; btn.innerText = '⏳';
286
+ try{
287
+ const job = await postJSON('/scale_up', {width_mult: mult});
288
+ await pollJob(job.job_id);
289
+ } catch(err){
290
+ alert(err);
291
+ } finally {
292
+ btn.innerText = original;
293
+ btn.disabled = false; input.disabled = false;
294
+ refreshStatus();
295
+ updateChart();
296
+ }
297
+ });
298
+
299
+ document.getElementById('collapseForm').addEventListener('submit', async (e)=>{
300
+ e.preventDefault();
301
+ const form = e.target;
302
+ const btn = form.querySelector('button');
303
+ for(const el of form.elements) el.disabled = true;
304
+ const clusters = JSON.parse(form.clusters.value);
305
+ const params = JSON.parse(form.params.value);
306
+ const w = parseFloat(document.getElementById('width_scale').value);
307
+ const original = btn.innerText; btn.innerText = '⏳';
308
+ try{
309
+ const job = await postJSON('/collapse', {clusters: clusters, params: params, width_scale: w});
310
+ await pollJob(job.job_id);
311
+ } catch(err){
312
+ alert(err);
313
+ } finally {
314
+ btn.innerText = original;
315
+ for(const el of form.elements) el.disabled = false;
316
+ refreshStatus();
317
+ updateChart();
318
+ }
319
+ });
320
+
321
+ document.getElementById('inferForm').addEventListener('submit', async (e)=>{
322
+ e.preventDefault();
323
+ let bits;
324
+ if(inferFileBits){
325
+ bits = inferFileBits;
326
+ } else if(datasetBits){
327
+ bits = [datasetBits[0]];
328
+ } else {
329
+ bits = [e.target.bits.value.trim().split(/\s+/).map(Number)];
330
+ }
331
+ const res = await postJSON('/infer', {bits});
332
+ if(res.error){
333
+ alert(res.error + '\n' + (res.suggestion||''));
334
+ } else {
335
+ document.getElementById('inferOut').innerText = JSON.stringify(res, null, 2);
336
+ if(res.ratio !== undefined){
337
+ document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
338
+ }
339
+ }
340
+ refreshStatus();
341
+ updateChart();
342
+ });
343
+
344
+ document.getElementById('infer_file').addEventListener('change', async (e)=>{
345
+ const f = e.target.files[0];
346
+ if(!f) return;
347
+ const bits = await fileToBits(f);
348
+ inferFileBits = [bits];
349
+ datasetBits = null;
350
+ document.querySelector('#inferForm input[name="bits"]').value = bits.slice(0,64).join(' ');
351
+ });
352
+
353
+ document.querySelector('#inferForm input[name="bits"]').addEventListener('input', ()=>{
354
+ inferFileBits = null;
355
+ datasetBits = null;
356
+ });
357
+
358
+ document.getElementById('datasetSelect').addEventListener('change', async (e)=>{
359
+ const val = e.target.value;
360
+ trainFileBits = null;
361
+ inferFileBits = null;
362
+ if(!val){ datasetBits = null; return; }
363
+ const [name, split] = val.split('_');
364
+ const resp = await fetch(`/dataset?name=${name}&split=${split}&size=4&seq_len=64`);
365
+ const data = await resp.json();
366
+ datasetBits = data.bits;
367
+ const preview = data.bits[0].slice(0,64).join(' ');
368
+ document.querySelector('#trainForm input[name="bits"]').value = preview;
369
+ document.querySelector('#inferForm input[name="bits"]').value = preview;
370
+ });
371
+
372
+ document.getElementById('inferLongForm').addEventListener('submit', async (e)=>{
373
+ e.preventDefault();
374
+ const bits = e.target.bits.value.trim().split(/\s+/).map(Number);
375
+ const ctx = parseInt(e.target.ctx_bits.value);
376
+ const ov = parseInt(e.target.overlap.value);
377
+ const res = await postJSON('/infer_long', {bits: bits, ctx_bits: ctx, overlap: ov});
378
+ document.getElementById('inferLongOut').innerText = JSON.stringify(res, null, 2);
379
+ refreshStatus();
380
+ updateChart();
381
+ });
382
+
383
+ document.getElementById('textInferForm').addEventListener('submit', async (e)=>{
384
+ e.preventDefault();
385
+ const text = e.target.text.value;
386
+ const res = await postJSON('/infer_text', {text:text});
387
+ document.getElementById('textInferOut').innerText = JSON.stringify(res, null, 2);
388
+ refreshStatus();
389
+ updateChart();
390
+ });
391
+
392
+ async function loadLambdas(){
393
+ const resp = await fetch('/lambdas');
394
+ const vals = await resp.json();
395
+ for(const k of ['lambda_K','lambda_C','lambda_S']){
396
+ document.getElementById(k).value = vals[k];
397
+ document.getElementById(k+"_val").innerText = vals[k];
398
+ }
399
+ }
400
+
401
+ document.getElementById('lambdaForm').addEventListener('submit', async (e)=>{
402
+ e.preventDefault();
403
+ const data = {
404
+ lambda_K: parseFloat(document.getElementById('lambda_K').value),
405
+ lambda_C: parseFloat(document.getElementById('lambda_C').value),
406
+ lambda_S: parseFloat(document.getElementById('lambda_S').value),
407
+ };
408
+ await postJSON('/lambdas', data);
409
+ for(const k in data){
410
+ document.getElementById(k+"_val").innerText = data[k];
411
+ }
412
+ refreshStatus();
413
+ });
414
+
415
+ loadLambdas();
416
+
417
+ function restoreToggle(id,key,endpoint,field){
418
+ const box = document.getElementById(id);
419
+ const saved = localStorage.getItem(key);
420
+ if(saved !== null){ box.checked = saved === 'true'; postJSON(endpoint,{[field]: box.checked}); }
421
+ box.addEventListener('change', async (e)=>{
422
+ await postJSON(endpoint, {[field]: e.target.checked});
423
+ localStorage.setItem(key, e.target.checked);
424
+ refreshStatus();
425
+ });
426
+ }
427
+
428
+ restoreToggle('diffusion_box','diffusion','/diffusion','diffusion');
429
+ restoreToggle('gpu_box','use_gpu','/gpu','use_gpu');
430
+ restoreToggle('compression_box','compression','/compression','compression');
431
+ restoreToggle('qat_box','qat','/qat','qat');
432
+
433
+ document.getElementById('uploadBtn').addEventListener('click', async ()=>{
434
+ const repo = document.getElementById('hf_repo').value;
435
+ const token = document.getElementById('hf_token').value;
436
+ const res = await postJSON('/save_checkpoint', {repo_id: repo, token: token||undefined});
437
+ document.getElementById('hfStatus').innerText = res.status || res.error;
438
+ });
439
+
440
+ document.getElementById('downloadBtn').addEventListener('click', async ()=>{
441
+ const repo = document.getElementById('hf_repo').value;
442
+ const token = document.getElementById('hf_token').value;
443
+ const res = await postJSON('/download_checkpoint', {repo_id: repo, token: token||undefined});
444
+ document.getElementById('hfStatus').innerText = res.status || res.error;
445
+ refreshStatus();
446
+ updateChart();
447
+ });
448
+
449
+ refreshStatus();
450
+ </script>
451
+ </div>
452
+ </body>
453
+ </html>
454
+
BitTransformerLM/bit_transformer/torch_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ import torch
5
+
6
+
7
+ @contextmanager
8
+ def cpu_autocast(enabled: bool = True):
9
+ """Context manager for bfloat16 autocast on CPU.
10
+
11
+ Parameters
12
+ ----------
13
+ enabled: bool, default True
14
+ Whether to enable autocast. When ``False`` this context manager
15
+ behaves like a no-op.
16
+ """
17
+ if enabled:
18
+ with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
19
+ yield
20
+ else:
21
+ yield
BitTransformerLM/bit_transformer/training.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common training utilities for BitTransformer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable, Dict, List, Optional
6
+ import contextlib
7
+ import sys
8
+ import warnings
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader
14
+
15
+ from .compression import compress_bits, pack_bits, unpack_bits
16
+ from .optimization import configure_optimizer
17
+ from .model import BitTransformerLM
18
+ from .utils import set_dropout
19
+ from .torch_utils import cpu_autocast
20
+
21
+
22
+ def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float:
23
+ """Cosine ramp from ``start`` to ``end`` over ``total_steps``."""
24
+ if total_steps <= 0 or step >= total_steps:
25
+ return end
26
+ cos_inner = math.pi * step / total_steps
27
+ return start + (end - start) * (1 - math.cos(cos_inner)) / 2
28
+
29
+
30
+ def train_loop(
31
+ model: BitTransformerLM,
32
+ data: torch.Tensor,
33
+ *,
34
+ epochs: int = 1,
35
+ extra_steps: int = 0,
36
+ compress_prob: float = 0.5,
37
+ direct_prob: float = 0.0,
38
+ batch_size: int = 8,
39
+ num_workers: int = 0,
40
+ accum_steps: int = 1,
41
+ amp: bool = False,
42
+ compile_model: bool = False,
43
+ log: bool = False,
44
+ forward_kwargs: Optional[Dict] = None,
45
+ optimizer: Optional[torch.optim.Optimizer] = None,
46
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
47
+ diffusion: bool = False,
48
+ noise_fn: Optional[Callable[[], float]] = None,
49
+ diffusion_curriculum: bool = False,
50
+ compress_warmup: int = 0,
51
+ ) -> List[Dict[str, float]]:
52
+ """Generic training loop supporting optional compression and diffusion.
53
+
54
+ ``compress_prob`` controls the fraction of batches that are run through
55
+ ``forward_compressed``. ``direct_prob`` instead feeds the model with the
56
+ bit-packed result of ``compress_bits`` after converting back to a bit
57
+ tensor. When enabled, metrics for direct-compressed batches are tracked
58
+ separately.
59
+
60
+ When ``diffusion`` is ``True`` the loop performs denoising training. Batches
61
+ are noised by randomly flipping bits with a probability given by
62
+ ``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When
63
+ ``diffusion_curriculum`` is ``True`` the noise probability decreases
64
+ linearly from ``0.5`` to ``0.0`` over the training epochs. The model is
65
+ then trained to recover the clean sequence using full-context attention
66
+ (``causal=False``).
67
+
68
+ Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow
69
+ integration with long-running training sessions, otherwise new ones are
70
+ created automatically.
71
+ """
72
+ if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1":
73
+ model = torch.compile(model)
74
+ elif compile_model:
75
+ warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12")
76
+
77
+ model.train()
78
+ set_dropout(model, 0.1)
79
+
80
+ device = next(model.parameters()).device
81
+ loader = DataLoader(
82
+ data,
83
+ batch_size=batch_size,
84
+ shuffle=True,
85
+ num_workers=num_workers,
86
+ persistent_workers=num_workers > 0,
87
+ )
88
+ steps_per_epoch = max(1, len(loader))
89
+ total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps)
90
+ if optimizer is None or scheduler is None:
91
+ optimizer, scheduler = configure_optimizer(
92
+ model, lr=1e-3, total_steps=total_updates
93
+ )
94
+ metrics: List[Dict[str, float]] = []
95
+
96
+ global_step = 0
97
+ for epoch in range(epochs):
98
+ raw_losses: List[float] = []
99
+ raw_accs: List[float] = []
100
+ comp_losses: List[float] = []
101
+ comp_accs: List[float] = []
102
+ comp_ratios: List[float] = []
103
+ direct_losses: List[float] = []
104
+
105
+ last_batch = None
106
+ for step, batch in enumerate(loader):
107
+ last_batch = batch
108
+ batch = batch.to(device)
109
+ cur_compress = (
110
+ cosine_ramp(global_step, 0.0, compress_prob, compress_warmup)
111
+ if not diffusion
112
+ else compress_prob
113
+ )
114
+ if diffusion:
115
+ if diffusion_curriculum:
116
+ p = 0.5 * (1 - epoch / max(1, epochs - 1))
117
+ else:
118
+ p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5)
119
+ noise = (torch.rand_like(batch.float()) < p).long()
120
+ noisy = batch ^ noise
121
+ with (
122
+ torch.cuda.amp.autocast(dtype=torch.bfloat16)
123
+ if amp and torch.cuda.is_available()
124
+ else cpu_autocast() if amp else contextlib.nullcontext()
125
+ ):
126
+ logits, _ = model(noisy, causal=False)
127
+ pred = logits.reshape(-1, 2)
128
+ target = batch.reshape(-1)
129
+ loss = F.cross_entropy(pred, target) / accum_steps
130
+ acc = (pred.argmax(dim=-1) == target).float().mean().item()
131
+ raw_losses.append(loss.item() * accum_steps)
132
+ raw_accs.append(acc)
133
+ loss.backward()
134
+ if (step + 1) % accum_steps == 0:
135
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
136
+ optimizer.step()
137
+ scheduler.step()
138
+ optimizer.zero_grad()
139
+ global_step += 1
140
+ continue
141
+
142
+ r = torch.rand(())
143
+ key = "raw"
144
+ ratio = 1.0
145
+ target = batch[:, 1:].reshape(-1)
146
+
147
+ if r < direct_prob:
148
+ packed = [pack_bits(row.to(torch.uint8)) for row in batch]
149
+ unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed]
150
+ max_len = min(
151
+ max(u.numel() for u in unpacked),
152
+ model.pos_enc.pe.size(0),
153
+ )
154
+ padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked]
155
+ dc_batch = torch.stack(padded).long()
156
+ with (
157
+ torch.cuda.amp.autocast(dtype=torch.bfloat16)
158
+ if amp and torch.cuda.is_available()
159
+ else cpu_autocast() if amp else contextlib.nullcontext()
160
+ ):
161
+ logits, _ = model(dc_batch, **(forward_kwargs or {}))
162
+ ratio = sum(p.numel() for p in packed) / batch.numel()
163
+ target = dc_batch[:, 1:].reshape(-1)
164
+ key = "direct"
165
+ elif r < direct_prob + cur_compress:
166
+ comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch]
167
+ with (
168
+ torch.cuda.amp.autocast(dtype=torch.bfloat16)
169
+ if amp and torch.cuda.is_available()
170
+ else cpu_autocast() if amp else contextlib.nullcontext()
171
+ ):
172
+ logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {}))
173
+ ratio = sum(c.numel() for c in comp_batch) / batch.numel()
174
+ target = batch[:, 1:].reshape(-1)
175
+ key = "compressed"
176
+ else:
177
+ with (
178
+ torch.cuda.amp.autocast(dtype=torch.bfloat16)
179
+ if amp and torch.cuda.is_available()
180
+ else cpu_autocast() if amp else contextlib.nullcontext()
181
+ ):
182
+ logits, _ = model(batch, **(forward_kwargs or {}))
183
+
184
+ pred = logits[:, :-1, :].reshape(-1, 2)
185
+ loss = F.cross_entropy(pred, target) / accum_steps
186
+ acc = (pred.argmax(dim=-1) == target).float().mean().item()
187
+
188
+ loss.backward()
189
+ if (step + 1) % accum_steps == 0:
190
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
191
+ optimizer.step()
192
+ scheduler.step()
193
+ optimizer.zero_grad()
194
+ global_step += 1
195
+
196
+ if key == "compressed":
197
+ comp_losses.append(loss.item() * accum_steps)
198
+ comp_accs.append(acc)
199
+ comp_ratios.append(ratio)
200
+ elif key == "direct":
201
+ direct_losses.append(loss.item() * accum_steps)
202
+ comp_ratios.append(ratio)
203
+ else:
204
+ raw_losses.append(loss.item() * accum_steps)
205
+ raw_accs.append(acc)
206
+
207
+ # run extra gradient updates using the final batch
208
+ if extra_steps > 0 and last_batch is not None and not diffusion:
209
+ for step in range(extra_steps):
210
+ with (
211
+ torch.cuda.amp.autocast(dtype=torch.bfloat16)
212
+ if amp and torch.cuda.is_available()
213
+ else cpu_autocast() if amp else contextlib.nullcontext()
214
+ ):
215
+ logits, _ = model(last_batch, **(forward_kwargs or {}))
216
+ pred = logits[:, :-1, :].reshape(-1, 2)
217
+ target = last_batch[:, 1:].reshape(-1)
218
+ loss = F.cross_entropy(pred, target) / accum_steps
219
+ acc = (pred.argmax(dim=-1) == target).float().mean().item()
220
+ loss.backward()
221
+ if (step + 1) % accum_steps == 0:
222
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
223
+ optimizer.step()
224
+ scheduler.step()
225
+ optimizer.zero_grad()
226
+ raw_losses.append(loss.item() * accum_steps)
227
+ raw_accs.append(acc)
228
+ global_step += 1
229
+
230
+ m = {
231
+ "raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0,
232
+ "raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0,
233
+ "compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0,
234
+ "compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0,
235
+ "direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0,
236
+ "compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0,
237
+ }
238
+ metrics.append(m)
239
+
240
+ if log:
241
+ print(
242
+ f"Epoch {epoch} "
243
+ f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | "
244
+ f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} "
245
+ f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}"
246
+ )
247
+
248
+ return metrics
249
+
250
+ __all__ = ["train_loop"]
BitTransformerLM/bit_transformer/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gzip
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def save_model(model: torch.nn.Module, path: str) -> None:
8
+ """Save a model using gzip compression."""
9
+ os.makedirs(os.path.dirname(path), exist_ok=True)
10
+ with gzip.open(path, 'wb') as f:
11
+ torch.save(model, f)
12
+
13
+
14
+ def load_model(path: str) -> torch.nn.Module:
15
+ """Load a model saved with ``save_model``."""
16
+ with gzip.open(path, 'rb') as f:
17
+ model = torch.load(f, map_location="cpu", weights_only=False)
18
+ return model
19
+
20
+
21
+ def set_dropout(model: torch.nn.Module, p: float) -> None:
22
+ """Set dropout probability ``p`` for all dropout layers in ``model``."""
23
+ for module in model.modules():
24
+ if isinstance(module, nn.Dropout):
25
+ module.p = p
26
+
27
+
28
+ __all__ = ["save_model", "load_model", "set_dropout"]
BitTransformerLM/bit_transformer_lm_codex_playbook.md ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ # 🧭 BitTransformerLM Codex Playbook (Merged)
4
+
5
+ A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom.
6
+
7
+ ---
8
+
9
+ ## Phase 1 — Training Loop & Runtime Optimizations (apply these first)
10
+
11
+ ### Task 1 — Make batch size configurable & fix OneCycle accounting — COMPLETED ✅
12
+
13
+ **Prompt:**
14
+
15
+ ```bash
16
+ codex run bittransformerlm/patch \
17
+ --file bit_transformer/training.py \
18
+ --edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"
19
+ ```
20
+
21
+ ✅ OneCycle’s horizon matches reality across runs.
22
+
23
+ ---
24
+
25
+ ### Task 2 — Remove hardcoded `total_steps=100` in dashboard/MCP — COMPLETED ✅
26
+
27
+ **Prompt:**
28
+
29
+ ```bash
30
+ codex run bittransformerlm/patch \
31
+ --file dashboard/manager.py \
32
+ --edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"
33
+ ```
34
+
35
+ ✅ Aligns scheduler behavior between direct loop and MCP/dashboard.
36
+
37
+ ---
38
+
39
+ ### Task 3 — Add mixed-precision autocast (AMP, BF16) — COMPLETED ✅
40
+
41
+ **Prompt (pseudo-patch):**
42
+
43
+ ```python
44
+ with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
45
+ logits = model(batch)
46
+ loss = criterion(logits, labels)
47
+ loss.backward()
48
+ ```
49
+
50
+ ✅ 1.2–1.8× throughput on attention-heavy training. Keep grad-clip.
51
+
52
+ ---
53
+
54
+ ### Task 4 — Add gradient accumulation — COMPLETED ✅
55
+
56
+ **Prompt:**
57
+
58
+ ```bash
59
+ codex run bittransformerlm/patch \
60
+ --file bit_transformer/training.py \
61
+ --edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"
62
+ ```
63
+
64
+ ✅ Simulates larger effective batch sizes without extra memory.
65
+
66
+ ---
67
+
68
+ ### Task 5 — Optimize dataset pipeline (mmap + streaming) — COMPLETED ✅
69
+
70
+ **Prompt:**
71
+
72
+ ```bash
73
+ codex run bittransformerlm/patch \
74
+ --file data/wikitext_schedule.py \
75
+ --edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"
76
+ ```
77
+
78
+ ✅ Removes conversion bottlenecks on large corpora.
79
+
80
+ ---
81
+
82
+ ### Task 6 — Schedule compression probability (safer ramp) — COMPLETED ✅
83
+
84
+ **Prompt (pseudo-code):**
85
+
86
+ ```python
87
+ compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)
88
+ ```
89
+
90
+ ✅ Prevents early instability from aggressive compression.
91
+
92
+ ---
93
+
94
+ ### Task 7 — Stabilize safety gate (EMA + burn‑in) — COMPLETED ✅
95
+
96
+ **Prompt (pseudo-patch):**
97
+
98
+ ```python
99
+ ema_val = ema(val_loss, decay=0.9)
100
+ if step < burn_in_steps:
101
+ allow_training = True
102
+ elif ema_val > threshold:
103
+ trigger_gate()
104
+ ```
105
+
106
+ ✅ Reduces false positives from noisy early validations.
107
+
108
+ ---
109
+
110
+ ### Task 8 — Enable `torch.compile` selectively — COMPLETED ✅
111
+
112
+ **Prompt:**
113
+
114
+ ```bash
115
+ codex run bittransformerlm/patch \
116
+ --file bit_transformer/training.py \
117
+ --edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"
118
+ ```
119
+
120
+ ✅ Opportunistic speedup where supported.
121
+
122
+ ---
123
+
124
+ ### Task 9 — Integrate FlashAttention / SDPA
125
+
126
+ **Prompt (pseudo-patch):**
127
+
128
+ ```python
129
+ from torch.nn import functional as F
130
+
131
+ def forward_attention(q, k, v, is_causal=True):
132
+ return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
133
+ ```
134
+
135
+ ✅ Unlocks fused kernels; prefer `is_causal=True` over boolean masks.
136
+
137
+ ---
138
+
139
+ ### Task 10 — Cache causal masks — COMPLETED ✅
140
+
141
+ **Prompt (pseudo-code):**
142
+
143
+ ```python
144
+ mask_cache = {}
145
+
146
+ def get_tri_mask(seq_len, device):
147
+ key = (seq_len, device)
148
+ if key not in mask_cache:
149
+ mask_cache[key] = torch.triu(
150
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
151
+ )
152
+ return mask_cache[key]
153
+ ```
154
+
155
+ ✅ Avoids repeated `triu` allocations when masks are still needed.
156
+
157
+ ---
158
+
159
+ ### Task 11 — Fix stitched attention negative indexing — COMPLETED ✅
160
+
161
+ **Prompt (pseudo-code):**
162
+
163
+ ```python
164
+ start = max(s - overlap, 0)
165
+ end = min(s + chunk_size, T)
166
+ canvas[..., start:end] = attn_chunk[..., : end - start]
167
+ ```
168
+
169
+ ✅ Prevents wrap-around misplacement during T×T map reconstruction.
170
+
171
+ ---
172
+
173
+ ### Task 12 — Default off: full T×T attention logging in chunked runs — COMPLETED ✅
174
+
175
+ **Prompt:**
176
+
177
+ ```bash
178
+ codex run bittransformerlm/patch \
179
+ --file bit_transformer/model.py \
180
+ --edit "Set full_attn_logging=False by default when chunk_size is set"
181
+ ```
182
+
183
+ ✅ Big memory/time savings without losing training signal.
184
+
185
+ ---
186
+
187
+ ## Phase 2 — Model Creation & Training Tasks (run after Phase 1)
188
+
189
+ ### Task A — Train the best current baseline (8×256 with ACT)
190
+
191
+ **Prompt:**
192
+
193
+ ```bash
194
+ codex run bittransformerlm/train \
195
+ --layers 8 \
196
+ --d_model 256 \
197
+ --nhead 8 \
198
+ --causal true \
199
+ --chunk_size 128 \
200
+ --act true \
201
+ --reversible true \
202
+ --checkpointing true \
203
+ --batch_size 64 \
204
+ --accum_steps 2 \
205
+ --amp bf16 \
206
+ --lr_schedule progressive_plateau \
207
+ --full_attn_logging false
208
+ ```
209
+
210
+ ✅ Reproduces the validated **sweet spot** with newly enabled efficiency features.
211
+
212
+ ---
213
+
214
+ ### Task B — CPU‑friendly deployment (8×128, INT8 + optional QAT)
215
+
216
+ **Prompt:**
217
+
218
+ ```bash
219
+ codex run bittransformerlm/train \
220
+ --layers 8 \
221
+ --d_model 128 \
222
+ --nhead 8 \
223
+ --causal true \
224
+ --chunk_size 128 \
225
+ --quantization int8 \
226
+ --qat true \
227
+ --reversible true \
228
+ --checkpointing true \
229
+ --batch_size 128 \
230
+ --accum_steps 1 \
231
+ --amp bf16
232
+ ```
233
+
234
+ ✅ Efficient CPU target; QAT optional based on deployment constraints.
235
+
236
+ ---
237
+
238
+ ### Task C — Cautious scale‑up candidate (16×256)
239
+
240
+ **Prompt:**
241
+
242
+ ```bash
243
+ codex run bittransformerlm/train \
244
+ --layers 16 \
245
+ --d_model 256 \
246
+ --nhead 8 \
247
+ --causal true \
248
+ --chunk_size 128 \
249
+ --act true \
250
+ --reversible true \
251
+ --checkpointing true \
252
+ --batch_size 48 \
253
+ --accum_steps 3 \
254
+ --amp bf16 \
255
+ --lr_schedule progressive_plateau
256
+ ```
257
+
258
+ ⚠️ Use only after data expansion and schedule retune.
259
+
260
+ ---
261
+
262
+ ## Recommended Execution Order
263
+
264
+ 1. **Phase 1 Tasks 1–12** (apply all optimizations).
265
+ 2. **Task A** baseline → validate.
266
+ 3. **Task B** CPU build → validate + (optional) QAT.
267
+ 4. **Task C** scale‑up **only** when data/schedule allow.
268
+
269
+ ---
270
+
271
+ ### Notes
272
+
273
+ - Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
274
+ - Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention.
275
+ - When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required.
276
+
277
+ ---
278
+
BitTransformerLM/build_full_bits.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import torch
3
+ from datasets import load_dataset
4
+
5
+ TXT_MB = 100
6
+ OUT = pathlib.Path('full_bits.pt')
7
+
8
+
9
+ def build_bits(out: pathlib.Path = OUT, txt_mb: int = TXT_MB) -> None:
10
+ ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
11
+ buf = bytearray()
12
+ for line in ds['text']:
13
+ buf.extend(line.encode() + b"\n")
14
+ if len(buf) >= txt_mb * 2 ** 20:
15
+ break
16
+ bits = []
17
+ for byte in buf:
18
+ bits.extend(int(b) for b in f'{byte:08b}')
19
+ tensor = torch.tensor(bits, dtype=torch.uint8)
20
+ torch.save(tensor, out)
21
+
22
+ if __name__ == '__main__':
23
+ build_bits()
BitTransformerLM/context_extension.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Increasing the BitTransformerLM context window
2
+
3
+ Current limitations and mechanisms
4
+ The default max_seq_len in BitTransformerLM is 1 024 bits
5
+ GitHub
6
+ . Since text is encoded using parity bits (9 bits per byte)
7
+ GitHub
8
+ , this translates to roughly 113 bytes (≈113 characters) of input. The model uses full self‑attention, giving quadratic memory complexity in sequence length. To train on very long sequences, train_full_sequence slides a fixed‑size context window along a long bit tensor, detaching the computation graph periodically
9
+ GitHub
10
+ . Compression can shorten sequences via run‑length encoding
11
+ GitHub
12
+ , and chunked attention can divide long inputs into overlapping windows for attention calculations
13
+ GitHub
14
+ . However, the maximum positional encoding still defines an upper bound.
15
+
16
+ Strategies to reach ~2 k‑word context (~18 k bits)
17
+ Increase max_seq_len and positional encoding. The positional encoding precomputes a [max_len, d_model] matrix
18
+ GitHub
19
+ . Raising max_len to accommodate ~18 000 bits (for ~2 000 words × 9 bits per word) is possible but memory‑intensive. At d_model=128, the positional encoding would be ~18 000×128≈2.3 M floats (≈9 MB). That is reasonable for a CPU VM. Codex can modify the default max_seq_len and update any dependent tests.
20
+ Use chunked attention and overlapping windows. LoggingTransformerEncoderLayer already supports chunk_size and overlap parameters
21
+ GitHub
22
+ . Setting chunk_size (e.g., 2 048 bits) and an overlap of e.g., 128 bits enables the model to handle sequences far longer than the attention window while still allowing information flow across chunks. Codex can expose chunk_size and overlap through the dashboard and CLI so users can tune them for longer contexts.
23
+ Codex prompt example: “Modify the dashboard /init endpoint to accept chunk_size and overlap fields and pass them to BitTransformerLM. Update the HTML template to include input fields for these parameters.”
24
+ Apply sliding‑window training and inference. The train_full_sequence method trains on long bit tensors by sliding a context window and detaching the graph every ctx_bits bits
25
+ GitHub
26
+ . For inference, a similar sliding approach could produce outputs for long sequences. Codex can add an infer_long_sequence method that divides a long bit sequence into overlapping windows, runs the model with causal=True to preserve order, and stitches the outputs.
27
+ Prompt example: “Implement def infer_long_sequence(model: BitTransformerLM, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256): that processes a long bit tensor in sliding windows with overlap, uses causal=True, and returns the concatenated output bits.”
28
+ Exploit run‑length compression more aggressively. Since binary data often contains runs of identical bits (e.g., long sequences of zeros), increasing compression ratio reduces the effective sequence length. Codex could add additional compression schemes (e.g., bit‑packing into bytes using numpy.packbits) and integrate them into the model’s I/O pipeline. Care must be taken to maintain parity bits for error detection.
29
+ Prompt example: “Add functions pack_bits and unpack_bits that use numpy.packbits to pack 8 bits into a byte. Modify train_loop so that when direct_prob>0 the model is trained on packed bits with a suitable embedding.”
30
+ Memory‑efficient attention alternatives. For even larger contexts, one could replace full attention with sparse, local or linear attention mechanisms. However, this would change the core architecture, which the task seeks to avoid. Using chunked attention (already present) and reversible layers is therefore preferred.
31
+ Dynamic quantization and mixed precision. Larger context sizes increase model activations. Enabling use_autocast=True to compute in bfloat16 and applying quantize_dynamic after training reduces memory usage
32
+ GitHub
33
+ GitHub
34
+ . Codex can create scripts that quantify memory usage and automatically toggle these features when large contexts are requested.
35
+ Proposed Codex tasks to implement context extension
36
+ Expose context parameters in the API/UI. Extend the dashboard and MCP server to allow clients to specify max_seq_len, chunk_size, overlap, and ctx_bits when initializing a model or running long inference.
37
+ Prompt example: “Add optional parameters max_seq_len, chunk_size and overlap to the /init endpoint and pass them into BitTransformerLM and ModelManager. Update the HTML template to include these fields.”
38
+ Implement sliding‑window inference. Add a function infer_long_sequence as described above and expose it via the dashboard and MCP server.
39
+ Prompt example: “Add a new endpoint /infer_long to mcp_server.py that accepts a list of bits and processes them using a sliding window with overlap. The endpoint should return the predicted bits and telemetry summaries for each window.”
40
+ Allow dynamic context scaling. Add a method to BitTransformerLM to adjust its pos_enc buffer when the context exceeds the current max_seq_len. This can be done by creating a new positional encoding tensor with the new length and copying the existing values.
41
+ Prompt example: “Implement BitTransformerLM.expand_positional_encoding(new_len: int) that creates a new positional encoding buffer of size new_len and copies the existing encoding. Update the model’s max_seq_len accordingly.”
42
+ Integrate aggressive compression. Implement alternative compression schemes (e.g., bit‑packing or general‑purpose compressors) and add toggles for them in training and inference. Evaluate compression ratio and latency to decide when to use them.
43
+ Benchmark and tune hyperparameters. Write scripts to benchmark model memory use and throughput for various max_seq_len, chunk_size, reversible, use_act, and quantization settings. These benchmarks can inform safe defaults for the VM build.
BitTransformerLM/example.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from bit_transformer import example_training_step
2
+
3
+ if __name__ == "__main__":
4
+ loss, telemetry = example_training_step()
5
+ print("Training loss:", loss)
6
+ print("Available telemetry:", list(telemetry.keys()))
BitTransformerLM/full_bits_train.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import torch
3
+ from bit_transformer import BitTransformerLM
4
+
5
+ DATA_PATH = pathlib.Path('full_bits.pt')
6
+
7
+ class BitSeq(torch.utils.data.IterableDataset):
8
+ def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
9
+ self.bits = torch.load(path, mmap=True)
10
+ self.seq = seq
11
+
12
+ def __len__(self) -> int:
13
+ return (self.bits.numel() // self.seq) - 1
14
+
15
+ def __iter__(self):
16
+ N = (self.bits.numel() // self.seq) - 1
17
+ for i in range(N):
18
+ s = i * self.seq
19
+ yield (
20
+ self.bits[s:s+self.seq].long(),
21
+ self.bits[s+1:s+self.seq+1].long(),
22
+ )
23
+
24
+ def main() -> None:
25
+ dl = torch.utils.data.DataLoader(
26
+ BitSeq(DATA_PATH, seq=2048),
27
+ batch_size=8,
28
+ num_workers=0,
29
+ pin_memory=False,
30
+ )
31
+
32
+ model = BitTransformerLM(
33
+ d_model=64,
34
+ nhead=4,
35
+ num_layers=2,
36
+ dim_feedforward=256,
37
+ max_seq_len=2048,
38
+ reversible=True,
39
+ use_autocast=True,
40
+ )
41
+
42
+ loss_fn = torch.nn.CrossEntropyLoss()
43
+ xb, yb = next(iter(dl))
44
+ logits, _ = model(xb)
45
+ pred = logits.reshape(-1, 2)
46
+ target = yb.reshape(-1)
47
+ loss = loss_fn(pred, target)
48
+ print('Batch loss:', float(loss))
49
+
50
+ if __name__ == '__main__':
51
+ main()
BitTransformerLM/integration_flow.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.profiler import profile
3
+ from bit_transformer import (
4
+ BitTransformerLM,
5
+ quantize_dynamic,
6
+ hil_safe_inference,
7
+ collapse_submodel,
8
+ )
9
+ from bit_transformer.training import train_loop
10
+ from bit_transformer.torch_utils import cpu_autocast
11
+
12
+ def train(
13
+ model: BitTransformerLM,
14
+ data: torch.Tensor,
15
+ epochs: int = 3,
16
+ compress_prob: float = 0.5,
17
+ direct_prob: float = 0.0,
18
+ log: bool = False,
19
+ forward_kwargs: dict | None = None,
20
+ ) -> list[dict]:
21
+ """Train on bit sequences with optional random compression.
22
+
23
+ If ``direct_prob`` is positive, some batches are fed using their
24
+ run-length encoded representation packed into bits. Loss on these
25
+ direct-compressed batches is tracked separately.
26
+
27
+ Returns a list of per-epoch metric dictionaries containing raw and
28
+ compressed loss/accuracy statistics and the mean compression ratio.
29
+ """
30
+ return train_loop(
31
+ model,
32
+ data,
33
+ epochs=epochs,
34
+ compress_prob=compress_prob,
35
+ direct_prob=direct_prob,
36
+ log=log,
37
+ forward_kwargs=forward_kwargs,
38
+ )
39
+
40
+
41
+ def main() -> None:
42
+ data = torch.randint(0, 2, (64, 128), dtype=torch.long)
43
+ validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long)
44
+ input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long)
45
+ bit_sequence_data = data.tolist()
46
+
47
+ model = BitTransformerLM(
48
+ d_model=32,
49
+ nhead=4,
50
+ num_layers=1,
51
+ dim_feedforward=64,
52
+ max_seq_len=128,
53
+ use_act=True,
54
+ act_threshold=0.7,
55
+ reversible=True,
56
+ chunk_size=128,
57
+ )
58
+
59
+ for step in range(1, 13):
60
+ if step % 2 == 0:
61
+ model = model.double_width()
62
+ else:
63
+ model = model.double_layers()
64
+ train(model, data, epochs=3, compress_prob=0.5, log=True)
65
+ _, telemetry = model(validation_bits)
66
+ K = telemetry["negentropy_logits"].mean().item()
67
+ C = telemetry["lz_complexity_logits"].mean().item()
68
+ S = telemetry["symbiosis_score"].mean().item()
69
+ assert (
70
+ K > 0.3 and C > 0.35 and S > 0.5
71
+ ), f"Step {step} telemetry floor failure"
72
+
73
+ with cpu_autocast():
74
+ model(input_bits)
75
+
76
+ quantized_model = quantize_dynamic(model)
77
+ quantized_model.eval()
78
+
79
+ safe_output, _ = hil_safe_inference(
80
+ quantized_model, input_bits, c_floor=0.35, s_floor=0.5
81
+ )
82
+
83
+ student_model, _ = collapse_submodel(
84
+ bit_sequence_data,
85
+ target_params=dict(
86
+ d_model=16,
87
+ nhead=4,
88
+ num_layers=1,
89
+ dim_feedforward=32,
90
+ max_seq_len=128,
91
+ ),
92
+ floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5},
93
+ )
94
+
95
+ compiled_model = (
96
+ torch.compile(student_model)
97
+ if hasattr(torch, "compile")
98
+ else student_model
99
+ )
100
+ compiled_model.eval()
101
+
102
+ with profile() as prof:
103
+ compiled_model(input_bits)
104
+
105
+ prof.export_chrome_trace("trace12.json")
106
+ print("Safe output bits:", safe_output.squeeze(0).tolist())
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
BitTransformerLM/integration_schedule.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ from itertools import cycle
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from bit_transformer import (
10
+ BitTransformerLM,
11
+ text_to_bits,
12
+ quantize_dynamic,
13
+ prepare_qat_fx,
14
+ convert_qat_fx,
15
+ hil_safe_inference,
16
+ collapse_submodel,
17
+ diffusion_inference,
18
+ TelemetrySynthesizer,
19
+ save_distilled_model,
20
+ )
21
+ from bit_transformer.training import train_loop as train
22
+ from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
23
+ from bit_transformer.utils import save_model, load_model, set_dropout
24
+ from bit_transformer.torch_utils import cpu_autocast
25
+
26
+
27
+ def lines_to_tensor(lines, max_len):
28
+ seqs = []
29
+ for text in lines:
30
+ bits = text_to_bits(text)[:max_len]
31
+ if len(bits) < max_len:
32
+ bits.extend([0] * (max_len - len(bits)))
33
+ seqs.append(bits)
34
+ return torch.tensor(seqs, dtype=torch.long)
35
+
36
+
37
+ def load_wikitext(dataset_size=128, max_len=64):
38
+ try:
39
+ from datasets import load_dataset
40
+
41
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1")
42
+ train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
43
+ valid_split = max(1, dataset_size // 4)
44
+ valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
45
+ train = lines_to_tensor(train_lines, max_len)
46
+ valid = lines_to_tensor(valid_lines, max_len)
47
+ return train, valid, train_lines
48
+ except Exception as e:
49
+ print("Dataset load failed, using random bits", e)
50
+ train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
51
+ valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
52
+ return train, valid, ["" for _ in range(len(train))]
53
+
54
+
55
+ def _warmup(
56
+ model: BitTransformerLM,
57
+ data: torch.Tensor,
58
+ steps: int = 5,
59
+ freeze_old: bool = False,
60
+ old_layers: int = 0,
61
+ *,
62
+ diffusion: bool = False,
63
+ curriculum: bool = False,
64
+ optimizer: Optional[torch.optim.Optimizer] = None,
65
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
66
+ ) -> None:
67
+ """Run a short warm-up loop after expansion."""
68
+ model.train()
69
+ set_dropout(model, 0.1)
70
+ if freeze_old:
71
+ for idx, layer in enumerate(model.layers):
72
+ if idx < old_layers:
73
+ for p in layer.parameters():
74
+ p.requires_grad_(False)
75
+ if optimizer is None or scheduler is None:
76
+ optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
77
+ it = iter(data.split(8))
78
+ for idx in range(steps):
79
+ try:
80
+ batch = next(it)
81
+ except StopIteration:
82
+ it = iter(data.split(8))
83
+ batch = next(it)
84
+ if diffusion:
85
+ p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
86
+ noise = (torch.rand_like(batch.float()) < p).long()
87
+ noisy = batch ^ noise
88
+ logits, _ = model(noisy, causal=False)
89
+ pred = logits.reshape(-1, 2)
90
+ target = batch.reshape(-1)
91
+ else:
92
+ logits, _ = model(batch)
93
+ pred = logits[:, :-1, :].reshape(-1, 2)
94
+ target = batch[:, 1:].reshape(-1)
95
+ loss = F.cross_entropy(pred, target)
96
+ loss.backward()
97
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
98
+ optimizer.step()
99
+ scheduler.step()
100
+ optimizer.zero_grad()
101
+ for p in model.parameters():
102
+ p.requires_grad_(True)
103
+ model.eval()
104
+ set_dropout(model, 0.0)
105
+
106
+
107
+ def integration_schedule(
108
+ steps: int = 10,
109
+ max_len: int = 64,
110
+ dataset_size: int = 128,
111
+ *,
112
+ weights_path: str = "weights/model.pt.gz",
113
+ plateau_steps: int = 0,
114
+ collapsed_path: str | None = None,
115
+ epochs_per_step: int = 2,
116
+ extra_steps: int = 3,
117
+ collapse: bool = True,
118
+ diffusion: bool = False,
119
+ noise_schedule: str = "linear",
120
+ diffusion_steps: int = 8,
121
+ diffusion_curriculum: bool = False,
122
+ use_checkpoint: bool = True,
123
+ reversible: bool = True,
124
+ improve_thresh: float = 0.01,
125
+ qat: bool = False,
126
+ ):
127
+ start = time.time()
128
+ train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
129
+ if os.path.exists(weights_path):
130
+ try:
131
+ model = load_model(weights_path)
132
+ print(f"Loaded model from {weights_path}")
133
+ except Exception as e:
134
+ print("Failed to load weights, initializing new model", e)
135
+ model = BitTransformerLM(
136
+ d_model=32,
137
+ nhead=4,
138
+ num_layers=1,
139
+ dim_feedforward=64,
140
+ max_seq_len=max_len,
141
+ use_act=True,
142
+ act_threshold=0.7,
143
+ reversible=reversible,
144
+ chunk_size=max_len,
145
+ use_autocast=True,
146
+ use_checkpoint=use_checkpoint,
147
+ )
148
+ else:
149
+ model = BitTransformerLM(
150
+ d_model=32,
151
+ nhead=4,
152
+ num_layers=1,
153
+ dim_feedforward=64,
154
+ max_seq_len=max_len,
155
+ use_act=True,
156
+ act_threshold=0.7,
157
+ reversible=reversible,
158
+ chunk_size=max_len,
159
+ use_autocast=True,
160
+ use_checkpoint=use_checkpoint,
161
+ )
162
+ if qat:
163
+ model = prepare_qat_fx(model)
164
+ results = []
165
+ scale_cycle = cycle(["layers", "width", "context"])
166
+ base_lr = 1e-3
167
+ prev_val_loss: Optional[float] = None
168
+ for step in range(steps):
169
+ model.train()
170
+ set_dropout(model, 0.1)
171
+ opt, sched = configure_optimizer(
172
+ model, lr=base_lr, total_steps=epochs_per_step
173
+ )
174
+ train(
175
+ model,
176
+ train_bits,
177
+ epochs=epochs_per_step,
178
+ extra_steps=extra_steps,
179
+ compress_prob=0.0 if diffusion else 1.0,
180
+ log=True,
181
+ diffusion=diffusion,
182
+ diffusion_curriculum=diffusion_curriculum,
183
+ optimizer=opt,
184
+ scheduler=sched,
185
+ )
186
+
187
+ model.eval()
188
+ set_dropout(model, 0.0)
189
+ with torch.no_grad():
190
+ logits, telemetry = model(valid_bits, causal=not diffusion)
191
+ if diffusion:
192
+ pred = logits.reshape(-1, 2)
193
+ target = valid_bits.reshape(-1)
194
+ else:
195
+ pred = logits[:, :-1, :].reshape(-1, 2)
196
+ target = valid_bits[:, 1:].reshape(-1)
197
+ val_loss = F.cross_entropy(pred, target).item()
198
+ k = telemetry["negentropy_logits"].mean().item()
199
+ c = telemetry["lz_complexity_logits"].mean().item()
200
+ s = telemetry["symbiosis_score"].mean().item()
201
+ print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
202
+ results.append((step, val_loss, k, c, s))
203
+
204
+ if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
205
+ strategy = next(scale_cycle)
206
+ base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
207
+ if strategy == "layers":
208
+ old_layers = model.num_layers
209
+ model = model.double_layers()
210
+ warm_opt, warm_sched = configure_optimizer(
211
+ model, lr=base_lr, total_steps=100
212
+ )
213
+ _warmup(
214
+ model,
215
+ train_bits,
216
+ steps=100,
217
+ freeze_old=True,
218
+ old_layers=old_layers,
219
+ diffusion=diffusion,
220
+ curriculum=diffusion_curriculum,
221
+ optimizer=warm_opt,
222
+ scheduler=warm_sched,
223
+ )
224
+ elif strategy == "width":
225
+ model = model.double_width()
226
+ warm_opt, warm_sched = configure_optimizer(
227
+ model, lr=base_lr, total_steps=100
228
+ )
229
+ _warmup(
230
+ model,
231
+ train_bits,
232
+ steps=100,
233
+ diffusion=diffusion,
234
+ curriculum=diffusion_curriculum,
235
+ optimizer=warm_opt,
236
+ scheduler=warm_sched,
237
+ )
238
+ else:
239
+ max_len *= 2
240
+ train_bits, valid_bits, train_lines = load_wikitext(
241
+ dataset_size, max_len
242
+ )
243
+ model = model.double_length()
244
+ warm_opt, warm_sched = configure_optimizer(
245
+ model, lr=base_lr, total_steps=100
246
+ )
247
+ _warmup(
248
+ model,
249
+ train_bits,
250
+ steps=100,
251
+ diffusion=diffusion,
252
+ curriculum=diffusion_curriculum,
253
+ optimizer=warm_opt,
254
+ scheduler=warm_sched,
255
+ )
256
+
257
+ prev_val_loss = val_loss
258
+ if time.time() - start > 8 * 60:
259
+ print("Time limit reached")
260
+ break
261
+
262
+ # optional plateau phase at final size
263
+ for p in range(plateau_steps):
264
+ model.train()
265
+ set_dropout(model, 0.1)
266
+ train(
267
+ model,
268
+ train_bits,
269
+ epochs=epochs_per_step,
270
+ extra_steps=extra_steps,
271
+ compress_prob=0.0 if diffusion else 1.0,
272
+ log=True,
273
+ diffusion=diffusion,
274
+ diffusion_curriculum=diffusion_curriculum,
275
+ )
276
+ model.eval()
277
+ set_dropout(model, 0.0)
278
+ with torch.no_grad():
279
+ logits, telemetry = model(valid_bits, causal=not diffusion)
280
+ if diffusion:
281
+ pred = logits.reshape(-1, 2)
282
+ target = valid_bits.reshape(-1)
283
+ else:
284
+ pred = logits[:, :-1, :].reshape(-1, 2)
285
+ target = valid_bits[:, 1:].reshape(-1)
286
+ val_loss = F.cross_entropy(pred, target).item()
287
+ k = telemetry["negentropy_logits"].mean().item()
288
+ c = telemetry["lz_complexity_logits"].mean().item()
289
+ s = telemetry["symbiosis_score"].mean().item()
290
+ idx = steps + p
291
+ print(
292
+ f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
293
+ )
294
+ results.append((idx, val_loss, k, c, s))
295
+ if time.time() - start > 8 * 60:
296
+ print("Time limit reached")
297
+ break
298
+
299
+ # final validation after last step
300
+ model.eval()
301
+ set_dropout(model, 0.0)
302
+ with torch.no_grad():
303
+ logits, telemetry = model(valid_bits, causal=not diffusion)
304
+ if diffusion:
305
+ pred = logits.reshape(-1, 2)
306
+ target = valid_bits.reshape(-1)
307
+ else:
308
+ pred = logits[:, :-1, :].reshape(-1, 2)
309
+ target = valid_bits[:, 1:].reshape(-1)
310
+ val_loss = F.cross_entropy(pred, target).item()
311
+ k = telemetry["negentropy_logits"].mean().item()
312
+ c = telemetry["lz_complexity_logits"].mean().item()
313
+ s = telemetry["symbiosis_score"].mean().item()
314
+
315
+ print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
316
+ results.append((steps + plateau_steps, val_loss, k, c, s))
317
+
318
+ # persist final model weights for future runs
319
+ save_model(model, weights_path)
320
+
321
+ input_bits = valid_bits[:1]
322
+ if qat:
323
+ qmodel = convert_qat_fx(model)
324
+ else:
325
+ with cpu_autocast():
326
+ model(input_bits)
327
+ qmodel = quantize_dynamic(model)
328
+ qmodel.eval()
329
+ try:
330
+ hil_safe_inference(
331
+ qmodel,
332
+ input_bits,
333
+ c_floor=0.3,
334
+ s_floor=0.5,
335
+ causal=not diffusion,
336
+ strict=not diffusion,
337
+ )
338
+ except RuntimeError as e:
339
+ print("Safety gate triggered", e)
340
+ collapsed = None
341
+ if collapse:
342
+ synth = TelemetrySynthesizer(n_clusters=8)
343
+ reps = synth.cluster_sequences(model, train_bits[:64])
344
+ floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
345
+ collapsed, metrics = collapse_submodel(
346
+ reps,
347
+ target_params=dict(
348
+ d_model=16,
349
+ nhead=4,
350
+ num_layers=1,
351
+ dim_feedforward=32,
352
+ max_seq_len=max_len,
353
+ ),
354
+ floors=floors,
355
+ )
356
+ collapsed.eval()
357
+ with torch.no_grad():
358
+ logits, _ = collapsed(valid_bits)
359
+ pred = logits[:, :-1, :].reshape(-1, 2)
360
+ target = valid_bits[:, 1:].reshape(-1)
361
+ c_loss = F.cross_entropy(pred, target).item()
362
+ print("Collapsed model validation loss:", c_loss)
363
+ if collapsed_path is not None:
364
+ save_distilled_model(
365
+ collapsed,
366
+ collapsed_path,
367
+ {**metrics, "val_loss": c_loss},
368
+ floors=floors,
369
+ )
370
+ if diffusion:
371
+ sample = diffusion_inference(
372
+ model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
373
+ )
374
+ print("Diffusion sample:", sample[0].tolist())
375
+ return results, collapsed
376
+
377
+
378
+ if __name__ == "__main__":
379
+ integration_schedule()
BitTransformerLM/mcp_server.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import gzip
4
+ import uuid
5
+ import traceback
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from flask import Flask, request, jsonify, send_file
8
+ import matplotlib.pyplot as plt
9
+ import torch
10
+
11
+ from bit_transformer.dashboard_app import ModelManager
12
+ from bit_transformer.dashboard import plot_telemetry
13
+ from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
14
+ from bit_transformer.optimization import configure_optimizer
15
+ from bit_transformer.bit_io import text_to_bits
16
+
17
+ app = Flask(__name__)
18
+ manager = ModelManager()
19
+
20
+ # background job management
21
+ executor = ThreadPoolExecutor(max_workers=4)
22
+ jobs: dict[str, dict] = {}
23
+
24
+
25
+ def _submit_job(fn, *args, **kwargs) -> str:
26
+ """Schedule a function for background execution and return a job id."""
27
+ job_id = str(uuid.uuid4())
28
+ jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []}
29
+
30
+ def wrapper():
31
+ jobs[job_id]["status"] = "running"
32
+ try:
33
+ jobs[job_id]["result"] = fn(*args, **kwargs)
34
+ jobs[job_id]["status"] = "completed"
35
+ except Exception as err: # pragma: no cover - captured for client
36
+ jobs[job_id]["status"] = "error"
37
+ jobs[job_id]["error"] = str(err)
38
+ jobs[job_id]["trace"] = traceback.format_exc()
39
+
40
+ executor.submit(wrapper)
41
+ return job_id
42
+
43
+
44
+ @app.errorhandler(Exception)
45
+ def handle_exception(err):
46
+ """Return JSON error responses with stack traces."""
47
+ return (
48
+ jsonify({"error": str(err), "trace": traceback.format_exc()}),
49
+ getattr(err, "code", 500),
50
+ )
51
+
52
+
53
+ @app.route("/init", methods=["POST"])
54
+ def init_model():
55
+ data = request.json or {}
56
+ int_fields = {
57
+ "d_model",
58
+ "nhead",
59
+ "num_layers",
60
+ "dim_feedforward",
61
+ "max_seq_len",
62
+ "chunk_size",
63
+ "overlap",
64
+ }
65
+ float_fields = {"act_threshold"}
66
+ bool_fields = {"reversible", "use_checkpoint"}
67
+ params = {}
68
+ for k, v in data.items():
69
+ if v is None:
70
+ params[k] = None
71
+ elif k in int_fields:
72
+ params[k] = int(v)
73
+ elif k in float_fields:
74
+ params[k] = float(v)
75
+ elif k in bool_fields:
76
+ params[k] = bool(v)
77
+ else:
78
+ params[k] = v
79
+ manager.init_model(params)
80
+ return jsonify({"status": "initialized", "params": params})
81
+
82
+ @app.route("/train", methods=["POST"])
83
+ def train_model():
84
+ bits = request.json["bits"]
85
+
86
+ def task():
87
+ tensor = torch.tensor(bits, dtype=torch.long)
88
+ loss, ratio = manager.train_step(tensor)
89
+ return {"loss": loss, "ratio": ratio}
90
+
91
+ job_id = _submit_job(task)
92
+ return jsonify({"job_id": job_id})
93
+
94
+
95
+ @app.route("/train_epochs", methods=["POST"])
96
+ def train_epochs_route():
97
+ data = request.json
98
+ bits = data["bits"]
99
+ epochs = int(data.get("epochs", 1))
100
+ compress_prob = float(data.get("compress_prob", 0.5))
101
+ direct_prob = float(data.get("direct_prob", 0.0))
102
+
103
+ def task():
104
+ tensor = torch.tensor(bits, dtype=torch.long)
105
+ metrics = manager.train_epochs(
106
+ tensor,
107
+ epochs=epochs,
108
+ compress_prob=compress_prob,
109
+ direct_prob=direct_prob,
110
+ )
111
+ return {"metrics": metrics}
112
+
113
+ job_id = _submit_job(task)
114
+ return jsonify({"job_id": job_id})
115
+
116
+ @app.route("/scale_up", methods=["POST"])
117
+ def scale_up():
118
+ width_mult = float(request.json.get("width_mult", 1.0))
119
+
120
+ def task():
121
+ manager.scale_up(width_mult)
122
+ return {
123
+ "status": "scaled",
124
+ "layers": manager.model.num_layers,
125
+ "d_model": manager.model.d_model,
126
+ }
127
+
128
+ job_id = _submit_job(task)
129
+ return jsonify({"job_id": job_id})
130
+
131
+ @app.route("/collapse", methods=["POST"])
132
+ def collapse_model():
133
+ cluster_bits = request.json["clusters"]
134
+ params = {k: int(v) for k, v in request.json["params"].items()}
135
+ width_scale = float(request.json.get("width_scale", 1.0))
136
+
137
+ def task():
138
+ manager.collapse(cluster_bits, params, width_scale)
139
+ return {"status": "collapsed"}
140
+
141
+ job_id = _submit_job(task)
142
+ return jsonify({"job_id": job_id})
143
+
144
+
145
+ @app.route("/job/<job_id>", methods=["GET"])
146
+ def get_job(job_id: str):
147
+ job = jobs.get(job_id)
148
+ if job is None:
149
+ return jsonify({"error": "not found"}), 404
150
+ return jsonify(job)
151
+
152
+
153
+ @app.route("/jobs", methods=["GET"])
154
+ def list_jobs():
155
+ return jsonify(jobs)
156
+
157
+ @app.route("/lambdas", methods=["GET", "POST"])
158
+ def update_lambdas():
159
+ if request.method == "POST":
160
+ data = request.json
161
+ manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]))
162
+ return jsonify({"status": "updated"})
163
+ else:
164
+ return jsonify({
165
+ "lambda_K": manager.lambda_K,
166
+ "lambda_C": manager.lambda_C,
167
+ "lambda_S": manager.lambda_S,
168
+ })
169
+
170
+ @app.route("/diffusion", methods=["GET", "POST"])
171
+ def update_diffusion():
172
+ if request.method == "POST":
173
+ manager.set_diffusion(bool(request.json.get("diffusion", False)))
174
+ return jsonify({"status": "updated"})
175
+ return jsonify({"diffusion": manager.diffusion})
176
+
177
+
178
+ @app.route("/qat", methods=["GET", "POST"])
179
+ def update_qat():
180
+ if request.method == "POST":
181
+ manager.set_qat(bool(request.json.get("qat", False)))
182
+ return jsonify({"status": "updated"})
183
+ return jsonify({"qat": manager.qat})
184
+
185
+
186
+ @app.route("/gpu", methods=["GET", "POST"])
187
+ def update_gpu():
188
+ if request.method == "POST":
189
+ manager.set_gpu(bool(request.json.get("use_gpu", False)))
190
+ return jsonify({"status": "updated"})
191
+ return jsonify({"use_gpu": manager.use_gpu})
192
+
193
+ @app.route("/infer", methods=["POST"])
194
+ def inference():
195
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
196
+ result = manager.infer(bits)
197
+ return jsonify(result)
198
+
199
+
200
+ @app.route("/infer_long", methods=["POST"])
201
+ def inference_long():
202
+ bits = torch.tensor(request.json["bits"], dtype=torch.long)
203
+ ctx = int(request.json.get("ctx_bits", 4096))
204
+ overlap = int(request.json.get("overlap", 256))
205
+ result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
206
+ return jsonify(result)
207
+
208
+ @app.route("/infer_text", methods=["POST"])
209
+ def inference_text():
210
+ text = request.json.get("text", "")
211
+ result = manager.infer_text(text)
212
+ return jsonify(result)
213
+
214
+ @app.route("/status", methods=["GET"])
215
+ def status():
216
+ return jsonify(manager.get_status())
217
+
218
+
219
+ @app.route("/model_config", methods=["GET"])
220
+ def model_config():
221
+ return jsonify(manager.get_model_config())
222
+
223
+
224
+ @app.route("/metrics", methods=["GET"])
225
+ def metrics():
226
+ return jsonify(manager.get_metrics())
227
+
228
+
229
+ @app.route("/save_checkpoint", methods=["POST"])
230
+ def save_checkpoint_route():
231
+ repo_id = request.json.get("repo_id")
232
+ token = request.json.get("token") or os.getenv("HF_TOKEN")
233
+ if manager.model is None:
234
+ return jsonify({"error": "model not initialized"}), 400
235
+ if token:
236
+ hf_login(token=token)
237
+ save_checkpoint(manager.model, repo_id=repo_id)
238
+ return jsonify({"status": "saved"})
239
+
240
+
241
+ @app.route("/download_checkpoint", methods=["POST"])
242
+ def download_checkpoint_route():
243
+ repo_id = request.json.get("repo_id")
244
+ token = request.json.get("token") or os.getenv("HF_TOKEN")
245
+ if token:
246
+ hf_login(token=token)
247
+ dest = manager.weights_path + ".gz"
248
+ ok = download_checkpoint(dest, repo_id=repo_id)
249
+ if not ok:
250
+ return jsonify({"status": "failed"}), 500
251
+ if manager.model is None:
252
+ return jsonify({"status": "downloaded", "loaded": False})
253
+ with gzip.open(dest, "rb") as f:
254
+ state = torch.load(f, map_location="cpu")
255
+ manager.model.load_state_dict(state)
256
+ manager.optimizer, manager.scheduler = configure_optimizer(
257
+ manager.model, lr=1e-3, total_steps=manager.total_steps
258
+ )
259
+ manager._apply_device()
260
+ manager._save_state()
261
+ return jsonify({"status": "downloaded", "loaded": True})
262
+
263
+ @app.route("/plot.png")
264
+ def plot_png():
265
+ fig, _ = plot_telemetry(manager.metrics)
266
+ buf = io.BytesIO()
267
+ fig.savefig(buf, format="png")
268
+ plt.close(fig)
269
+ buf.seek(0)
270
+ return send_file(buf, mimetype="image/png")
271
+
272
+
273
+ @app.route("/text_to_bits", methods=["POST"])
274
+ def text_to_bits_route():
275
+ text = request.json.get("text", "")
276
+ if len(text) > 100_000:
277
+ return jsonify({"error": "text too large"}), 413
278
+ return jsonify({"bits": text_to_bits(text)})
279
+
280
+
281
+ @app.route("/dataset", methods=["GET"])
282
+ def dataset_route():
283
+ name = request.args.get("name", "")
284
+ split = request.args.get("split", "train")
285
+ size = int(request.args.get("size", 1))
286
+ seq_len = int(request.args.get("seq_len", 64))
287
+ if size * seq_len > 1_000_000:
288
+ return jsonify({"error": "dataset too large"}), 413
289
+ if name == "wikitext2":
290
+ try:
291
+ from datasets import load_dataset
292
+
293
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
294
+ lines = [t for t in ds["text"] if t.strip()][:size]
295
+ except Exception:
296
+ bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
297
+ return jsonify({"bits": bits.tolist()})
298
+ bits_list = []
299
+ for text in lines:
300
+ b = text_to_bits(text)[:seq_len]
301
+ if len(b) < seq_len:
302
+ b.extend([0] * (seq_len - len(b)))
303
+ bits_list.append(b)
304
+ if len(bits_list) < size:
305
+ pad = size - len(bits_list)
306
+ bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
307
+ return jsonify({"bits": bits_list})
308
+ return jsonify({"error": "unknown dataset"}), 400
309
+
310
+
311
+ @app.route("/health")
312
+ def health_check():
313
+ return jsonify({"status": "ok"})
314
+
315
+
316
+ def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None:
317
+ app.run(host=host, port=port, debug=True)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ import torch
322
+ run_mcp_server()
BitTransformerLM/progressive_scaleup.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy progressive scale-up demo.
2
+
3
+ This script is retained for historical reference but has been superseded by
4
+ ``integration_schedule.py`` which provides a more flexible scaling workflow.
5
+ """
6
+
7
+ import argparse
8
+ import warnings
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from bit_transformer import (
12
+ BitTransformerLM,
13
+ configure_optimizer,
14
+ expand_model,
15
+ text_to_bits,
16
+ )
17
+ from bit_transformer.training import train_loop as basic_train
18
+
19
+ warnings.warn(
20
+ "progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
21
+ DeprecationWarning,
22
+ stacklevel=2,
23
+ )
24
+
25
+
26
+ def progressive_scale_up(
27
+ eps: float = 0.65,
28
+ steps: int = 2,
29
+ width_mult: float = 1.0,
30
+ forward_kwargs: dict | None = None,
31
+ ) -> None:
32
+ """Demonstrate automatic scaling of the model on random data."""
33
+ params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
34
+ model = BitTransformerLM(**params)
35
+ steps_per_epoch = 64 // 8
36
+ optimizer, scheduler = configure_optimizer(
37
+ model, lr=1e-3, total_steps=steps * steps_per_epoch
38
+ )
39
+
40
+ train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
41
+ valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
42
+
43
+ for step in range(steps):
44
+ # one epoch over train
45
+ basic_train(
46
+ model,
47
+ train,
48
+ epochs=1,
49
+ compress_prob=0.5,
50
+ log=False,
51
+ forward_kwargs=forward_kwargs,
52
+ )
53
+
54
+ with torch.no_grad():
55
+ logits, _ = model(valid, **(forward_kwargs or {}))
56
+ pred = logits[:, :-1, :].reshape(-1, 2)
57
+ target = valid[:, 1:].reshape(-1)
58
+ val_loss = F.cross_entropy(pred, target).item()
59
+ print(f"Step {step} validation loss: {val_loss:.4f}")
60
+ if val_loss < eps:
61
+ params["num_layers"] *= 2
62
+ params["d_model"] = int(params["d_model"] * width_mult)
63
+ params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
64
+ model = expand_model(model, params)
65
+ optimizer, scheduler = configure_optimizer(
66
+ model, lr=1e-3, total_steps=steps * steps_per_epoch
67
+ )
68
+ print(
69
+ "Scaled model to", params["num_layers"], "layers and width", params["d_model"]
70
+ )
71
+
72
+
73
+ def progressive_scale_up_text(
74
+ improve_thresh: float = 0.01,
75
+ steps: int = 2,
76
+ width_mult: float = 2.0,
77
+ max_len: int = 64,
78
+ dataset_size: int = 512,
79
+ forward_kwargs: dict | None = None,
80
+ ) -> None:
81
+ """Scale up using WikiText2 lines converted to bits.
82
+
83
+ Parameters
84
+ ----------
85
+ improve_thresh: float
86
+ Relative validation loss improvement required to avoid scaling.
87
+ If improvement is <= this threshold, model size is increased.
88
+ steps: int
89
+ Number of training steps.
90
+ width_mult: float
91
+ Multiplier applied when increasing model width.
92
+ max_len: int
93
+ Initial sequence length.
94
+ dataset_size: int
95
+ Number of training lines to load from WikiText2.
96
+ forward_kwargs: dict | None
97
+ Extra keyword arguments for the forward pass.
98
+ """
99
+ from datasets import load_dataset
100
+
101
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1")
102
+ train_iter = ds["train"]["text"]
103
+ valid_iter = ds["validation"]["text"]
104
+
105
+ train_lines = []
106
+ for line in train_iter:
107
+ train_lines.append(line)
108
+ if len(train_lines) >= dataset_size:
109
+ break
110
+
111
+ valid_lines = []
112
+ for line in valid_iter:
113
+ valid_lines.append(line)
114
+ if len(valid_lines) >= dataset_size // 4:
115
+ break
116
+
117
+ def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
118
+ seqs = []
119
+ for text in lines:
120
+ bits = text_to_bits(text)[:length]
121
+ if len(bits) < length:
122
+ bits.extend([0] * (length - len(bits)))
123
+ seqs.append(bits)
124
+ return torch.tensor(seqs, dtype=torch.long)
125
+
126
+ train = lines_to_tensor(train_lines, max_len)
127
+ valid = lines_to_tensor(valid_lines, max_len)
128
+
129
+ params = dict(
130
+ d_model=32,
131
+ nhead=4,
132
+ num_layers=1,
133
+ dim_feedforward=64,
134
+ max_seq_len=max_len,
135
+ )
136
+ model = BitTransformerLM(**params)
137
+ steps_per_epoch = len(train) // 8
138
+ optimizer, scheduler = configure_optimizer(
139
+ model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
140
+ )
141
+
142
+ prev_loss: float | None = None
143
+ scale_length = True
144
+
145
+ for step in range(steps):
146
+ basic_train(
147
+ model,
148
+ train,
149
+ epochs=1,
150
+ compress_prob=0.5,
151
+ log=False,
152
+ forward_kwargs=forward_kwargs,
153
+ )
154
+
155
+ with torch.no_grad():
156
+ logits, _ = model(valid, **(forward_kwargs or {}))
157
+ pred = logits[:, :-1, :].reshape(-1, 2)
158
+ target = valid[:, 1:].reshape(-1)
159
+ val_loss = F.cross_entropy(pred, target).item()
160
+ print(f"Step {step} validation loss: {val_loss:.4f}")
161
+ if prev_loss is not None:
162
+ improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
163
+ if improvement <= improve_thresh:
164
+ if scale_length:
165
+ params["max_seq_len"] *= 2
166
+ train = lines_to_tensor(train_lines, params["max_seq_len"])
167
+ valid = lines_to_tensor(valid_lines, params["max_seq_len"])
168
+ model = model.double_length()
169
+ steps_per_epoch = len(train) // 8
170
+ scale_type = "length"
171
+ else:
172
+ params["d_model"] = int(params["d_model"] * width_mult)
173
+ params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
174
+ model = expand_model(model, params)
175
+ scale_type = "width"
176
+ optimizer, scheduler = configure_optimizer(
177
+ model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
178
+ )
179
+ scale_length = not scale_length
180
+ param_count = sum(p.numel() for p in model.parameters())
181
+ print(
182
+ f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
183
+ )
184
+ prev_loss = val_loss
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser(description="Progressively scale model length and width")
189
+ parser.add_argument("--steps", type=int, default=2, help="number of training steps")
190
+ parser.add_argument(
191
+ "--improve-thresh",
192
+ type=float,
193
+ default=0.01,
194
+ help="relative loss improvement required to avoid scaling",
195
+ )
196
+ parser.add_argument(
197
+ "--width-mult", type=float, default=2.0, help="width multiplier when scaling"
198
+ )
199
+ parser.add_argument("--causal", action="store_true", help="use causal attention during training")
200
+ parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
201
+ args = parser.parse_args()
202
+ if args.wikitext:
203
+ progressive_scale_up_text(
204
+ improve_thresh=args.improve_thresh,
205
+ steps=args.steps,
206
+ width_mult=args.width_mult,
207
+ forward_kwargs={"causal": args.causal} if args.causal else None,
208
+ )
209
+ else:
210
+ progressive_scale_up(
211
+ eps=args.improve_thresh,
212
+ steps=args.steps,
213
+ width_mult=args.width_mult,
214
+ forward_kwargs={"causal": args.causal} if args.causal else None,
215
+ )
216
+
BitTransformerLM/pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=67", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "bit-transformer"
7
+ version = "0.1.0"
8
+ description = "Bit-based transformer language model nearing v1 release with telemetry"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = {text = "All Rights Reserved"}
12
+ authors = [{name = "WCNegentropy"}]
13
+
14
+ [project.urls]
15
+ Homepage = "https://example.com"
16
+
17
+ [tool.setuptools.packages.find]
18
+ include = ["bit_transformer"]
BitTransformerLM/recursive_integration_flow.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.profiler import profile
4
+ from bit_transformer import (
5
+ BitTransformerLM,
6
+ quantize_dynamic,
7
+ hil_safe_inference,
8
+ collapse_submodel,
9
+ )
10
+ from bit_transformer.training import train_loop
11
+ from bit_transformer.torch_utils import cpu_autocast
12
+
13
+
14
+ def train(
15
+ model: BitTransformerLM,
16
+ data: torch.Tensor,
17
+ epochs: int = 1,
18
+ compress_prob: float = 0.5,
19
+ log: bool = False,
20
+ forward_kwargs: dict | None = None,
21
+ ) -> list[dict]:
22
+ """Train with random compression; returns per-epoch metrics."""
23
+ return train_loop(
24
+ model,
25
+ data,
26
+ epochs=epochs,
27
+ compress_prob=compress_prob,
28
+ direct_prob=0.0,
29
+ log=log,
30
+ forward_kwargs=forward_kwargs,
31
+ )
32
+
33
+
34
+ def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None:
35
+ """Run a dynamic scale-up loop with telemetry-based gating."""
36
+ train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long)
37
+ valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long)
38
+ input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long)
39
+ bit_sequence_data = train_bits.tolist()
40
+
41
+ best_K = best_C = best_S = 0.0
42
+
43
+ model = BitTransformerLM(
44
+ d_model=32,
45
+ nhead=4,
46
+ num_layers=1,
47
+ dim_feedforward=64,
48
+ max_seq_len=max_len,
49
+ use_act=True,
50
+ act_threshold=0.7,
51
+ reversible=True,
52
+ chunk_size=max_len,
53
+ use_autocast=True,
54
+ )
55
+
56
+ results = []
57
+ for step in range(steps + 1):
58
+ epochs = min(10, 2 + step // 2)
59
+ train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True)
60
+
61
+ with torch.no_grad():
62
+ with cpu_autocast():
63
+ logits, telemetry = model(valid_bits)
64
+ pred = logits[:, :-1, :].reshape(-1, 2)
65
+ target = valid_bits[:, 1:].reshape(-1)
66
+ val_loss = F.cross_entropy(pred, target).item()
67
+ k = telemetry["negentropy_logits"].mean().item()
68
+ c = telemetry["lz_complexity_logits"].mean().item()
69
+ s = telemetry["symbiosis_score"].mean().item()
70
+
71
+ print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
72
+ results.append((step, val_loss, k, c, s))
73
+
74
+ if step > 0:
75
+ if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3:
76
+ print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.")
77
+ break
78
+ best_K = max(best_K, k)
79
+ best_C = max(best_C, c)
80
+ best_S = max(best_S, s)
81
+
82
+ if step < steps:
83
+ if step % 2 == 0:
84
+ model = model.double_width()
85
+ else:
86
+ model = model.double_layers()
87
+
88
+ # Post-scaling optimizations
89
+ with cpu_autocast():
90
+ model(input_bits)
91
+
92
+ qmodel = quantize_dynamic(model)
93
+ qmodel.eval()
94
+
95
+ safe_output = hil_safe_inference(
96
+ qmodel, input_bits, c_floor=0.5, s_floor=0.2
97
+ )
98
+
99
+ student_model, _ = collapse_submodel(
100
+ bit_sequence_data,
101
+ target_params=dict(
102
+ d_model=16,
103
+ nhead=4,
104
+ num_layers=1,
105
+ dim_feedforward=32,
106
+ max_seq_len=max_len,
107
+ ),
108
+ floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2},
109
+ )
110
+
111
+ if hasattr(torch, "compile"):
112
+ try:
113
+ compiled = torch.compile(student_model)
114
+ except RuntimeError as exc:
115
+ print(f"Compilation skipped: {exc}")
116
+ compiled = student_model
117
+ else:
118
+ compiled = student_model
119
+ compiled.eval()
120
+
121
+ with profile() as prof:
122
+ compiled(input_bits)
123
+ prof.export_chrome_trace("trace12.json")
124
+ print("Safe output bits:", safe_output[0].tolist())
125
+
126
+
127
+ if __name__ == "__main__":
128
+ recursive_integration_flow()
BitTransformerLM/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -i https://pypi.org/simple
2
+ --extra-index-url https://download.pytorch.org/whl/cpu
3
+ torch==2.7.1+cpu
4
+
5
+ # --extra-index-url https://download.pytorch.org/whl/cu118
6
+ # torch==2.7.1+cu118
7
+
8
+ pytest==8.4.1
9
+ scikit-learn==1.7.1
10
+ matplotlib==3.10.3
11
+ datasets==4.0.0
12
+ flask==3.1.1
13
+ numpy==2.3.1
14
+ requests==2.32.3
15
+ watchdog==6.0.0
16
+ huggingface-hub==0.34.3
BitTransformerLM/review_cli.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from bit_transformer import BitTransformerLM
7
+
8
+
9
+ def list_candidates(path: Path):
10
+ models = sorted(path.glob("*.pt"))
11
+ for m in models:
12
+ metrics_file = m.with_suffix(m.suffix + ".json")
13
+ metrics = {}
14
+ if metrics_file.exists():
15
+ with open(metrics_file) as f:
16
+ metrics = json.load(f)
17
+ yield m, metrics
18
+
19
+
20
+ def main():
21
+ parser = argparse.ArgumentParser(description="Review distilled submodels")
22
+ parser.add_argument("directory", type=Path, help="Directory with candidate models")
23
+ parser.add_argument("--approve-dir", type=Path, default=Path("approved"), help="Directory to store approved models")
24
+ args = parser.parse_args()
25
+
26
+ args.approve_dir.mkdir(exist_ok=True)
27
+ log_file = args.approve_dir / "review_log.jsonl"
28
+
29
+ for model_path, metrics in list_candidates(args.directory):
30
+ print("Candidate:", model_path.name)
31
+ for k, v in metrics.items():
32
+ print(f" {k}: {v}")
33
+ ans = input("Approve this model? [y/N] ").strip().lower()
34
+ if ans == "y":
35
+ approved_path = args.approve_dir / model_path.name
36
+ torch.save(torch.load(model_path), approved_path)
37
+ entry = {"model": approved_path.name, "metrics": metrics, "approved": True}
38
+ with open(log_file, "a") as lf:
39
+ lf.write(json.dumps(entry) + "\n")
40
+ print("Approved and saved to", approved_path)
41
+ else:
42
+ entry = {"model": model_path.name, "metrics": metrics, "approved": False}
43
+ with open(log_file, "a") as lf:
44
+ lf.write(json.dumps(entry) + "\n")
45
+ print("Rejected", model_path.name)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
BitTransformerLM/start.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ python mcp_server.py &
4
+ sleep 2
5
+ python -m bit_transformer.dashboard_app
BitTransformerLM/state_of_the_repo_audit.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ summary: |
2
+ The **BitTransformerLM** repository is well-structured and aligns closely with the README’s feature set.
3
+ All core functionalities (bit-level modeling, telemetry metrics, progressive scaling, compression, context extension, diffusion mode, dashboard, etc.) are present and largely consistent with documentation.
4
+ The code is generally clean and well-tested (no TODOs or obvious dead code) with an effective CI in place:contentReference[oaicite:0]{index=0}.
5
+ We identified a few issues via static analysis: a critical **security flaw** where the dashboard’s `/exec` endpoint executes arbitrary code:contentReference[oaicite:1]{index=1}, a missing import that breaks the compression toggle:contentReference[oaicite:2]{index=2}:contentReference[oaicite:3]{index=3}, and a rare edge-case in bit-sequence decompression logic:contentReference[oaicite:4]{index=4}.
6
+ No functions exceed 300 lines, though the `BitTransformerLM.forward` method is complex with deeply nested logic (~6 levels) and duplicated code blocks for the halting mechanism.
7
+ Naming conventions are consistent (snake_case for functions, CamelCase for classes), and dependency versions are up-to-date.
8
+ Documentation and code behavior are in sync – for example, the MCP server’s `/health` endpoint described in docs is implemented:contentReference[oaicite:5]{index=5}.
9
+ Overall, the project appears **nearly production-ready**, with these fixes and refinements needed before a 1.0 release.
10
+
11
+ findings:
12
+ - severity: P0
13
+ effort: S
14
+ category: security
15
+ location: bit_transformer/dashboard_app.py:533
16
+ description: "Unrestricted `/exec` HTTP endpoint allows arbitrary code execution:contentReference[oaicite:6]{index=6}."
17
+ recommendation: "Disable or restrict the `/exec` route (e.g. remove it or require an admin token) to prevent remote code execution."
18
+ status: completed ✅
19
+ - severity: P1
20
+ effort: S
21
+ category: static
22
+ location: bit_transformer/dashboard_app.py:195
23
+ description: "NameError risk – `compress_bits` is used without being imported:contentReference[oaicite:7]{index=7}:contentReference[oaicite:8]{index=8}."
24
+ recommendation: "Import the `compress_bits` function in `dashboard_app.py` (e.g. `from .compression import compress_bits`) so compression toggles don’t crash."
25
+ status: completed ✅
26
+ - severity: P2
27
+ effort: M
28
+ category: static
29
+ location: bit_transformer/model.py:320
30
+ description: "Edge-case bug – `_maybe_decompress` skips decompression if all values ≤1:contentReference[oaicite:9]{index=9}, which can misinterpret run-length encoding outputs of all 1s."
31
+ recommendation: "Adjust the decompress condition (e.g. track whether input was compressed) to ensure even uniformly alternating bit sequences get properly decompressed."
32
+ status: completed ✅
33
+ - severity: P3
34
+ effort: M
35
+ category: static
36
+ location: bit_transformer/model.py:415
37
+ description: "Duplicate code – nearly identical halting logic is implemented in both reversible and normal forward loops:contentReference[oaicite:10]{index=10}:contentReference[oaicite:11]{index=11}."
38
+ recommendation: "Refactor the halting (ACT) mechanism into a helper function to avoid repetition and reduce maintenance effort."
39
+ status: completed ✅
40
+ - severity: P3
41
+ effort: M
42
+ category: static
43
+ location: bit_transformer/model.py:368
44
+ description: "Complex logic – `BitTransformerLM.forward` contains deeply nested control flow (up to 5-6 levels) for reversible layers, ACT, etc."
45
+ recommendation: "Consider simplifying or breaking up the forward pass (e.g. separate functions for reversible vs. standard flow) to improve readability and maintainability."
46
+ status: completed ✅
47
+ - severity: P3
48
+ effort: S
49
+ category: static
50
+ location: bit_transformer/dashboard_app.py:125
51
+ description: "Config parsing quirk – booleans in `ModelManager.init_model` are cast to int (True→1) instead of preserved as bool."
52
+ recommendation: "Handle boolean fields explicitly (e.g. do not cast values for keys like `reversible` or `use_act` to int) to avoid confusion and potential type issues."
53
+ status: completed ✅
54
+
55
+ codex_tasks:
56
+ - codex_prompt: "Remove or secure the dangerous `/exec` endpoint in the dashboard to prevent arbitrary code execution."
57
+ acceptance_test: |
58
+ import requests, subprocess
59
+ Attempt to call the /exec endpoint with a harmless command
60
+ try:
61
+ resp = requests.post("http://localhost:5000/exec", json={"code": "print('OK')"}, timeout=5)
62
+ except Exception as e:
63
+ resp = e.response if hasattr(e, 'response') else None
64
+ The endpoint should be removed or secured, so it should either 404 or refuse access
65
+ assert resp is None or resp.status_code in (403, 404), "Exec endpoint still accessible!"
66
+ status: completed ✅
67
+ - codex_prompt: "Import the `compress_bits` function in `dashboard_app.py` so that enabling compression no longer raises a NameError."
68
+ acceptance_test: |
69
+ import torch
70
+ from bit_transformer.dashboard_app import ModelManager
71
+ mgr = ModelManager()
72
+ mgr.set_compression(True)
73
+ bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
74
+ try:
75
+ loss, ratio = mgr.train_step(bits)
76
+ except NameError as e:
77
+ raise AssertionError(f"NameError not resolved: {e}")
78
+ assert isinstance(loss, float) and 0 <= ratio <= 1.0, "Compression training failed"
79
+ status: completed ✅
80
+ - codex_prompt: "Fix `_maybe_decompress` in `model.py` to always decompress run-length encoded sequences (even if all run lengths are 1) before computing metrics."
81
+ acceptance_test: |
82
+ import torch
83
+ from bit_transformer import BitTransformerLM, compress_bits, decompress_bits
84
+ Create an alternating bit sequence where compress_bits yields only count=1 values
85
+ bits = torch.tensor([0,1]*8, dtype=torch.uint8)
86
+ comp = compress_bits(bits)
87
+ model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=len(bits))
88
+ Compute negentropy on compressed vs original and compare
89
+ neg_comp = model.negentropy_kpi(comp.unsqueeze(0))
90
+ neg_raw = model.negentropy_kpi(bits.unsqueeze(0))
91
+ assert torch.allclose(neg_comp, neg_raw, atol=1e-6), "Negentropy differs for compressed input – decompression fix failed"
92
+ status: completed ✅
93
+
94
+ metrics:
95
+ loc_total: 3770
96
+ todo_count: 0
97
+ duplicate_block_count: 3
98
+ oversized_functions: 0