WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
4.85 kB
import logging
import time
import torch
from typing import Dict, Optional, Tuple
from .model import BitTransformerLM
class SafetyGate:
"""Exponential moving average safety gate with burn-in."""
def __init__(
self,
*,
c_floor: float = 0.3,
s_floor: float = 0.5,
decay: float = 0.9,
burn_in: int = 10,
) -> None:
self.c_floor = c_floor
self.s_floor = s_floor
self.decay = decay
self.burn_in = burn_in
self.step = 0
self._c_ema: Optional[float] = None
self._s_ema: Optional[float] = None
def should_trigger(self, c_val: float, s_val: float) -> bool:
"""Update EMA scores and check if gating should trigger."""
self.step += 1
if self._c_ema is None:
self._c_ema = c_val
self._s_ema = s_val
else:
self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val
self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val
if self.step <= self.burn_in:
return False
return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor
def hil_safe_inference(
model: BitTransformerLM,
bit_seq: torch.Tensor,
c_floor: float = 0.3,
s_floor: float = 0.5,
*,
causal: bool = True,
strict: bool = True,
gate: Optional[SafetyGate] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Run inference with telemetry gating.
Parameters
----------
model:
Model to run inference with.
bit_seq:
Input bit sequences.
c_floor, s_floor:
Minimum LZ complexity and symbiosis score required for safe output.
causal:
Whether to run the model in causal (autoregressive) mode. When ``False``
the model performs full-context Diffusion LM inference.
strict:
If ``False`` the function returns model outputs even when the floors are
not met instead of raising ``RuntimeError``.
gate:
Optional :class:`SafetyGate` that applies EMA smoothing and burn-in
before enforcing the floors.
"""
model.eval()
with torch.no_grad():
logits, telemetry = model(bit_seq, causal=causal)
c_val = float(telemetry["lz_complexity_logits"].mean().item())
s_val = float(telemetry["symbiosis_score"].mean().item())
c_val = max(0.0, min(1.0, c_val))
s_val = max(0.0, min(1.0, s_val))
if gate is not None:
triggered = gate.should_trigger(c_val, s_val)
else:
triggered = c_val <= c_floor or s_val <= s_floor
if strict and triggered:
raise RuntimeError(
f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}"
)
return logits.argmax(-1), telemetry
def demo_hil_safety() -> None:
"""Demonstrate gating on random bits."""
bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
try:
out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0)
print("Safe output bits:", out.squeeze(0).tolist())
except RuntimeError as e:
print("Gate triggered:", e)
def safe_sample_with_retry(
model: BitTransformerLM,
bit_seq: torch.Tensor,
c_floor: float = 0.3,
s_floor: float = 0.5,
*,
causal: bool = True,
max_retries: int = 3,
backoff: float = 0.1,
gate: Optional[SafetyGate] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Run :func:`hil_safe_inference` with automatic retries.
The helper retries failed safety checks by toggling diffusion mode and
refreshing the input bits. An exponential backoff is applied between
attempts and warnings are logged for each retry.
Parameters
----------
gate:
Optional :class:`SafetyGate` instance shared across retries to apply
EMA smoothing and burn-in.
Returns
-------
Tuple[torch.Tensor, Dict[str, torch.Tensor]]
The sampled bits and associated telemetry.
"""
for attempt in range(max_retries):
try:
return hil_safe_inference(
model,
bit_seq,
c_floor,
s_floor,
causal=causal,
strict=True,
gate=gate,
)
except RuntimeError as exc: # safety gate triggered
logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc)
if attempt >= max_retries - 1:
raise
time.sleep(backoff * (2 ** attempt))
causal = False # retry in diffusion mode
bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)