import math import contextlib import logging from typing import Dict, List, Tuple, Optional import torch import torch.distributed as dist import sys import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from .torch_utils import cpu_autocast from .optimization import configure_optimizer from .compression import decompress_bits from .parity import enforce_parity _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} _attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns _MAX_CACHE_SIZE = 50 # Limit cache growth def clear_cache(): """Clear memory caches to prevent OOM in long sequences.""" global _mask_cache, _attention_cache _mask_cache.clear() _attention_cache.clear() def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor: """Return or create a cached upper-triangular mask with memory management.""" key = (seq_len, device) # Clear cache if it gets too large if len(_mask_cache) > _MAX_CACHE_SIZE: clear_cache() if key not in _mask_cache: _mask_cache[key] = torch.triu( torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 ) return _mask_cache[key] try: # torch.compile may not work on all Python versions if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): compile_fn = torch.compile else: raise RuntimeError except Exception: # pragma: no cover - handle missing torch or unsupported version def compile_fn(fn=None, **kwargs): if fn is None: return lambda f: f return fn class PositionalEncoding(nn.Module): """Sinusoidal positional encoding.""" def __init__(self, d_model: int, max_len: int = 1024) -> None: super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) inv = torch.exp( torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(pos * inv) pe[:, 1::2] = torch.cos(pos * inv) self.register_buffer("pe", pe.unsqueeze(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding to input tensor.""" return x + self.pe[: x.size(0)] class LoggingTransformerEncoderLayer(nn.Module): """Transformer encoder layer that exposes attention weights. It optionally performs chunked attention with a fixed window size. """ def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 512, dropout: float = 0.1, chunk_size: Optional[int] = None, overlap: int = 0, full_attn_logging: Optional[bool] = None, ) -> None: super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) self.chunk_size = chunk_size self.overlap = overlap if full_attn_logging is None: full_attn_logging = False if chunk_size is not None else True self.full_attn_logging = full_attn_logging self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.relu def _chunked_attn( self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform memory-efficient chunked self attention with overlap.""" T, B, D = src.shape # Early return for small sequences if T <= 128 or self.chunk_size is None or self.chunk_size >= T: return self._full_attn(src, attn_mask) src_b = src.transpose(0, 1) # [B, T, D] C = self.chunk_size O = self.overlap n_chunks = (T + C - 1) // C pad_len = n_chunks * C - T # Process chunks with gradient checkpointing for memory efficiency outputs = [] weights_list = [] # Use memory-efficient processing with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): for chunk_idx in range(n_chunks): start_idx = chunk_idx * C end_idx = min(start_idx + C + 2 * O, T + O) # Extract chunk with overlap chunk_start = max(0, start_idx - O) chunk_end = min(T, end_idx) chunk = src_b[:, chunk_start:chunk_end] # Pad if necessary if chunk.size(1) < C + 2 * O: pad_size = C + 2 * O - chunk.size(1) chunk = F.pad(chunk, (0, 0, 0, pad_size)) chunk_len = chunk.size(1) mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None # Apply attention to chunk out, weights = self.self_attn( chunk, chunk, chunk, attn_mask=mask, need_weights=self.full_attn_logging, average_attn_weights=False, ) # Extract the core part (remove overlap) core_start = O if chunk_idx > 0 else 0 core_end = core_start + min(C, T - start_idx) outputs.append(out[:, core_start:core_end]) if self.full_attn_logging and weights is not None: weights_list.append(weights[:, :, core_start:core_end]) # Clear intermediate tensors to save memory del out, weights, chunk if torch.cuda.is_available(): torch.cuda.empty_cache() # Concatenate outputs seq = torch.cat(outputs, dim=1) # Handle attention weights if self.full_attn_logging and weights_list: # Use sparse representation for large sequences if T > 1024: attn_out = torch.empty(0, device=src.device) # Skip full attention for very long sequences else: attn_out = torch.cat(weights_list, dim=2) else: attn_out = torch.empty(0, device=src.device) return seq.transpose(0, 1), attn_out def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """Standard full attention for smaller sequences.""" qkv = src.transpose(0, 1) attn_output, attn_weights = self.self_attn( qkv, qkv, qkv, attn_mask=attn_mask, need_weights=True, average_attn_weights=False, ) return attn_output.transpose(0, 1), attn_weights def forward( self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Return output and attention map.""" if self.chunk_size is not None: attn_output, attn_weights = self._chunked_attn(src, attn_mask) else: qkv = src.transpose(0, 1) attn_output, attn_weights = self.self_attn( qkv, qkv, qkv, attn_mask=attn_mask, need_weights=True, average_attn_weights=False, ) attn_output = attn_output.transpose(0, 1) src = src + self.dropout1(attn_output) src = self.norm1(src) out = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(out) src = self.norm2(src) return src, attn_weights.detach() class ReversibleLoggingTransformerEncoderLayer(nn.Module): """Reversible transformer encoder layer with checkpointing.""" def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 512, dropout: float = 0.1, chunk_size: Optional[int] = None, overlap: int = 0, full_attn_logging: Optional[bool] = None, ) -> None: super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) self.chunk_size = chunk_size self.overlap = overlap if full_attn_logging is None: full_attn_logging = False if chunk_size is not None else True self.full_attn_logging = full_attn_logging self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.relu @compile_fn def _sa_block( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: if self.chunk_size is not None: T, B, D = x.shape x_b = x.transpose(0, 1) C = self.chunk_size or T O = self.overlap n_chunks = (T + C - 1) // C pad_len = n_chunks * C - T src_pad = F.pad(x_b, (0, 0, O, pad_len + O)) chunk_len = C + 2 * O chunks = src_pad.unfold(1, chunk_len, C) mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None out, weights = self.self_attn( chunks.reshape(B * n_chunks, chunk_len, D), chunks.reshape(B * n_chunks, chunk_len, D), chunks.reshape(B * n_chunks, chunk_len, D), attn_mask=mask, need_weights=True, average_attn_weights=False, ) out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C] weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[ :, :, :, O : O + C ] seq = out.reshape(B, n_chunks * C, D)[:, :T] if self.full_attn_logging and C < T: full_attn = torch.zeros( B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device ) for idx in range(n_chunks): s = idx * C start = max(s - O, 0) end = min(s + C, n_chunks * C) src_start = O - (s - start) src_end = src_start + (end - start) full_attn[:, :, s : s + C, start:end] = weights[ :, idx, :, src_start:src_end ] full_attn = full_attn[:, :, :T, :T] weights = full_attn.detach() else: weights = torch.empty(0, device=x.device) attn_out = seq.transpose(0, 1) else: qkv = x.transpose(0, 1) attn_out, weights = self.self_attn( qkv, qkv, qkv, attn_mask=attn_mask, need_weights=True, average_attn_weights=False, ) attn_out = attn_out.transpose(0, 1) x = self.norm1(x + self.dropout1(attn_out)) return x, weights.detach() @compile_fn def _ff_block(self, x: torch.Tensor) -> torch.Tensor: out = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = self.norm2(x + self.dropout2(out)) return x def forward( self, x1: torch.Tensor, x2: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: y1, weights = self._sa_block(x2, attn_mask) y1 = x1 + y1 y2 = x2 + self._ff_block(y1) return y1, y2, weights class BitTransformerLM(nn.Module): """Transformer language model that operates on raw bits (0/1) with telemetry.""" def __init__( self, d_model: int = 128, nhead: int = 8, num_layers: int = 4, dim_feedforward: int = 512, max_seq_len: int = 1024, lambda_K: float = 1.0, lambda_C: float = 1.0, lambda_S: float = 1.0, reversible: bool = False, use_checkpoint: bool = True, use_autocast: bool = False, use_act: bool = False, act_threshold: float = 0.9, chunk_size: Optional[int] = None, overlap: int = 0, full_attn_logging: Optional[bool] = None, ) -> None: """Create a BitTransformer language model. Args: full_attn_logging: When ``False`` and ``chunk_size`` is smaller than the sequence length, the model skips reconstructing the full ``T×T`` attention matrices for telemetry to reduce memory use. """ super().__init__() self.d_model = d_model self.num_layers = num_layers self.lambda_K = lambda_K self.lambda_C = lambda_C self.lambda_S = lambda_S self.reversible = reversible self.use_checkpoint = use_checkpoint self.use_autocast = use_autocast self.use_act = use_act self.act_threshold = act_threshold self.chunk_size = chunk_size self.overlap = overlap if full_attn_logging is None: full_attn_logging = False if chunk_size is not None else True self.full_attn_logging = full_attn_logging # Bit embedding: two possible input values self.embedding = nn.Embedding(2, d_model) self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len) layer_cls = ( ReversibleLoggingTransformerEncoderLayer if reversible else LoggingTransformerEncoderLayer ) self.layers = nn.ModuleList( [ layer_cls( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, chunk_size=chunk_size, overlap=overlap, full_attn_logging=full_attn_logging, ) for _ in range(num_layers) ] ) if self.use_act: self.halt_projs = nn.ModuleList( [nn.Linear(d_model, 1) for _ in range(num_layers)] ) self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1 def expand_positional_encoding(self, new_len: int) -> None: """Expand positional encoding to at least ``new_len``.""" cur_len = self.pos_enc.pe.size(0) if new_len <= cur_len: return device = self.pos_enc.pe.device d_model = self.d_model pe = torch.zeros(new_len, d_model, device=device) pe[:cur_len] = self.pos_enc.pe.squeeze(1) pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1) inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) pe[cur_len:, 0::2] = torch.sin(pos * inv) pe[cur_len:, 1::2] = torch.cos(pos * inv) self.pos_enc.pe = pe.unsqueeze(1) def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None: """Update weighting coefficients for telemetry metrics.""" self.lambda_K = lambda_K self.lambda_C = lambda_C self.lambda_S = lambda_S def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor: """Return raw bit sequences, decompressing if input appears run-length encoded.""" if codes.dim() <= 1: return codes needs_decompress = codes.max().item() > 1 if not needs_decompress and codes.size(1) % 2 == 0: vals = codes[:, 0::2] if torch.all(vals[:, 1:] != vals[:, :-1]): needs_decompress = True if not needs_decompress: return codes seqs = [decompress_bits(row.to(torch.uint8)) for row in codes] max_len = max(seq.numel() for seq in seqs) padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs] return torch.stack(padded) def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor: """Approximate negentropy of bit sequences. Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered sequence (all zeros or ones) and ``0`` reflects maximal entropy. """ codes = self._maybe_decompress(codes) p = codes.float().mean(dim=1) entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) max_e = math.log(2.0) return 1 - entropy / max_e def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor: """Differentiable proxy for Lempel–Ziv complexity. Values near ``0`` indicate highly compressible sequences while values approaching ``1`` correspond to rapid bit alternation. """ codes = self._maybe_decompress(codes) diffs = torch.abs(codes[:, 1:] - codes[:, :-1]) return diffs.float().mean(dim=1) def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: """Negentropy computed from model logits. Parameters ---------- logits: ``torch.Tensor`` Logit tensor of shape ``(B, T, 2)``. detach: bool, default ``True`` When ``True`` the computation is detached from the autograd graph. """ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" prob = logits.softmax(-1) if detach: prob = prob.detach() p = prob[..., 1].mean(dim=1) entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) max_e = math.log(2.0) return 1 - entropy / max_e def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: """LZ complexity proxy computed from logits. Parameters ---------- logits: ``torch.Tensor`` Logit tensor of shape ``(B, T, 2)``. detach: bool, default ``True`` When ``True`` the computation is detached from the autograd graph. """ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" prob = logits.softmax(-1) if detach: prob = prob.detach() prob1 = prob[..., 1] diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1]) return diffs.mean(dim=1) def symbiosis_kl_logits( self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True ) -> torch.Tensor: """Symbiosis score from KL divergence to a reference distribution. Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with the reference distribution and ``0`` indicating maximal divergence. """ assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" probs = logits.softmax(-1) if detach: probs = probs.detach() ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device) kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1) max_kl = math.log(2.0) return 1 - kl / max_kl def _act_step( self, hidden: torch.Tensor, idx: int, halt_prob: torch.Tensor, act_state: torch.Tensor, halt_history: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, bool]: """Apply one step of ACT halting logic.""" p = torch.sigmoid(self.halt_projs[idx](hidden)) delta = (1 - halt_prob) * p halt_prob = halt_prob + delta act_state = act_state + hidden * delta halt_history.append(halt_prob.detach()) min_prob = halt_prob.detach().min() if dist.is_initialized(): dist.all_reduce(min_prob, op=dist.ReduceOp.MIN) return halt_prob, act_state, min_prob.item() >= self.act_threshold def forward( self, bit_seq: torch.Tensor, causal: bool = True ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Forward pass returning logits and telemetry from the same graph. By default the model uses causal masking and (optional) chunked attention. When ``causal`` is ``False`` the model operates in "Diffusion LM" mode. In this mode chunked attention is temporarily disabled so that every token can attend to the full sequence bidirectionally. The original chunking configuration is restored after the forward pass. """ # Disable chunking when running in bidirectional (non-causal) mode orig_chunks = None orig_model_chunk = None if not causal and self.chunk_size is not None: orig_model_chunk = self.chunk_size orig_chunks = [layer.chunk_size for layer in self.layers] self.chunk_size = None for layer in self.layers: layer.chunk_size = None try: ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext() with ctx: x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model) x = self.pos_enc(x) attn_mask = get_tri_mask(x.size(0), x.device) if causal else None activations: List[torch.Tensor] = [] attn_maps: List[torch.Tensor] = [] halt_history: List[torch.Tensor] = [] if self.use_act: halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device) act_state = torch.zeros_like(x) if self.reversible: x1, x2 = x, x for idx, layer in enumerate(self.layers): if self.use_checkpoint: x1, x2, attn = checkpoint.checkpoint( layer, x1, x2, attn_mask ) else: x1, x2, attn = layer(x1, x2, attn_mask) combined = (x1 + x2) / 2 activations.append(combined) if attn.numel() > 0: attn_maps.append(attn) if self.use_act: halt_prob, act_state, should_break = self._act_step( combined, idx, halt_prob, act_state, halt_history ) if should_break: break x = (x1 + x2) / 2 else: for idx, layer in enumerate(self.layers): if self.use_checkpoint: x, attn = checkpoint.checkpoint(layer, x, attn_mask) else: x, attn = layer(x, attn_mask) activations.append(x) if attn.numel() > 0: attn_maps.append(attn) if self.use_act: halt_prob, act_state, should_break = self._act_step( x, idx, halt_prob, act_state, halt_history ) if should_break: break if self.use_act: act_state = act_state + x * (1 - halt_prob) x = act_state logits = self.out_head(x) # Per-layer entropy of activations entropies = [] for act in activations: prob = act.softmax(-1) ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean() entropies.append(ent) attn_entropies = [] for attn in attn_maps: prob = attn # weights are already softmaxed ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1) ent = ent.mean(1) attn_entropies.append(ent) if attn_entropies: attn_entropy_map = torch.stack(attn_entropies).mean(0) else: attn_entropy_map = torch.zeros( bit_seq.size(0), bit_seq.size(1), device=bit_seq.device ) max_ent = math.log(attn_entropy_map.size(-1)) attn_entropy_map = attn_entropy_map / max_ent attn_entropy = attn_entropy_map.mean(1) logits_bt = logits.transpose(0, 1) negentropy_in = self.negentropy_kpi(bit_seq) lz_in = self.lz_complexity(bit_seq.float()) negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False) lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False) kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False) raw_sym = ( (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2 + negentropy_logits_b * lz_logits_b - self.lambda_S * kl_div_b - 0.1 * attn_entropy ) weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach() raw_sym = raw_sym - 0.01 * weight_norm sym_score = torch.sigmoid(raw_sym) B, T = bit_seq.shape assert logits_bt.shape[:2] == (B, T) assert attn_entropy_map.shape == (B, T) telemetry = { "activations": activations, "attention_maps": attn_maps, "attention_entropy": attn_entropy_map, "entropy": entropies, "attention_entropy_mean": attn_entropy, "negentropy_input": negentropy_in.detach(), "lz_complexity_input": lz_in.detach(), "negentropy_logits": negentropy_logits_b.detach(), "lz_complexity_logits": lz_logits_b.detach(), "symbiosis_kl": kl_div_b.detach(), "symbiosis_score": sym_score.detach(), } if self.use_act: telemetry["halt_probs"] = halt_history return logits_bt, telemetry finally: if orig_chunks is not None: self.chunk_size = orig_model_chunk for layer, chunk in zip(self.layers, orig_chunks): layer.chunk_size = chunk def forward_compressed( self, compressed_bits, causal: bool = True ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Decompress bit sequences then run the normal forward pass.""" if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1: sequences = [decompress_bits(compressed_bits).to(torch.long)] else: sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits] lengths = [seq.numel() for seq in sequences] if len(set(lengths)) != 1: raise ValueError("Sequences decompress to different lengths") bits = torch.stack(sequences) return self.forward(bits, causal=causal) def _current_params(self) -> Dict: """Return a dictionary with the current model hyperparameters.""" return { "d_model": self.d_model, "nhead": self.layers[0].self_attn.num_heads, "num_layers": self.num_layers, "dim_feedforward": self.layers[0].linear1.out_features, "max_seq_len": self.pos_enc.pe.size(0), "lambda_K": self.lambda_K, "lambda_C": self.lambda_C, "lambda_S": self.lambda_S, "reversible": self.reversible, "use_checkpoint": self.use_checkpoint, "use_autocast": self.use_autocast, "use_act": self.use_act, "act_threshold": self.act_threshold, "chunk_size": self.chunk_size, "overlap": self.overlap, } def double_width(self) -> "BitTransformerLM": """Return a copy of the model with doubled hidden size.""" from .scale import expand_model params = self._current_params() params["d_model"] *= 2 params["dim_feedforward"] *= 2 return expand_model(self, params) def double_layers(self) -> "BitTransformerLM": """Return a copy of the model with twice as many layers.""" from .scale import expand_model params = self._current_params() params["num_layers"] *= 2 return expand_model(self, params) def double_length(self) -> "BitTransformerLM": """Return a copy of the model with doubled maximum sequence length.""" from .scale import expand_model params = self._current_params() params["max_seq_len"] *= 2 params["chunk_size"] = params["max_seq_len"] return expand_model(self, params) def train_full_sequence( self, bits: torch.Tensor, *, ctx_bits: int = 4096, detach_every_n: int = 1_048_576, ) -> float: """Train on a long bit tensor using sliding windows. Parameters ---------- bits: ``torch.Tensor`` 1D tensor containing the full bit sequence. ctx_bits: int Size of the training context window. detach_every_n: int Interval in bits for optimizer updates and graph detachment. Returns ------- float Mean loss over all windows. """ self.train() optimizer, scheduler = configure_optimizer( self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits) ) accum = 0 total_loss = 0.0 count = 0 for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits): segment = bits[start : start + ctx_bits + 1].unsqueeze(0) logits, _ = self(segment) pred = logits[:, :-1, :].reshape(-1, 2) target = segment[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) loss.backward() accum += ctx_bits total_loss += loss.item() count += 1 if accum >= detach_every_n: torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() accum = 0 if accum > 0: torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return total_loss / max(1, count) def infer_long_sequence( model: BitTransformerLM, bits: torch.Tensor, *, ctx_bits: int = 4096, overlap: int = 256, ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: """Infer a long bit sequence using sliding windows with overlap.""" model.eval() device = next(model.parameters()).device bits = bits.to(device) step = ctx_bits - overlap outputs: List[torch.Tensor] = [] logs: List[Dict[str, torch.Tensor]] = [] for start in range(0, bits.numel(), step): window = bits[start : start + ctx_bits].unsqueeze(0) logits, tele = model(window, causal=True) pred = logits.argmax(-1).squeeze(0) outputs.append(pred) logs.append(tele) out = torch.cat(outputs)[: bits.numel()] return out, logs def diffusion_inference( model: BitTransformerLM, *, length: int, steps: int = 8, batch_size: int = 1, init_bits: Optional[torch.Tensor] = None, schedule: str = "linear", ) -> torch.Tensor: """Generate bit sequences using iterative denoising diffusion. Parameters ---------- model: ``BitTransformerLM`` The model used for denoising. It is run in non-causal mode with chunked attention disabled, enabling full-context bidirectional attention. length: int Length of the bit sequences to generate. steps: int, default ``8`` Number of denoising iterations. More steps generally yield sharper samples at the cost of compute. batch_size: int, default ``1`` Number of sequences to generate in parallel. init_bits: ``torch.Tensor`` | ``None`` Optional initial noisy bits of shape ``(batch_size, length)``. When ``None`` random noise is used. schedule: str, default ``"linear"`` Noise schedule for the denoising mask probability. Options are ``"linear"``, ``"cosine"``, and ``"exp"``. Returns ------- ``torch.Tensor`` A tensor of shape ``(batch_size, length)`` containing generated bits. """ model.eval() device = next(model.parameters()).device if init_bits is None: bits = torch.randint(0, 2, (batch_size, length), device=device) else: bits = init_bits.to(device) if bits.shape != (batch_size, length): raise ValueError("init_bits must have shape (batch_size, length)") for step in range(steps): logits, _ = model(bits, causal=False) prob = logits.softmax(-1)[..., 1] t = (step + 1) / steps if schedule == "linear": mask_prob = 1.0 - t elif schedule == "cosine": mask_prob = math.cos(math.pi * t / 2) elif schedule == "exp": mask_prob = math.exp(-5 * t) else: raise ValueError(f"unknown schedule: {schedule}") mask = (torch.rand_like(bits.float()) < mask_prob).long() sampled = torch.bernoulli(prob).long() bits = torch.where(mask.bool(), sampled, bits) if bits.shape[-1] % 9 == 0: bits, corrections = enforce_parity(bits) if corrections: logging.info("Parity corrections applied: %d", corrections) try: from .safety import hil_safe_inference hil_safe_inference(model, bits, causal=False, strict=False) except RuntimeError as exc: logging.warning("Safety gate warning: %s", exc) return bits def example_usage() -> float: """Run the example from the README and return the loss.""" B, L = 4, 16 model = BitTransformerLM( d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L ) bits = torch.randint(0, 2, (B, L), dtype=torch.long) logits, _ = model(bits) pred = logits[:, :-1, :].reshape(-1, 2) target = bits[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) return loss.item() def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]: """Demonstrate a training step where metrics do not affect gradients.""" B, L = 4, 16 model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L ) optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1) bits = torch.randint(0, 2, (B, L), dtype=torch.long) logits, telemetry = model(bits) pred = logits[:, :-1, :].reshape(-1, 2) target = bits[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return loss.item(), telemetry if __name__ == "__main__": loss, telemetry = example_training_step() print("Composite loss:", loss) print("Telemetry keys:", list(telemetry.keys()))