|
from typing import List, TYPE_CHECKING |
|
import torch |
|
import sys |
|
|
|
try: |
|
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
|
compile_fn = torch.compile |
|
else: |
|
raise RuntimeError |
|
except Exception: |
|
|
|
def compile_fn(fn=None, **kwargs): |
|
if fn is None: |
|
return lambda f: f |
|
return fn |
|
|
|
|
|
if TYPE_CHECKING: |
|
from .model import BitTransformerLM |
|
|
|
|
|
@compile_fn |
|
def bytes_to_bits(data: bytes) -> List[int]: |
|
"""Convert bytes to bits with per-byte parity bit.""" |
|
result: List[int] = [] |
|
for b in data: |
|
bits = [(b >> i) & 1 for i in reversed(range(8))] |
|
parity = sum(bits) % 2 |
|
result.extend(bits + [parity]) |
|
return result |
|
|
|
|
|
@compile_fn |
|
def bits_to_bytes(bits: List[int]) -> bytes: |
|
"""Convert parity-protected bits back to bytes.""" |
|
if len(bits) % 9 != 0: |
|
raise ValueError("Bit stream length must be multiple of 9") |
|
out = bytearray() |
|
for i in range(0, len(bits), 9): |
|
chunk = bits[i : i + 9] |
|
payload = chunk[:8] |
|
parity = chunk[8] |
|
if parity != sum(payload) % 2: |
|
raise ValueError("Parity check failed") |
|
value = 0 |
|
for bit in payload: |
|
value = (value << 1) | bit |
|
out.append(value) |
|
return bytes(out) |
|
|
|
|
|
def text_to_bits(text: str) -> List[int]: |
|
return bytes_to_bits(text.encode("utf-8")) |
|
|
|
|
|
def bits_to_text(bits: List[int]) -> str: |
|
return bits_to_bytes(bits).decode("utf-8", errors="replace") |
|
|
|
|
|
def infer_text( |
|
model: "BitTransformerLM", |
|
text: str, |
|
c_floor: float = 0.3, |
|
s_floor: float = 0.5, |
|
) -> str: |
|
"""Run text through the model using the safety gate.""" |
|
from .safety import hil_safe_inference |
|
bits = text_to_bits(text) |
|
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
|
out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor) |
|
return bits_to_text(out_bits.squeeze(0).tolist()) |
|
|
|
|
|
def sample_text( |
|
model: "BitTransformerLM", |
|
prompt: str, |
|
max_new_tokens: int = 16, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
) -> str: |
|
"""Generate text from the model using simple top-p sampling.""" |
|
model.eval() |
|
bits = text_to_bits(prompt) |
|
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
|
for _ in range(max_new_tokens * 9): |
|
if tensor.size(1) >= model.pos_enc.pe.size(0): |
|
break |
|
logits, _ = model(tensor, causal=True) |
|
prob = logits[0, -1].softmax(-1) / temperature |
|
sorted_prob, sorted_idx = prob.sort(descending=True) |
|
cumulative = sorted_prob.cumsum(0) |
|
mask = cumulative > top_p |
|
sorted_prob[mask] = 0 |
|
sorted_prob = sorted_prob / sorted_prob.sum() |
|
next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)] |
|
tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1) |
|
return bits_to_text(tensor.squeeze(0).tolist()) |
|
|