WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
3.05 kB
from typing import List, TYPE_CHECKING
import torch
import sys
try: # torch.compile may be unavailable or unsupported
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
compile_fn = torch.compile
else:
raise RuntimeError
except Exception: # pragma: no cover
def compile_fn(fn=None, **kwargs):
if fn is None:
return lambda f: f
return fn
if TYPE_CHECKING: # pragma: no cover
from .model import BitTransformerLM
@compile_fn
def bytes_to_bits(data: bytes) -> List[int]:
"""Convert bytes to bits with per-byte parity bit."""
result: List[int] = []
for b in data:
bits = [(b >> i) & 1 for i in reversed(range(8))]
parity = sum(bits) % 2
result.extend(bits + [parity])
return result
@compile_fn
def bits_to_bytes(bits: List[int]) -> bytes:
"""Convert parity-protected bits back to bytes."""
if len(bits) % 9 != 0:
raise ValueError("Bit stream length must be multiple of 9")
out = bytearray()
for i in range(0, len(bits), 9):
chunk = bits[i : i + 9]
payload = chunk[:8]
parity = chunk[8]
if parity != sum(payload) % 2:
raise ValueError("Parity check failed")
value = 0
for bit in payload:
value = (value << 1) | bit
out.append(value)
return bytes(out)
def text_to_bits(text: str) -> List[int]:
return bytes_to_bits(text.encode("utf-8"))
def bits_to_text(bits: List[int]) -> str:
return bits_to_bytes(bits).decode("utf-8", errors="replace")
def infer_text(
model: "BitTransformerLM",
text: str,
c_floor: float = 0.3,
s_floor: float = 0.5,
) -> str:
"""Run text through the model using the safety gate."""
from .safety import hil_safe_inference
bits = text_to_bits(text)
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor)
return bits_to_text(out_bits.squeeze(0).tolist())
def sample_text(
model: "BitTransformerLM",
prompt: str,
max_new_tokens: int = 16,
temperature: float = 1.0,
top_p: float = 1.0,
) -> str:
"""Generate text from the model using simple top-p sampling."""
model.eval()
bits = text_to_bits(prompt)
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
for _ in range(max_new_tokens * 9):
if tensor.size(1) >= model.pos_enc.pe.size(0):
break
logits, _ = model(tensor, causal=True)
prob = logits[0, -1].softmax(-1) / temperature
sorted_prob, sorted_idx = prob.sort(descending=True)
cumulative = sorted_prob.cumsum(0)
mask = cumulative > top_p
sorted_prob[mask] = 0
sorted_prob = sorted_prob / sorted_prob.sum()
next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)]
tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1)
return bits_to_text(tensor.squeeze(0).tolist())