File size: 3,052 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import List, TYPE_CHECKING
import torch
import sys

try:  # torch.compile may be unavailable or unsupported
    if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
        compile_fn = torch.compile
    else:
        raise RuntimeError
except Exception:  # pragma: no cover

    def compile_fn(fn=None, **kwargs):
        if fn is None:
            return lambda f: f
        return fn


if TYPE_CHECKING:  # pragma: no cover
    from .model import BitTransformerLM


@compile_fn
def bytes_to_bits(data: bytes) -> List[int]:
    """Convert bytes to bits with per-byte parity bit."""
    result: List[int] = []
    for b in data:
        bits = [(b >> i) & 1 for i in reversed(range(8))]
        parity = sum(bits) % 2
        result.extend(bits + [parity])
    return result


@compile_fn
def bits_to_bytes(bits: List[int]) -> bytes:
    """Convert parity-protected bits back to bytes."""
    if len(bits) % 9 != 0:
        raise ValueError("Bit stream length must be multiple of 9")
    out = bytearray()
    for i in range(0, len(bits), 9):
        chunk = bits[i : i + 9]
        payload = chunk[:8]
        parity = chunk[8]
        if parity != sum(payload) % 2:
            raise ValueError("Parity check failed")
        value = 0
        for bit in payload:
            value = (value << 1) | bit
        out.append(value)
    return bytes(out)


def text_to_bits(text: str) -> List[int]:
    return bytes_to_bits(text.encode("utf-8"))


def bits_to_text(bits: List[int]) -> str:
    return bits_to_bytes(bits).decode("utf-8", errors="replace")


def infer_text(
    model: "BitTransformerLM",
    text: str,
    c_floor: float = 0.3,
    s_floor: float = 0.5,
) -> str:
    """Run text through the model using the safety gate."""
    from .safety import hil_safe_inference
    bits = text_to_bits(text)
    tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
    out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor)
    return bits_to_text(out_bits.squeeze(0).tolist())


def sample_text(
    model: "BitTransformerLM",
    prompt: str,
    max_new_tokens: int = 16,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> str:
    """Generate text from the model using simple top-p sampling."""
    model.eval()
    bits = text_to_bits(prompt)
    tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
    for _ in range(max_new_tokens * 9):
        if tensor.size(1) >= model.pos_enc.pe.size(0):
            break
        logits, _ = model(tensor, causal=True)
        prob = logits[0, -1].softmax(-1) / temperature
        sorted_prob, sorted_idx = prob.sort(descending=True)
        cumulative = sorted_prob.cumsum(0)
        mask = cumulative > top_p
        sorted_prob[mask] = 0
        sorted_prob = sorted_prob / sorted_prob.sum()
        next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)]
        tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1)
    return bits_to_text(tensor.squeeze(0).tolist())