File size: 4,852 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import logging
import time
import torch
from typing import Dict, Optional, Tuple

from .model import BitTransformerLM


class SafetyGate:
    """Exponential moving average safety gate with burn-in."""

    def __init__(
        self,
        *,
        c_floor: float = 0.3,
        s_floor: float = 0.5,
        decay: float = 0.9,
        burn_in: int = 10,
    ) -> None:
        self.c_floor = c_floor
        self.s_floor = s_floor
        self.decay = decay
        self.burn_in = burn_in
        self.step = 0
        self._c_ema: Optional[float] = None
        self._s_ema: Optional[float] = None

    def should_trigger(self, c_val: float, s_val: float) -> bool:
        """Update EMA scores and check if gating should trigger."""

        self.step += 1
        if self._c_ema is None:
            self._c_ema = c_val
            self._s_ema = s_val
        else:
            self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val
            self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val
        if self.step <= self.burn_in:
            return False
        return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor


def hil_safe_inference(
    model: BitTransformerLM,
    bit_seq: torch.Tensor,
    c_floor: float = 0.3,
    s_floor: float = 0.5,
    *,
    causal: bool = True,
    strict: bool = True,
    gate: Optional[SafetyGate] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """Run inference with telemetry gating.

    Parameters
    ----------
    model:
        Model to run inference with.
    bit_seq:
        Input bit sequences.
    c_floor, s_floor:
        Minimum LZ complexity and symbiosis score required for safe output.
    causal:
        Whether to run the model in causal (autoregressive) mode. When ``False``
        the model performs full-context Diffusion LM inference.
    strict:
        If ``False`` the function returns model outputs even when the floors are
        not met instead of raising ``RuntimeError``.
    gate:
        Optional :class:`SafetyGate` that applies EMA smoothing and burn-in
        before enforcing the floors.
    """
    model.eval()
    with torch.no_grad():
        logits, telemetry = model(bit_seq, causal=causal)
        c_val = float(telemetry["lz_complexity_logits"].mean().item())
        s_val = float(telemetry["symbiosis_score"].mean().item())
        c_val = max(0.0, min(1.0, c_val))
        s_val = max(0.0, min(1.0, s_val))
        if gate is not None:
            triggered = gate.should_trigger(c_val, s_val)
        else:
            triggered = c_val <= c_floor or s_val <= s_floor
        if strict and triggered:
            raise RuntimeError(
                f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}"
            )
        return logits.argmax(-1), telemetry


def demo_hil_safety() -> None:
    """Demonstrate gating on random bits."""
    bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
    model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
    try:
        out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0)
        print("Safe output bits:", out.squeeze(0).tolist())
    except RuntimeError as e:
        print("Gate triggered:", e)


def safe_sample_with_retry(
    model: BitTransformerLM,
    bit_seq: torch.Tensor,
    c_floor: float = 0.3,
    s_floor: float = 0.5,
    *,
    causal: bool = True,
    max_retries: int = 3,
    backoff: float = 0.1,
    gate: Optional[SafetyGate] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """Run :func:`hil_safe_inference` with automatic retries.

    The helper retries failed safety checks by toggling diffusion mode and
    refreshing the input bits. An exponential backoff is applied between
    attempts and warnings are logged for each retry.

    Parameters
    ----------
    gate:
        Optional :class:`SafetyGate` instance shared across retries to apply
        EMA smoothing and burn-in.

    Returns
    -------
    Tuple[torch.Tensor, Dict[str, torch.Tensor]]
        The sampled bits and associated telemetry.
    """

    for attempt in range(max_retries):
        try:
            return hil_safe_inference(
                model,
                bit_seq,
                c_floor,
                s_floor,
                causal=causal,
                strict=True,
                gate=gate,
            )
        except RuntimeError as exc:  # safety gate triggered
            logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc)
            if attempt >= max_retries - 1:
                raise
            time.sleep(backoff * (2 ** attempt))
            causal = False  # retry in diffusion mode
            bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)