from typing import List, TYPE_CHECKING import torch import sys try: # torch.compile may be unavailable or unsupported if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): compile_fn = torch.compile else: raise RuntimeError except Exception: # pragma: no cover def compile_fn(fn=None, **kwargs): if fn is None: return lambda f: f return fn if TYPE_CHECKING: # pragma: no cover from .model import BitTransformerLM @compile_fn def bytes_to_bits(data: bytes) -> List[int]: """Convert bytes to bits with per-byte parity bit.""" result: List[int] = [] for b in data: bits = [(b >> i) & 1 for i in reversed(range(8))] parity = sum(bits) % 2 result.extend(bits + [parity]) return result @compile_fn def bits_to_bytes(bits: List[int]) -> bytes: """Convert parity-protected bits back to bytes.""" if len(bits) % 9 != 0: raise ValueError("Bit stream length must be multiple of 9") out = bytearray() for i in range(0, len(bits), 9): chunk = bits[i : i + 9] payload = chunk[:8] parity = chunk[8] if parity != sum(payload) % 2: raise ValueError("Parity check failed") value = 0 for bit in payload: value = (value << 1) | bit out.append(value) return bytes(out) def text_to_bits(text: str) -> List[int]: return bytes_to_bits(text.encode("utf-8")) def bits_to_text(bits: List[int]) -> str: return bits_to_bytes(bits).decode("utf-8", errors="replace") def infer_text( model: "BitTransformerLM", text: str, c_floor: float = 0.3, s_floor: float = 0.5, ) -> str: """Run text through the model using the safety gate.""" from .safety import hil_safe_inference bits = text_to_bits(text) tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor) return bits_to_text(out_bits.squeeze(0).tolist()) def sample_text( model: "BitTransformerLM", prompt: str, max_new_tokens: int = 16, temperature: float = 1.0, top_p: float = 1.0, ) -> str: """Generate text from the model using simple top-p sampling.""" model.eval() bits = text_to_bits(prompt) tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) for _ in range(max_new_tokens * 9): if tensor.size(1) >= model.pos_enc.pe.size(0): break logits, _ = model(tensor, causal=True) prob = logits[0, -1].softmax(-1) / temperature sorted_prob, sorted_idx = prob.sort(descending=True) cumulative = sorted_prob.cumsum(0) mask = cumulative > top_p sorted_prob[mask] = 0 sorted_prob = sorted_prob / sorted_prob.sum() next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)] tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1) return bits_to_text(tensor.squeeze(0).tolist())