import logging import time import torch from typing import Dict, Optional, Tuple from .model import BitTransformerLM class SafetyGate: """Exponential moving average safety gate with burn-in.""" def __init__( self, *, c_floor: float = 0.3, s_floor: float = 0.5, decay: float = 0.9, burn_in: int = 10, ) -> None: self.c_floor = c_floor self.s_floor = s_floor self.decay = decay self.burn_in = burn_in self.step = 0 self._c_ema: Optional[float] = None self._s_ema: Optional[float] = None def should_trigger(self, c_val: float, s_val: float) -> bool: """Update EMA scores and check if gating should trigger.""" self.step += 1 if self._c_ema is None: self._c_ema = c_val self._s_ema = s_val else: self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val if self.step <= self.burn_in: return False return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor def hil_safe_inference( model: BitTransformerLM, bit_seq: torch.Tensor, c_floor: float = 0.3, s_floor: float = 0.5, *, causal: bool = True, strict: bool = True, gate: Optional[SafetyGate] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Run inference with telemetry gating. Parameters ---------- model: Model to run inference with. bit_seq: Input bit sequences. c_floor, s_floor: Minimum LZ complexity and symbiosis score required for safe output. causal: Whether to run the model in causal (autoregressive) mode. When ``False`` the model performs full-context Diffusion LM inference. strict: If ``False`` the function returns model outputs even when the floors are not met instead of raising ``RuntimeError``. gate: Optional :class:`SafetyGate` that applies EMA smoothing and burn-in before enforcing the floors. """ model.eval() with torch.no_grad(): logits, telemetry = model(bit_seq, causal=causal) c_val = float(telemetry["lz_complexity_logits"].mean().item()) s_val = float(telemetry["symbiosis_score"].mean().item()) c_val = max(0.0, min(1.0, c_val)) s_val = max(0.0, min(1.0, s_val)) if gate is not None: triggered = gate.should_trigger(c_val, s_val) else: triggered = c_val <= c_floor or s_val <= s_floor if strict and triggered: raise RuntimeError( f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}" ) return logits.argmax(-1), telemetry def demo_hil_safety() -> None: """Demonstrate gating on random bits.""" bits = torch.randint(0, 2, (1, 8), dtype=torch.long) model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) try: out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0) print("Safe output bits:", out.squeeze(0).tolist()) except RuntimeError as e: print("Gate triggered:", e) def safe_sample_with_retry( model: BitTransformerLM, bit_seq: torch.Tensor, c_floor: float = 0.3, s_floor: float = 0.5, *, causal: bool = True, max_retries: int = 3, backoff: float = 0.1, gate: Optional[SafetyGate] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Run :func:`hil_safe_inference` with automatic retries. The helper retries failed safety checks by toggling diffusion mode and refreshing the input bits. An exponential backoff is applied between attempts and warnings are logged for each retry. Parameters ---------- gate: Optional :class:`SafetyGate` instance shared across retries to apply EMA smoothing and burn-in. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] The sampled bits and associated telemetry. """ for attempt in range(max_retries): try: return hil_safe_inference( model, bit_seq, c_floor, s_floor, causal=causal, strict=True, gate=gate, ) except RuntimeError as exc: # safety gate triggered logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc) if attempt >= max_retries - 1: raise time.sleep(backoff * (2 ** attempt)) causal = False # retry in diffusion mode bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)