|
import logging |
|
import time |
|
import torch |
|
from typing import Dict, Optional, Tuple |
|
|
|
from .model import BitTransformerLM |
|
|
|
|
|
class SafetyGate: |
|
"""Exponential moving average safety gate with burn-in.""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
c_floor: float = 0.3, |
|
s_floor: float = 0.5, |
|
decay: float = 0.9, |
|
burn_in: int = 10, |
|
) -> None: |
|
self.c_floor = c_floor |
|
self.s_floor = s_floor |
|
self.decay = decay |
|
self.burn_in = burn_in |
|
self.step = 0 |
|
self._c_ema: Optional[float] = None |
|
self._s_ema: Optional[float] = None |
|
|
|
def should_trigger(self, c_val: float, s_val: float) -> bool: |
|
"""Update EMA scores and check if gating should trigger.""" |
|
|
|
self.step += 1 |
|
if self._c_ema is None: |
|
self._c_ema = c_val |
|
self._s_ema = s_val |
|
else: |
|
self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val |
|
self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val |
|
if self.step <= self.burn_in: |
|
return False |
|
return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor |
|
|
|
|
|
def hil_safe_inference( |
|
model: BitTransformerLM, |
|
bit_seq: torch.Tensor, |
|
c_floor: float = 0.3, |
|
s_floor: float = 0.5, |
|
*, |
|
causal: bool = True, |
|
strict: bool = True, |
|
gate: Optional[SafetyGate] = None, |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Run inference with telemetry gating. |
|
|
|
Parameters |
|
---------- |
|
model: |
|
Model to run inference with. |
|
bit_seq: |
|
Input bit sequences. |
|
c_floor, s_floor: |
|
Minimum LZ complexity and symbiosis score required for safe output. |
|
causal: |
|
Whether to run the model in causal (autoregressive) mode. When ``False`` |
|
the model performs full-context Diffusion LM inference. |
|
strict: |
|
If ``False`` the function returns model outputs even when the floors are |
|
not met instead of raising ``RuntimeError``. |
|
gate: |
|
Optional :class:`SafetyGate` that applies EMA smoothing and burn-in |
|
before enforcing the floors. |
|
""" |
|
model.eval() |
|
with torch.no_grad(): |
|
logits, telemetry = model(bit_seq, causal=causal) |
|
c_val = float(telemetry["lz_complexity_logits"].mean().item()) |
|
s_val = float(telemetry["symbiosis_score"].mean().item()) |
|
c_val = max(0.0, min(1.0, c_val)) |
|
s_val = max(0.0, min(1.0, s_val)) |
|
if gate is not None: |
|
triggered = gate.should_trigger(c_val, s_val) |
|
else: |
|
triggered = c_val <= c_floor or s_val <= s_floor |
|
if strict and triggered: |
|
raise RuntimeError( |
|
f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}" |
|
) |
|
return logits.argmax(-1), telemetry |
|
|
|
|
|
def demo_hil_safety() -> None: |
|
"""Demonstrate gating on random bits.""" |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
try: |
|
out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0) |
|
print("Safe output bits:", out.squeeze(0).tolist()) |
|
except RuntimeError as e: |
|
print("Gate triggered:", e) |
|
|
|
|
|
def safe_sample_with_retry( |
|
model: BitTransformerLM, |
|
bit_seq: torch.Tensor, |
|
c_floor: float = 0.3, |
|
s_floor: float = 0.5, |
|
*, |
|
causal: bool = True, |
|
max_retries: int = 3, |
|
backoff: float = 0.1, |
|
gate: Optional[SafetyGate] = None, |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
"""Run :func:`hil_safe_inference` with automatic retries. |
|
|
|
The helper retries failed safety checks by toggling diffusion mode and |
|
refreshing the input bits. An exponential backoff is applied between |
|
attempts and warnings are logged for each retry. |
|
|
|
Parameters |
|
---------- |
|
gate: |
|
Optional :class:`SafetyGate` instance shared across retries to apply |
|
EMA smoothing and burn-in. |
|
|
|
Returns |
|
------- |
|
Tuple[torch.Tensor, Dict[str, torch.Tensor]] |
|
The sampled bits and associated telemetry. |
|
""" |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
return hil_safe_inference( |
|
model, |
|
bit_seq, |
|
c_floor, |
|
s_floor, |
|
causal=causal, |
|
strict=True, |
|
gate=gate, |
|
) |
|
except RuntimeError as exc: |
|
logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc) |
|
if attempt >= max_retries - 1: |
|
raise |
|
time.sleep(backoff * (2 ** attempt)) |
|
causal = False |
|
bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device) |
|
|
|
|