|
import math |
|
import contextlib |
|
import logging |
|
from typing import Dict, List, Tuple, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import sys |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint as checkpoint |
|
|
|
from .torch_utils import cpu_autocast |
|
|
|
from .optimization import configure_optimizer |
|
from .compression import decompress_bits |
|
from .parity import enforce_parity |
|
|
|
_mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} |
|
_attention_cache: Dict[str, torch.Tensor] = {} |
|
_MAX_CACHE_SIZE = 50 |
|
|
|
|
|
def clear_cache(): |
|
"""Clear memory caches to prevent OOM in long sequences.""" |
|
global _mask_cache, _attention_cache |
|
_mask_cache.clear() |
|
_attention_cache.clear() |
|
|
|
|
|
def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
|
"""Return or create a cached upper-triangular mask with memory management.""" |
|
key = (seq_len, device) |
|
|
|
|
|
if len(_mask_cache) > _MAX_CACHE_SIZE: |
|
clear_cache() |
|
|
|
if key not in _mask_cache: |
|
_mask_cache[key] = torch.triu( |
|
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 |
|
) |
|
return _mask_cache[key] |
|
|
|
try: |
|
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
|
compile_fn = torch.compile |
|
else: |
|
raise RuntimeError |
|
except Exception: |
|
|
|
def compile_fn(fn=None, **kwargs): |
|
if fn is None: |
|
return lambda f: f |
|
return fn |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
"""Sinusoidal positional encoding.""" |
|
|
|
def __init__(self, d_model: int, max_len: int = 1024) -> None: |
|
super().__init__() |
|
pe = torch.zeros(max_len, d_model) |
|
pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
|
inv = torch.exp( |
|
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) |
|
) |
|
pe[:, 0::2] = torch.sin(pos * inv) |
|
pe[:, 1::2] = torch.cos(pos * inv) |
|
self.register_buffer("pe", pe.unsqueeze(1)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Add positional encoding to input tensor.""" |
|
return x + self.pe[: x.size(0)] |
|
|
|
|
|
class LoggingTransformerEncoderLayer(nn.Module): |
|
"""Transformer encoder layer that exposes attention weights. |
|
|
|
It optionally performs chunked attention with a fixed window size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
nhead: int, |
|
dim_feedforward: int = 512, |
|
dropout: float = 0.1, |
|
chunk_size: Optional[int] = None, |
|
overlap: int = 0, |
|
full_attn_logging: Optional[bool] = None, |
|
) -> None: |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
if full_attn_logging is None: |
|
full_attn_logging = False if chunk_size is not None else True |
|
self.full_attn_logging = full_attn_logging |
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.activation = F.relu |
|
|
|
def _chunked_attn( |
|
self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Perform memory-efficient chunked self attention with overlap.""" |
|
T, B, D = src.shape |
|
|
|
|
|
if T <= 128 or self.chunk_size is None or self.chunk_size >= T: |
|
return self._full_attn(src, attn_mask) |
|
|
|
src_b = src.transpose(0, 1) |
|
C = self.chunk_size |
|
O = self.overlap |
|
n_chunks = (T + C - 1) // C |
|
pad_len = n_chunks * C - T |
|
|
|
|
|
outputs = [] |
|
weights_list = [] |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): |
|
for chunk_idx in range(n_chunks): |
|
start_idx = chunk_idx * C |
|
end_idx = min(start_idx + C + 2 * O, T + O) |
|
|
|
|
|
chunk_start = max(0, start_idx - O) |
|
chunk_end = min(T, end_idx) |
|
chunk = src_b[:, chunk_start:chunk_end] |
|
|
|
|
|
if chunk.size(1) < C + 2 * O: |
|
pad_size = C + 2 * O - chunk.size(1) |
|
chunk = F.pad(chunk, (0, 0, 0, pad_size)) |
|
|
|
chunk_len = chunk.size(1) |
|
mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None |
|
|
|
|
|
out, weights = self.self_attn( |
|
chunk, chunk, chunk, |
|
attn_mask=mask, |
|
need_weights=self.full_attn_logging, |
|
average_attn_weights=False, |
|
) |
|
|
|
|
|
core_start = O if chunk_idx > 0 else 0 |
|
core_end = core_start + min(C, T - start_idx) |
|
outputs.append(out[:, core_start:core_end]) |
|
|
|
if self.full_attn_logging and weights is not None: |
|
weights_list.append(weights[:, :, core_start:core_end]) |
|
|
|
|
|
del out, weights, chunk |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
seq = torch.cat(outputs, dim=1) |
|
|
|
|
|
if self.full_attn_logging and weights_list: |
|
|
|
if T > 1024: |
|
attn_out = torch.empty(0, device=src.device) |
|
else: |
|
attn_out = torch.cat(weights_list, dim=2) |
|
else: |
|
attn_out = torch.empty(0, device=src.device) |
|
|
|
return seq.transpose(0, 1), attn_out |
|
|
|
def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Standard full attention for smaller sequences.""" |
|
qkv = src.transpose(0, 1) |
|
attn_output, attn_weights = self.self_attn( |
|
qkv, qkv, qkv, |
|
attn_mask=attn_mask, |
|
need_weights=True, |
|
average_attn_weights=False, |
|
) |
|
return attn_output.transpose(0, 1), attn_weights |
|
|
|
def forward( |
|
self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Return output and attention map.""" |
|
if self.chunk_size is not None: |
|
attn_output, attn_weights = self._chunked_attn(src, attn_mask) |
|
else: |
|
qkv = src.transpose(0, 1) |
|
attn_output, attn_weights = self.self_attn( |
|
qkv, |
|
qkv, |
|
qkv, |
|
attn_mask=attn_mask, |
|
need_weights=True, |
|
average_attn_weights=False, |
|
) |
|
attn_output = attn_output.transpose(0, 1) |
|
src = src + self.dropout1(attn_output) |
|
src = self.norm1(src) |
|
out = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(out) |
|
src = self.norm2(src) |
|
return src, attn_weights.detach() |
|
|
|
|
|
class ReversibleLoggingTransformerEncoderLayer(nn.Module): |
|
"""Reversible transformer encoder layer with checkpointing.""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
nhead: int, |
|
dim_feedforward: int = 512, |
|
dropout: float = 0.1, |
|
chunk_size: Optional[int] = None, |
|
overlap: int = 0, |
|
full_attn_logging: Optional[bool] = None, |
|
) -> None: |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
if full_attn_logging is None: |
|
full_attn_logging = False if chunk_size is not None else True |
|
self.full_attn_logging = full_attn_logging |
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.activation = F.relu |
|
|
|
@compile_fn |
|
def _sa_block( |
|
self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.chunk_size is not None: |
|
T, B, D = x.shape |
|
x_b = x.transpose(0, 1) |
|
C = self.chunk_size or T |
|
O = self.overlap |
|
n_chunks = (T + C - 1) // C |
|
pad_len = n_chunks * C - T |
|
src_pad = F.pad(x_b, (0, 0, O, pad_len + O)) |
|
chunk_len = C + 2 * O |
|
chunks = src_pad.unfold(1, chunk_len, C) |
|
mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None |
|
out, weights = self.self_attn( |
|
chunks.reshape(B * n_chunks, chunk_len, D), |
|
chunks.reshape(B * n_chunks, chunk_len, D), |
|
chunks.reshape(B * n_chunks, chunk_len, D), |
|
attn_mask=mask, |
|
need_weights=True, |
|
average_attn_weights=False, |
|
) |
|
out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C] |
|
weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[ |
|
:, :, :, O : O + C |
|
] |
|
seq = out.reshape(B, n_chunks * C, D)[:, :T] |
|
if self.full_attn_logging and C < T: |
|
full_attn = torch.zeros( |
|
B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device |
|
) |
|
for idx in range(n_chunks): |
|
s = idx * C |
|
start = max(s - O, 0) |
|
end = min(s + C, n_chunks * C) |
|
src_start = O - (s - start) |
|
src_end = src_start + (end - start) |
|
full_attn[:, :, s : s + C, start:end] = weights[ |
|
:, idx, :, src_start:src_end |
|
] |
|
full_attn = full_attn[:, :, :T, :T] |
|
weights = full_attn.detach() |
|
else: |
|
weights = torch.empty(0, device=x.device) |
|
attn_out = seq.transpose(0, 1) |
|
else: |
|
qkv = x.transpose(0, 1) |
|
attn_out, weights = self.self_attn( |
|
qkv, |
|
qkv, |
|
qkv, |
|
attn_mask=attn_mask, |
|
need_weights=True, |
|
average_attn_weights=False, |
|
) |
|
attn_out = attn_out.transpose(0, 1) |
|
x = self.norm1(x + self.dropout1(attn_out)) |
|
return x, weights.detach() |
|
|
|
@compile_fn |
|
def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
|
out = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
x = self.norm2(x + self.dropout2(out)) |
|
return x |
|
|
|
def forward( |
|
self, |
|
x1: torch.Tensor, |
|
x2: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
y1, weights = self._sa_block(x2, attn_mask) |
|
y1 = x1 + y1 |
|
y2 = x2 + self._ff_block(y1) |
|
return y1, y2, weights |
|
|
|
|
|
class BitTransformerLM(nn.Module): |
|
"""Transformer language model that operates on raw bits (0/1) with telemetry.""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int = 128, |
|
nhead: int = 8, |
|
num_layers: int = 4, |
|
dim_feedforward: int = 512, |
|
max_seq_len: int = 1024, |
|
lambda_K: float = 1.0, |
|
lambda_C: float = 1.0, |
|
lambda_S: float = 1.0, |
|
reversible: bool = False, |
|
use_checkpoint: bool = True, |
|
use_autocast: bool = False, |
|
use_act: bool = False, |
|
act_threshold: float = 0.9, |
|
chunk_size: Optional[int] = None, |
|
overlap: int = 0, |
|
full_attn_logging: Optional[bool] = None, |
|
) -> None: |
|
"""Create a BitTransformer language model. |
|
|
|
Args: |
|
full_attn_logging: When ``False`` and ``chunk_size`` is |
|
smaller than the sequence length, the model skips |
|
reconstructing the full ``TΓT`` attention matrices for |
|
telemetry to reduce memory use. |
|
""" |
|
super().__init__() |
|
self.d_model = d_model |
|
self.num_layers = num_layers |
|
self.lambda_K = lambda_K |
|
self.lambda_C = lambda_C |
|
self.lambda_S = lambda_S |
|
self.reversible = reversible |
|
self.use_checkpoint = use_checkpoint |
|
self.use_autocast = use_autocast |
|
self.use_act = use_act |
|
self.act_threshold = act_threshold |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
if full_attn_logging is None: |
|
full_attn_logging = False if chunk_size is not None else True |
|
self.full_attn_logging = full_attn_logging |
|
|
|
|
|
self.embedding = nn.Embedding(2, d_model) |
|
self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len) |
|
|
|
layer_cls = ( |
|
ReversibleLoggingTransformerEncoderLayer |
|
if reversible |
|
else LoggingTransformerEncoderLayer |
|
) |
|
self.layers = nn.ModuleList( |
|
[ |
|
layer_cls( |
|
d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=dim_feedforward, |
|
chunk_size=chunk_size, |
|
overlap=overlap, |
|
full_attn_logging=full_attn_logging, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
if self.use_act: |
|
self.halt_projs = nn.ModuleList( |
|
[nn.Linear(d_model, 1) for _ in range(num_layers)] |
|
) |
|
|
|
self.out_head = nn.Linear(d_model, 2) |
|
|
|
def expand_positional_encoding(self, new_len: int) -> None: |
|
"""Expand positional encoding to at least ``new_len``.""" |
|
cur_len = self.pos_enc.pe.size(0) |
|
if new_len <= cur_len: |
|
return |
|
device = self.pos_enc.pe.device |
|
d_model = self.d_model |
|
pe = torch.zeros(new_len, d_model, device=device) |
|
pe[:cur_len] = self.pos_enc.pe.squeeze(1) |
|
pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1) |
|
inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) |
|
pe[cur_len:, 0::2] = torch.sin(pos * inv) |
|
pe[cur_len:, 1::2] = torch.cos(pos * inv) |
|
self.pos_enc.pe = pe.unsqueeze(1) |
|
|
|
def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None: |
|
"""Update weighting coefficients for telemetry metrics.""" |
|
self.lambda_K = lambda_K |
|
self.lambda_C = lambda_C |
|
self.lambda_S = lambda_S |
|
|
|
def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Return raw bit sequences, decompressing if input appears run-length encoded.""" |
|
if codes.dim() <= 1: |
|
return codes |
|
needs_decompress = codes.max().item() > 1 |
|
if not needs_decompress and codes.size(1) % 2 == 0: |
|
vals = codes[:, 0::2] |
|
if torch.all(vals[:, 1:] != vals[:, :-1]): |
|
needs_decompress = True |
|
if not needs_decompress: |
|
return codes |
|
seqs = [decompress_bits(row.to(torch.uint8)) for row in codes] |
|
max_len = max(seq.numel() for seq in seqs) |
|
padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs] |
|
return torch.stack(padded) |
|
|
|
def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Approximate negentropy of bit sequences. |
|
|
|
Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered |
|
sequence (all zeros or ones) and ``0`` reflects maximal entropy. |
|
""" |
|
codes = self._maybe_decompress(codes) |
|
p = codes.float().mean(dim=1) |
|
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
|
max_e = math.log(2.0) |
|
return 1 - entropy / max_e |
|
|
|
def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Differentiable proxy for LempelβZiv complexity. |
|
|
|
Values near ``0`` indicate highly compressible sequences while values |
|
approaching ``1`` correspond to rapid bit alternation. |
|
""" |
|
codes = self._maybe_decompress(codes) |
|
diffs = torch.abs(codes[:, 1:] - codes[:, :-1]) |
|
return diffs.float().mean(dim=1) |
|
|
|
def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
|
"""Negentropy computed from model logits. |
|
|
|
Parameters |
|
---------- |
|
logits: ``torch.Tensor`` |
|
Logit tensor of shape ``(B, T, 2)``. |
|
detach: bool, default ``True`` |
|
When ``True`` the computation is detached from the autograd graph. |
|
""" |
|
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
|
prob = logits.softmax(-1) |
|
if detach: |
|
prob = prob.detach() |
|
p = prob[..., 1].mean(dim=1) |
|
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
|
max_e = math.log(2.0) |
|
return 1 - entropy / max_e |
|
|
|
def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
|
"""LZ complexity proxy computed from logits. |
|
|
|
Parameters |
|
---------- |
|
logits: ``torch.Tensor`` |
|
Logit tensor of shape ``(B, T, 2)``. |
|
detach: bool, default ``True`` |
|
When ``True`` the computation is detached from the autograd graph. |
|
""" |
|
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
|
prob = logits.softmax(-1) |
|
if detach: |
|
prob = prob.detach() |
|
prob1 = prob[..., 1] |
|
diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1]) |
|
return diffs.mean(dim=1) |
|
|
|
def symbiosis_kl_logits( |
|
self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True |
|
) -> torch.Tensor: |
|
"""Symbiosis score from KL divergence to a reference distribution. |
|
|
|
Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with |
|
the reference distribution and ``0`` indicating maximal divergence. |
|
""" |
|
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
|
probs = logits.softmax(-1) |
|
if detach: |
|
probs = probs.detach() |
|
ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device) |
|
kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1) |
|
max_kl = math.log(2.0) |
|
return 1 - kl / max_kl |
|
|
|
def _act_step( |
|
self, |
|
hidden: torch.Tensor, |
|
idx: int, |
|
halt_prob: torch.Tensor, |
|
act_state: torch.Tensor, |
|
halt_history: List[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor, bool]: |
|
"""Apply one step of ACT halting logic.""" |
|
p = torch.sigmoid(self.halt_projs[idx](hidden)) |
|
delta = (1 - halt_prob) * p |
|
halt_prob = halt_prob + delta |
|
act_state = act_state + hidden * delta |
|
halt_history.append(halt_prob.detach()) |
|
min_prob = halt_prob.detach().min() |
|
if dist.is_initialized(): |
|
dist.all_reduce(min_prob, op=dist.ReduceOp.MIN) |
|
return halt_prob, act_state, min_prob.item() >= self.act_threshold |
|
|
|
def forward( |
|
self, bit_seq: torch.Tensor, causal: bool = True |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Forward pass returning logits and telemetry from the same graph. |
|
|
|
By default the model uses causal masking and (optional) chunked |
|
attention. When ``causal`` is ``False`` the model operates in |
|
"Diffusion LM" mode. In this mode chunked attention is temporarily |
|
disabled so that every token can attend to the full sequence |
|
bidirectionally. The original chunking configuration is restored after |
|
the forward pass. |
|
""" |
|
|
|
|
|
orig_chunks = None |
|
orig_model_chunk = None |
|
if not causal and self.chunk_size is not None: |
|
orig_model_chunk = self.chunk_size |
|
orig_chunks = [layer.chunk_size for layer in self.layers] |
|
self.chunk_size = None |
|
for layer in self.layers: |
|
layer.chunk_size = None |
|
|
|
try: |
|
ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext() |
|
with ctx: |
|
x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model) |
|
x = self.pos_enc(x) |
|
|
|
attn_mask = get_tri_mask(x.size(0), x.device) if causal else None |
|
|
|
activations: List[torch.Tensor] = [] |
|
attn_maps: List[torch.Tensor] = [] |
|
halt_history: List[torch.Tensor] = [] |
|
if self.use_act: |
|
halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device) |
|
act_state = torch.zeros_like(x) |
|
if self.reversible: |
|
x1, x2 = x, x |
|
for idx, layer in enumerate(self.layers): |
|
if self.use_checkpoint: |
|
x1, x2, attn = checkpoint.checkpoint( |
|
layer, x1, x2, attn_mask |
|
) |
|
else: |
|
x1, x2, attn = layer(x1, x2, attn_mask) |
|
combined = (x1 + x2) / 2 |
|
activations.append(combined) |
|
if attn.numel() > 0: |
|
attn_maps.append(attn) |
|
if self.use_act: |
|
halt_prob, act_state, should_break = self._act_step( |
|
combined, idx, halt_prob, act_state, halt_history |
|
) |
|
if should_break: |
|
break |
|
x = (x1 + x2) / 2 |
|
else: |
|
for idx, layer in enumerate(self.layers): |
|
if self.use_checkpoint: |
|
x, attn = checkpoint.checkpoint(layer, x, attn_mask) |
|
else: |
|
x, attn = layer(x, attn_mask) |
|
activations.append(x) |
|
if attn.numel() > 0: |
|
attn_maps.append(attn) |
|
if self.use_act: |
|
halt_prob, act_state, should_break = self._act_step( |
|
x, idx, halt_prob, act_state, halt_history |
|
) |
|
if should_break: |
|
break |
|
if self.use_act: |
|
act_state = act_state + x * (1 - halt_prob) |
|
x = act_state |
|
logits = self.out_head(x) |
|
|
|
|
|
entropies = [] |
|
for act in activations: |
|
prob = act.softmax(-1) |
|
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean() |
|
entropies.append(ent) |
|
|
|
attn_entropies = [] |
|
for attn in attn_maps: |
|
prob = attn |
|
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1) |
|
ent = ent.mean(1) |
|
attn_entropies.append(ent) |
|
if attn_entropies: |
|
attn_entropy_map = torch.stack(attn_entropies).mean(0) |
|
else: |
|
attn_entropy_map = torch.zeros( |
|
bit_seq.size(0), bit_seq.size(1), device=bit_seq.device |
|
) |
|
max_ent = math.log(attn_entropy_map.size(-1)) |
|
attn_entropy_map = attn_entropy_map / max_ent |
|
attn_entropy = attn_entropy_map.mean(1) |
|
|
|
logits_bt = logits.transpose(0, 1) |
|
negentropy_in = self.negentropy_kpi(bit_seq) |
|
lz_in = self.lz_complexity(bit_seq.float()) |
|
negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False) |
|
lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False) |
|
kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False) |
|
|
|
raw_sym = ( |
|
(self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2 |
|
+ negentropy_logits_b * lz_logits_b |
|
- self.lambda_S * kl_div_b |
|
- 0.1 * attn_entropy |
|
) |
|
weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach() |
|
raw_sym = raw_sym - 0.01 * weight_norm |
|
sym_score = torch.sigmoid(raw_sym) |
|
|
|
B, T = bit_seq.shape |
|
assert logits_bt.shape[:2] == (B, T) |
|
assert attn_entropy_map.shape == (B, T) |
|
|
|
telemetry = { |
|
"activations": activations, |
|
"attention_maps": attn_maps, |
|
"attention_entropy": attn_entropy_map, |
|
"entropy": entropies, |
|
"attention_entropy_mean": attn_entropy, |
|
"negentropy_input": negentropy_in.detach(), |
|
"lz_complexity_input": lz_in.detach(), |
|
"negentropy_logits": negentropy_logits_b.detach(), |
|
"lz_complexity_logits": lz_logits_b.detach(), |
|
"symbiosis_kl": kl_div_b.detach(), |
|
"symbiosis_score": sym_score.detach(), |
|
} |
|
if self.use_act: |
|
telemetry["halt_probs"] = halt_history |
|
|
|
return logits_bt, telemetry |
|
finally: |
|
if orig_chunks is not None: |
|
self.chunk_size = orig_model_chunk |
|
for layer, chunk in zip(self.layers, orig_chunks): |
|
layer.chunk_size = chunk |
|
|
|
def forward_compressed( |
|
self, compressed_bits, causal: bool = True |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Decompress bit sequences then run the normal forward pass.""" |
|
if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1: |
|
sequences = [decompress_bits(compressed_bits).to(torch.long)] |
|
else: |
|
sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits] |
|
lengths = [seq.numel() for seq in sequences] |
|
if len(set(lengths)) != 1: |
|
raise ValueError("Sequences decompress to different lengths") |
|
bits = torch.stack(sequences) |
|
return self.forward(bits, causal=causal) |
|
|
|
def _current_params(self) -> Dict: |
|
"""Return a dictionary with the current model hyperparameters.""" |
|
return { |
|
"d_model": self.d_model, |
|
"nhead": self.layers[0].self_attn.num_heads, |
|
"num_layers": self.num_layers, |
|
"dim_feedforward": self.layers[0].linear1.out_features, |
|
"max_seq_len": self.pos_enc.pe.size(0), |
|
"lambda_K": self.lambda_K, |
|
"lambda_C": self.lambda_C, |
|
"lambda_S": self.lambda_S, |
|
"reversible": self.reversible, |
|
"use_checkpoint": self.use_checkpoint, |
|
"use_autocast": self.use_autocast, |
|
"use_act": self.use_act, |
|
"act_threshold": self.act_threshold, |
|
"chunk_size": self.chunk_size, |
|
"overlap": self.overlap, |
|
} |
|
|
|
def double_width(self) -> "BitTransformerLM": |
|
"""Return a copy of the model with doubled hidden size.""" |
|
from .scale import expand_model |
|
|
|
params = self._current_params() |
|
params["d_model"] *= 2 |
|
params["dim_feedforward"] *= 2 |
|
return expand_model(self, params) |
|
|
|
def double_layers(self) -> "BitTransformerLM": |
|
"""Return a copy of the model with twice as many layers.""" |
|
from .scale import expand_model |
|
|
|
params = self._current_params() |
|
params["num_layers"] *= 2 |
|
return expand_model(self, params) |
|
|
|
def double_length(self) -> "BitTransformerLM": |
|
"""Return a copy of the model with doubled maximum sequence length.""" |
|
from .scale import expand_model |
|
|
|
params = self._current_params() |
|
params["max_seq_len"] *= 2 |
|
params["chunk_size"] = params["max_seq_len"] |
|
return expand_model(self, params) |
|
|
|
def train_full_sequence( |
|
self, |
|
bits: torch.Tensor, |
|
*, |
|
ctx_bits: int = 4096, |
|
detach_every_n: int = 1_048_576, |
|
) -> float: |
|
"""Train on a long bit tensor using sliding windows. |
|
|
|
Parameters |
|
---------- |
|
bits: ``torch.Tensor`` |
|
1D tensor containing the full bit sequence. |
|
ctx_bits: int |
|
Size of the training context window. |
|
detach_every_n: int |
|
Interval in bits for optimizer updates and graph detachment. |
|
Returns |
|
------- |
|
float |
|
Mean loss over all windows. |
|
""" |
|
self.train() |
|
optimizer, scheduler = configure_optimizer( |
|
self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits) |
|
) |
|
accum = 0 |
|
total_loss = 0.0 |
|
count = 0 |
|
for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits): |
|
segment = bits[start : start + ctx_bits + 1].unsqueeze(0) |
|
logits, _ = self(segment) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = segment[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
loss.backward() |
|
accum += ctx_bits |
|
total_loss += loss.item() |
|
count += 1 |
|
if accum >= detach_every_n: |
|
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
accum = 0 |
|
if accum > 0: |
|
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
return total_loss / max(1, count) |
|
|
|
|
|
def infer_long_sequence( |
|
model: BitTransformerLM, |
|
bits: torch.Tensor, |
|
*, |
|
ctx_bits: int = 4096, |
|
overlap: int = 256, |
|
) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: |
|
"""Infer a long bit sequence using sliding windows with overlap.""" |
|
model.eval() |
|
device = next(model.parameters()).device |
|
bits = bits.to(device) |
|
step = ctx_bits - overlap |
|
outputs: List[torch.Tensor] = [] |
|
logs: List[Dict[str, torch.Tensor]] = [] |
|
for start in range(0, bits.numel(), step): |
|
window = bits[start : start + ctx_bits].unsqueeze(0) |
|
logits, tele = model(window, causal=True) |
|
pred = logits.argmax(-1).squeeze(0) |
|
outputs.append(pred) |
|
logs.append(tele) |
|
out = torch.cat(outputs)[: bits.numel()] |
|
return out, logs |
|
|
|
|
|
def diffusion_inference( |
|
model: BitTransformerLM, |
|
*, |
|
length: int, |
|
steps: int = 8, |
|
batch_size: int = 1, |
|
init_bits: Optional[torch.Tensor] = None, |
|
schedule: str = "linear", |
|
) -> torch.Tensor: |
|
"""Generate bit sequences using iterative denoising diffusion. |
|
|
|
Parameters |
|
---------- |
|
model: ``BitTransformerLM`` |
|
The model used for denoising. It is run in non-causal mode with |
|
chunked attention disabled, enabling full-context bidirectional |
|
attention. |
|
length: int |
|
Length of the bit sequences to generate. |
|
steps: int, default ``8`` |
|
Number of denoising iterations. More steps generally yield sharper |
|
samples at the cost of compute. |
|
batch_size: int, default ``1`` |
|
Number of sequences to generate in parallel. |
|
init_bits: ``torch.Tensor`` | ``None`` |
|
Optional initial noisy bits of shape ``(batch_size, length)``. When |
|
``None`` random noise is used. |
|
schedule: str, default ``"linear"`` |
|
Noise schedule for the denoising mask probability. Options are |
|
``"linear"``, ``"cosine"``, and ``"exp"``. |
|
|
|
Returns |
|
------- |
|
``torch.Tensor`` |
|
A tensor of shape ``(batch_size, length)`` containing generated bits. |
|
""" |
|
|
|
model.eval() |
|
device = next(model.parameters()).device |
|
if init_bits is None: |
|
bits = torch.randint(0, 2, (batch_size, length), device=device) |
|
else: |
|
bits = init_bits.to(device) |
|
if bits.shape != (batch_size, length): |
|
raise ValueError("init_bits must have shape (batch_size, length)") |
|
|
|
for step in range(steps): |
|
logits, _ = model(bits, causal=False) |
|
prob = logits.softmax(-1)[..., 1] |
|
t = (step + 1) / steps |
|
if schedule == "linear": |
|
mask_prob = 1.0 - t |
|
elif schedule == "cosine": |
|
mask_prob = math.cos(math.pi * t / 2) |
|
elif schedule == "exp": |
|
mask_prob = math.exp(-5 * t) |
|
else: |
|
raise ValueError(f"unknown schedule: {schedule}") |
|
mask = (torch.rand_like(bits.float()) < mask_prob).long() |
|
sampled = torch.bernoulli(prob).long() |
|
bits = torch.where(mask.bool(), sampled, bits) |
|
if bits.shape[-1] % 9 == 0: |
|
bits, corrections = enforce_parity(bits) |
|
if corrections: |
|
logging.info("Parity corrections applied: %d", corrections) |
|
try: |
|
from .safety import hil_safe_inference |
|
|
|
hil_safe_inference(model, bits, causal=False, strict=False) |
|
except RuntimeError as exc: |
|
logging.warning("Safety gate warning: %s", exc) |
|
return bits |
|
|
|
|
|
def example_usage() -> float: |
|
"""Run the example from the README and return the loss.""" |
|
B, L = 4, 16 |
|
model = BitTransformerLM( |
|
d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L |
|
) |
|
bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
|
logits, _ = model(bits) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = bits[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
return loss.item() |
|
|
|
|
|
def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]: |
|
"""Demonstrate a training step where metrics do not affect gradients.""" |
|
B, L = 4, 16 |
|
model = BitTransformerLM( |
|
d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L |
|
) |
|
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1) |
|
|
|
bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
|
logits, telemetry = model(bits) |
|
|
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = bits[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
|
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
return loss.item(), telemetry |
|
|
|
|
|
if __name__ == "__main__": |
|
loss, telemetry = example_training_step() |
|
print("Composite loss:", loss) |
|
print("Telemetry keys:", list(telemetry.keys())) |
|
|