WCNegentropy's picture
πŸš€ Final optimization: Update model.py with production-ready enhancements
cf1ded2 verified
raw
history blame
36.1 kB
import math
import contextlib
import logging
from typing import Dict, List, Tuple, Optional
import torch
import torch.distributed as dist
import sys
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from .torch_utils import cpu_autocast
from .optimization import configure_optimizer
from .compression import decompress_bits
from .parity import enforce_parity
_mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
_attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns
_MAX_CACHE_SIZE = 50 # Limit cache growth
def clear_cache():
"""Clear memory caches to prevent OOM in long sequences."""
global _mask_cache, _attention_cache
_mask_cache.clear()
_attention_cache.clear()
def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""Return or create a cached upper-triangular mask with memory management."""
key = (seq_len, device)
# Clear cache if it gets too large
if len(_mask_cache) > _MAX_CACHE_SIZE:
clear_cache()
if key not in _mask_cache:
_mask_cache[key] = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
)
return _mask_cache[key]
try: # torch.compile may not work on all Python versions
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
compile_fn = torch.compile
else:
raise RuntimeError
except Exception: # pragma: no cover - handle missing torch or unsupported version
def compile_fn(fn=None, **kwargs):
if fn is None:
return lambda f: f
return fn
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding."""
def __init__(self, d_model: int, max_len: int = 1024) -> None:
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
inv = torch.exp(
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(pos * inv)
pe[:, 1::2] = torch.cos(pos * inv)
self.register_buffer("pe", pe.unsqueeze(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add positional encoding to input tensor."""
return x + self.pe[: x.size(0)]
class LoggingTransformerEncoderLayer(nn.Module):
"""Transformer encoder layer that exposes attention weights.
It optionally performs chunked attention with a fixed window size.
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 512,
dropout: float = 0.1,
chunk_size: Optional[int] = None,
overlap: int = 0,
full_attn_logging: Optional[bool] = None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.chunk_size = chunk_size
self.overlap = overlap
if full_attn_logging is None:
full_attn_logging = False if chunk_size is not None else True
self.full_attn_logging = full_attn_logging
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = F.relu
def _chunked_attn(
self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform memory-efficient chunked self attention with overlap."""
T, B, D = src.shape
# Early return for small sequences
if T <= 128 or self.chunk_size is None or self.chunk_size >= T:
return self._full_attn(src, attn_mask)
src_b = src.transpose(0, 1) # [B, T, D]
C = self.chunk_size
O = self.overlap
n_chunks = (T + C - 1) // C
pad_len = n_chunks * C - T
# Process chunks with gradient checkpointing for memory efficiency
outputs = []
weights_list = []
# Use memory-efficient processing
with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
for chunk_idx in range(n_chunks):
start_idx = chunk_idx * C
end_idx = min(start_idx + C + 2 * O, T + O)
# Extract chunk with overlap
chunk_start = max(0, start_idx - O)
chunk_end = min(T, end_idx)
chunk = src_b[:, chunk_start:chunk_end]
# Pad if necessary
if chunk.size(1) < C + 2 * O:
pad_size = C + 2 * O - chunk.size(1)
chunk = F.pad(chunk, (0, 0, 0, pad_size))
chunk_len = chunk.size(1)
mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None
# Apply attention to chunk
out, weights = self.self_attn(
chunk, chunk, chunk,
attn_mask=mask,
need_weights=self.full_attn_logging,
average_attn_weights=False,
)
# Extract the core part (remove overlap)
core_start = O if chunk_idx > 0 else 0
core_end = core_start + min(C, T - start_idx)
outputs.append(out[:, core_start:core_end])
if self.full_attn_logging and weights is not None:
weights_list.append(weights[:, :, core_start:core_end])
# Clear intermediate tensors to save memory
del out, weights, chunk
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Concatenate outputs
seq = torch.cat(outputs, dim=1)
# Handle attention weights
if self.full_attn_logging and weights_list:
# Use sparse representation for large sequences
if T > 1024:
attn_out = torch.empty(0, device=src.device) # Skip full attention for very long sequences
else:
attn_out = torch.cat(weights_list, dim=2)
else:
attn_out = torch.empty(0, device=src.device)
return seq.transpose(0, 1), attn_out
def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Standard full attention for smaller sequences."""
qkv = src.transpose(0, 1)
attn_output, attn_weights = self.self_attn(
qkv, qkv, qkv,
attn_mask=attn_mask,
need_weights=True,
average_attn_weights=False,
)
return attn_output.transpose(0, 1), attn_weights
def forward(
self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return output and attention map."""
if self.chunk_size is not None:
attn_output, attn_weights = self._chunked_attn(src, attn_mask)
else:
qkv = src.transpose(0, 1)
attn_output, attn_weights = self.self_attn(
qkv,
qkv,
qkv,
attn_mask=attn_mask,
need_weights=True,
average_attn_weights=False,
)
attn_output = attn_output.transpose(0, 1)
src = src + self.dropout1(attn_output)
src = self.norm1(src)
out = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(out)
src = self.norm2(src)
return src, attn_weights.detach()
class ReversibleLoggingTransformerEncoderLayer(nn.Module):
"""Reversible transformer encoder layer with checkpointing."""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 512,
dropout: float = 0.1,
chunk_size: Optional[int] = None,
overlap: int = 0,
full_attn_logging: Optional[bool] = None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.chunk_size = chunk_size
self.overlap = overlap
if full_attn_logging is None:
full_attn_logging = False if chunk_size is not None else True
self.full_attn_logging = full_attn_logging
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = F.relu
@compile_fn
def _sa_block(
self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.chunk_size is not None:
T, B, D = x.shape
x_b = x.transpose(0, 1)
C = self.chunk_size or T
O = self.overlap
n_chunks = (T + C - 1) // C
pad_len = n_chunks * C - T
src_pad = F.pad(x_b, (0, 0, O, pad_len + O))
chunk_len = C + 2 * O
chunks = src_pad.unfold(1, chunk_len, C)
mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None
out, weights = self.self_attn(
chunks.reshape(B * n_chunks, chunk_len, D),
chunks.reshape(B * n_chunks, chunk_len, D),
chunks.reshape(B * n_chunks, chunk_len, D),
attn_mask=mask,
need_weights=True,
average_attn_weights=False,
)
out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
:, :, :, O : O + C
]
seq = out.reshape(B, n_chunks * C, D)[:, :T]
if self.full_attn_logging and C < T:
full_attn = torch.zeros(
B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device
)
for idx in range(n_chunks):
s = idx * C
start = max(s - O, 0)
end = min(s + C, n_chunks * C)
src_start = O - (s - start)
src_end = src_start + (end - start)
full_attn[:, :, s : s + C, start:end] = weights[
:, idx, :, src_start:src_end
]
full_attn = full_attn[:, :, :T, :T]
weights = full_attn.detach()
else:
weights = torch.empty(0, device=x.device)
attn_out = seq.transpose(0, 1)
else:
qkv = x.transpose(0, 1)
attn_out, weights = self.self_attn(
qkv,
qkv,
qkv,
attn_mask=attn_mask,
need_weights=True,
average_attn_weights=False,
)
attn_out = attn_out.transpose(0, 1)
x = self.norm1(x + self.dropout1(attn_out))
return x, weights.detach()
@compile_fn
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
out = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.norm2(x + self.dropout2(out))
return x
def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
y1, weights = self._sa_block(x2, attn_mask)
y1 = x1 + y1
y2 = x2 + self._ff_block(y1)
return y1, y2, weights
class BitTransformerLM(nn.Module):
"""Transformer language model that operates on raw bits (0/1) with telemetry."""
def __init__(
self,
d_model: int = 128,
nhead: int = 8,
num_layers: int = 4,
dim_feedforward: int = 512,
max_seq_len: int = 1024,
lambda_K: float = 1.0,
lambda_C: float = 1.0,
lambda_S: float = 1.0,
reversible: bool = False,
use_checkpoint: bool = True,
use_autocast: bool = False,
use_act: bool = False,
act_threshold: float = 0.9,
chunk_size: Optional[int] = None,
overlap: int = 0,
full_attn_logging: Optional[bool] = None,
) -> None:
"""Create a BitTransformer language model.
Args:
full_attn_logging: When ``False`` and ``chunk_size`` is
smaller than the sequence length, the model skips
reconstructing the full ``TΓ—T`` attention matrices for
telemetry to reduce memory use.
"""
super().__init__()
self.d_model = d_model
self.num_layers = num_layers
self.lambda_K = lambda_K
self.lambda_C = lambda_C
self.lambda_S = lambda_S
self.reversible = reversible
self.use_checkpoint = use_checkpoint
self.use_autocast = use_autocast
self.use_act = use_act
self.act_threshold = act_threshold
self.chunk_size = chunk_size
self.overlap = overlap
if full_attn_logging is None:
full_attn_logging = False if chunk_size is not None else True
self.full_attn_logging = full_attn_logging
# Bit embedding: two possible input values
self.embedding = nn.Embedding(2, d_model)
self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len)
layer_cls = (
ReversibleLoggingTransformerEncoderLayer
if reversible
else LoggingTransformerEncoderLayer
)
self.layers = nn.ModuleList(
[
layer_cls(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
chunk_size=chunk_size,
overlap=overlap,
full_attn_logging=full_attn_logging,
)
for _ in range(num_layers)
]
)
if self.use_act:
self.halt_projs = nn.ModuleList(
[nn.Linear(d_model, 1) for _ in range(num_layers)]
)
self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1
def expand_positional_encoding(self, new_len: int) -> None:
"""Expand positional encoding to at least ``new_len``."""
cur_len = self.pos_enc.pe.size(0)
if new_len <= cur_len:
return
device = self.pos_enc.pe.device
d_model = self.d_model
pe = torch.zeros(new_len, d_model, device=device)
pe[:cur_len] = self.pos_enc.pe.squeeze(1)
pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1)
inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model))
pe[cur_len:, 0::2] = torch.sin(pos * inv)
pe[cur_len:, 1::2] = torch.cos(pos * inv)
self.pos_enc.pe = pe.unsqueeze(1)
def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None:
"""Update weighting coefficients for telemetry metrics."""
self.lambda_K = lambda_K
self.lambda_C = lambda_C
self.lambda_S = lambda_S
def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor:
"""Return raw bit sequences, decompressing if input appears run-length encoded."""
if codes.dim() <= 1:
return codes
needs_decompress = codes.max().item() > 1
if not needs_decompress and codes.size(1) % 2 == 0:
vals = codes[:, 0::2]
if torch.all(vals[:, 1:] != vals[:, :-1]):
needs_decompress = True
if not needs_decompress:
return codes
seqs = [decompress_bits(row.to(torch.uint8)) for row in codes]
max_len = max(seq.numel() for seq in seqs)
padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs]
return torch.stack(padded)
def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor:
"""Approximate negentropy of bit sequences.
Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered
sequence (all zeros or ones) and ``0`` reflects maximal entropy.
"""
codes = self._maybe_decompress(codes)
p = codes.float().mean(dim=1)
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
max_e = math.log(2.0)
return 1 - entropy / max_e
def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor:
"""Differentiable proxy for Lempel–Ziv complexity.
Values near ``0`` indicate highly compressible sequences while values
approaching ``1`` correspond to rapid bit alternation.
"""
codes = self._maybe_decompress(codes)
diffs = torch.abs(codes[:, 1:] - codes[:, :-1])
return diffs.float().mean(dim=1)
def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
"""Negentropy computed from model logits.
Parameters
----------
logits: ``torch.Tensor``
Logit tensor of shape ``(B, T, 2)``.
detach: bool, default ``True``
When ``True`` the computation is detached from the autograd graph.
"""
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
prob = logits.softmax(-1)
if detach:
prob = prob.detach()
p = prob[..., 1].mean(dim=1)
entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
max_e = math.log(2.0)
return 1 - entropy / max_e
def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
"""LZ complexity proxy computed from logits.
Parameters
----------
logits: ``torch.Tensor``
Logit tensor of shape ``(B, T, 2)``.
detach: bool, default ``True``
When ``True`` the computation is detached from the autograd graph.
"""
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
prob = logits.softmax(-1)
if detach:
prob = prob.detach()
prob1 = prob[..., 1]
diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1])
return diffs.mean(dim=1)
def symbiosis_kl_logits(
self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True
) -> torch.Tensor:
"""Symbiosis score from KL divergence to a reference distribution.
Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with
the reference distribution and ``0`` indicating maximal divergence.
"""
assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
probs = logits.softmax(-1)
if detach:
probs = probs.detach()
ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device)
kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1)
max_kl = math.log(2.0)
return 1 - kl / max_kl
def _act_step(
self,
hidden: torch.Tensor,
idx: int,
halt_prob: torch.Tensor,
act_state: torch.Tensor,
halt_history: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, bool]:
"""Apply one step of ACT halting logic."""
p = torch.sigmoid(self.halt_projs[idx](hidden))
delta = (1 - halt_prob) * p
halt_prob = halt_prob + delta
act_state = act_state + hidden * delta
halt_history.append(halt_prob.detach())
min_prob = halt_prob.detach().min()
if dist.is_initialized():
dist.all_reduce(min_prob, op=dist.ReduceOp.MIN)
return halt_prob, act_state, min_prob.item() >= self.act_threshold
def forward(
self, bit_seq: torch.Tensor, causal: bool = True
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Forward pass returning logits and telemetry from the same graph.
By default the model uses causal masking and (optional) chunked
attention. When ``causal`` is ``False`` the model operates in
"Diffusion LM" mode. In this mode chunked attention is temporarily
disabled so that every token can attend to the full sequence
bidirectionally. The original chunking configuration is restored after
the forward pass.
"""
# Disable chunking when running in bidirectional (non-causal) mode
orig_chunks = None
orig_model_chunk = None
if not causal and self.chunk_size is not None:
orig_model_chunk = self.chunk_size
orig_chunks = [layer.chunk_size for layer in self.layers]
self.chunk_size = None
for layer in self.layers:
layer.chunk_size = None
try:
ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext()
with ctx:
x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model)
x = self.pos_enc(x)
attn_mask = get_tri_mask(x.size(0), x.device) if causal else None
activations: List[torch.Tensor] = []
attn_maps: List[torch.Tensor] = []
halt_history: List[torch.Tensor] = []
if self.use_act:
halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device)
act_state = torch.zeros_like(x)
if self.reversible:
x1, x2 = x, x
for idx, layer in enumerate(self.layers):
if self.use_checkpoint:
x1, x2, attn = checkpoint.checkpoint(
layer, x1, x2, attn_mask
)
else:
x1, x2, attn = layer(x1, x2, attn_mask)
combined = (x1 + x2) / 2
activations.append(combined)
if attn.numel() > 0:
attn_maps.append(attn)
if self.use_act:
halt_prob, act_state, should_break = self._act_step(
combined, idx, halt_prob, act_state, halt_history
)
if should_break:
break
x = (x1 + x2) / 2
else:
for idx, layer in enumerate(self.layers):
if self.use_checkpoint:
x, attn = checkpoint.checkpoint(layer, x, attn_mask)
else:
x, attn = layer(x, attn_mask)
activations.append(x)
if attn.numel() > 0:
attn_maps.append(attn)
if self.use_act:
halt_prob, act_state, should_break = self._act_step(
x, idx, halt_prob, act_state, halt_history
)
if should_break:
break
if self.use_act:
act_state = act_state + x * (1 - halt_prob)
x = act_state
logits = self.out_head(x)
# Per-layer entropy of activations
entropies = []
for act in activations:
prob = act.softmax(-1)
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean()
entropies.append(ent)
attn_entropies = []
for attn in attn_maps:
prob = attn # weights are already softmaxed
ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1)
ent = ent.mean(1)
attn_entropies.append(ent)
if attn_entropies:
attn_entropy_map = torch.stack(attn_entropies).mean(0)
else:
attn_entropy_map = torch.zeros(
bit_seq.size(0), bit_seq.size(1), device=bit_seq.device
)
max_ent = math.log(attn_entropy_map.size(-1))
attn_entropy_map = attn_entropy_map / max_ent
attn_entropy = attn_entropy_map.mean(1)
logits_bt = logits.transpose(0, 1)
negentropy_in = self.negentropy_kpi(bit_seq)
lz_in = self.lz_complexity(bit_seq.float())
negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False)
lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False)
kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False)
raw_sym = (
(self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2
+ negentropy_logits_b * lz_logits_b
- self.lambda_S * kl_div_b
- 0.1 * attn_entropy
)
weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach()
raw_sym = raw_sym - 0.01 * weight_norm
sym_score = torch.sigmoid(raw_sym)
B, T = bit_seq.shape
assert logits_bt.shape[:2] == (B, T)
assert attn_entropy_map.shape == (B, T)
telemetry = {
"activations": activations,
"attention_maps": attn_maps,
"attention_entropy": attn_entropy_map,
"entropy": entropies,
"attention_entropy_mean": attn_entropy,
"negentropy_input": negentropy_in.detach(),
"lz_complexity_input": lz_in.detach(),
"negentropy_logits": negentropy_logits_b.detach(),
"lz_complexity_logits": lz_logits_b.detach(),
"symbiosis_kl": kl_div_b.detach(),
"symbiosis_score": sym_score.detach(),
}
if self.use_act:
telemetry["halt_probs"] = halt_history
return logits_bt, telemetry
finally:
if orig_chunks is not None:
self.chunk_size = orig_model_chunk
for layer, chunk in zip(self.layers, orig_chunks):
layer.chunk_size = chunk
def forward_compressed(
self, compressed_bits, causal: bool = True
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Decompress bit sequences then run the normal forward pass."""
if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1:
sequences = [decompress_bits(compressed_bits).to(torch.long)]
else:
sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits]
lengths = [seq.numel() for seq in sequences]
if len(set(lengths)) != 1:
raise ValueError("Sequences decompress to different lengths")
bits = torch.stack(sequences)
return self.forward(bits, causal=causal)
def _current_params(self) -> Dict:
"""Return a dictionary with the current model hyperparameters."""
return {
"d_model": self.d_model,
"nhead": self.layers[0].self_attn.num_heads,
"num_layers": self.num_layers,
"dim_feedforward": self.layers[0].linear1.out_features,
"max_seq_len": self.pos_enc.pe.size(0),
"lambda_K": self.lambda_K,
"lambda_C": self.lambda_C,
"lambda_S": self.lambda_S,
"reversible": self.reversible,
"use_checkpoint": self.use_checkpoint,
"use_autocast": self.use_autocast,
"use_act": self.use_act,
"act_threshold": self.act_threshold,
"chunk_size": self.chunk_size,
"overlap": self.overlap,
}
def double_width(self) -> "BitTransformerLM":
"""Return a copy of the model with doubled hidden size."""
from .scale import expand_model
params = self._current_params()
params["d_model"] *= 2
params["dim_feedforward"] *= 2
return expand_model(self, params)
def double_layers(self) -> "BitTransformerLM":
"""Return a copy of the model with twice as many layers."""
from .scale import expand_model
params = self._current_params()
params["num_layers"] *= 2
return expand_model(self, params)
def double_length(self) -> "BitTransformerLM":
"""Return a copy of the model with doubled maximum sequence length."""
from .scale import expand_model
params = self._current_params()
params["max_seq_len"] *= 2
params["chunk_size"] = params["max_seq_len"]
return expand_model(self, params)
def train_full_sequence(
self,
bits: torch.Tensor,
*,
ctx_bits: int = 4096,
detach_every_n: int = 1_048_576,
) -> float:
"""Train on a long bit tensor using sliding windows.
Parameters
----------
bits: ``torch.Tensor``
1D tensor containing the full bit sequence.
ctx_bits: int
Size of the training context window.
detach_every_n: int
Interval in bits for optimizer updates and graph detachment.
Returns
-------
float
Mean loss over all windows.
"""
self.train()
optimizer, scheduler = configure_optimizer(
self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits)
)
accum = 0
total_loss = 0.0
count = 0
for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits):
segment = bits[start : start + ctx_bits + 1].unsqueeze(0)
logits, _ = self(segment)
pred = logits[:, :-1, :].reshape(-1, 2)
target = segment[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
loss.backward()
accum += ctx_bits
total_loss += loss.item()
count += 1
if accum >= detach_every_n:
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
accum = 0
if accum > 0:
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return total_loss / max(1, count)
def infer_long_sequence(
model: BitTransformerLM,
bits: torch.Tensor,
*,
ctx_bits: int = 4096,
overlap: int = 256,
) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]:
"""Infer a long bit sequence using sliding windows with overlap."""
model.eval()
device = next(model.parameters()).device
bits = bits.to(device)
step = ctx_bits - overlap
outputs: List[torch.Tensor] = []
logs: List[Dict[str, torch.Tensor]] = []
for start in range(0, bits.numel(), step):
window = bits[start : start + ctx_bits].unsqueeze(0)
logits, tele = model(window, causal=True)
pred = logits.argmax(-1).squeeze(0)
outputs.append(pred)
logs.append(tele)
out = torch.cat(outputs)[: bits.numel()]
return out, logs
def diffusion_inference(
model: BitTransformerLM,
*,
length: int,
steps: int = 8,
batch_size: int = 1,
init_bits: Optional[torch.Tensor] = None,
schedule: str = "linear",
) -> torch.Tensor:
"""Generate bit sequences using iterative denoising diffusion.
Parameters
----------
model: ``BitTransformerLM``
The model used for denoising. It is run in non-causal mode with
chunked attention disabled, enabling full-context bidirectional
attention.
length: int
Length of the bit sequences to generate.
steps: int, default ``8``
Number of denoising iterations. More steps generally yield sharper
samples at the cost of compute.
batch_size: int, default ``1``
Number of sequences to generate in parallel.
init_bits: ``torch.Tensor`` | ``None``
Optional initial noisy bits of shape ``(batch_size, length)``. When
``None`` random noise is used.
schedule: str, default ``"linear"``
Noise schedule for the denoising mask probability. Options are
``"linear"``, ``"cosine"``, and ``"exp"``.
Returns
-------
``torch.Tensor``
A tensor of shape ``(batch_size, length)`` containing generated bits.
"""
model.eval()
device = next(model.parameters()).device
if init_bits is None:
bits = torch.randint(0, 2, (batch_size, length), device=device)
else:
bits = init_bits.to(device)
if bits.shape != (batch_size, length):
raise ValueError("init_bits must have shape (batch_size, length)")
for step in range(steps):
logits, _ = model(bits, causal=False)
prob = logits.softmax(-1)[..., 1]
t = (step + 1) / steps
if schedule == "linear":
mask_prob = 1.0 - t
elif schedule == "cosine":
mask_prob = math.cos(math.pi * t / 2)
elif schedule == "exp":
mask_prob = math.exp(-5 * t)
else:
raise ValueError(f"unknown schedule: {schedule}")
mask = (torch.rand_like(bits.float()) < mask_prob).long()
sampled = torch.bernoulli(prob).long()
bits = torch.where(mask.bool(), sampled, bits)
if bits.shape[-1] % 9 == 0:
bits, corrections = enforce_parity(bits)
if corrections:
logging.info("Parity corrections applied: %d", corrections)
try:
from .safety import hil_safe_inference
hil_safe_inference(model, bits, causal=False, strict=False)
except RuntimeError as exc:
logging.warning("Safety gate warning: %s", exc)
return bits
def example_usage() -> float:
"""Run the example from the README and return the loss."""
B, L = 4, 16
model = BitTransformerLM(
d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L
)
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
logits, _ = model(bits)
pred = logits[:, :-1, :].reshape(-1, 2)
target = bits[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
return loss.item()
def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]:
"""Demonstrate a training step where metrics do not affect gradients."""
B, L = 4, 16
model = BitTransformerLM(
d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L
)
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1)
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
logits, telemetry = model(bits)
pred = logits[:, :-1, :].reshape(-1, 2)
target = bits[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return loss.item(), telemetry
if __name__ == "__main__":
loss, telemetry = example_training_step()
print("Composite loss:", loss)
print("Telemetry keys:", list(telemetry.keys()))