"""Common training utilities for BitTransformer models.""" from __future__ import annotations from typing import Callable, Dict, List, Optional import contextlib import sys import warnings import math import torch import torch.nn.functional as F from torch.utils.data import DataLoader from .compression import compress_bits, pack_bits, unpack_bits from .optimization import configure_optimizer from .model import BitTransformerLM from .utils import set_dropout from .torch_utils import cpu_autocast def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float: """Cosine ramp from ``start`` to ``end`` over ``total_steps``.""" if total_steps <= 0 or step >= total_steps: return end cos_inner = math.pi * step / total_steps return start + (end - start) * (1 - math.cos(cos_inner)) / 2 def train_loop( model: BitTransformerLM, data: torch.Tensor, *, epochs: int = 1, extra_steps: int = 0, compress_prob: float = 0.5, direct_prob: float = 0.0, batch_size: int = 8, num_workers: int = 0, accum_steps: int = 1, amp: bool = False, compile_model: bool = False, log: bool = False, forward_kwargs: Optional[Dict] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, diffusion: bool = False, noise_fn: Optional[Callable[[], float]] = None, diffusion_curriculum: bool = False, compress_warmup: int = 0, ) -> List[Dict[str, float]]: """Generic training loop supporting optional compression and diffusion. ``compress_prob`` controls the fraction of batches that are run through ``forward_compressed``. ``direct_prob`` instead feeds the model with the bit-packed result of ``compress_bits`` after converting back to a bit tensor. When enabled, metrics for direct-compressed batches are tracked separately. When ``diffusion`` is ``True`` the loop performs denoising training. Batches are noised by randomly flipping bits with a probability given by ``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When ``diffusion_curriculum`` is ``True`` the noise probability decreases linearly from ``0.5`` to ``0.0`` over the training epochs. The model is then trained to recover the clean sequence using full-context attention (``causal=False``). Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow integration with long-running training sessions, otherwise new ones are created automatically. """ if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1": model = torch.compile(model) elif compile_model: warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12") model.train() set_dropout(model, 0.1) device = next(model.parameters()).device loader = DataLoader( data, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=num_workers > 0, ) steps_per_epoch = max(1, len(loader)) total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps) if optimizer is None or scheduler is None: optimizer, scheduler = configure_optimizer( model, lr=1e-3, total_steps=total_updates ) metrics: List[Dict[str, float]] = [] global_step = 0 for epoch in range(epochs): raw_losses: List[float] = [] raw_accs: List[float] = [] comp_losses: List[float] = [] comp_accs: List[float] = [] comp_ratios: List[float] = [] direct_losses: List[float] = [] last_batch = None for step, batch in enumerate(loader): last_batch = batch batch = batch.to(device) cur_compress = ( cosine_ramp(global_step, 0.0, compress_prob, compress_warmup) if not diffusion else compress_prob ) if diffusion: if diffusion_curriculum: p = 0.5 * (1 - epoch / max(1, epochs - 1)) else: p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5) noise = (torch.rand_like(batch.float()) < p).long() noisy = batch ^ noise with ( torch.cuda.amp.autocast(dtype=torch.bfloat16) if amp and torch.cuda.is_available() else cpu_autocast() if amp else contextlib.nullcontext() ): logits, _ = model(noisy, causal=False) pred = logits.reshape(-1, 2) target = batch.reshape(-1) loss = F.cross_entropy(pred, target) / accum_steps acc = (pred.argmax(dim=-1) == target).float().mean().item() raw_losses.append(loss.item() * accum_steps) raw_accs.append(acc) loss.backward() if (step + 1) % accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 continue r = torch.rand(()) key = "raw" ratio = 1.0 target = batch[:, 1:].reshape(-1) if r < direct_prob: packed = [pack_bits(row.to(torch.uint8)) for row in batch] unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed] max_len = min( max(u.numel() for u in unpacked), model.pos_enc.pe.size(0), ) padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked] dc_batch = torch.stack(padded).long() with ( torch.cuda.amp.autocast(dtype=torch.bfloat16) if amp and torch.cuda.is_available() else cpu_autocast() if amp else contextlib.nullcontext() ): logits, _ = model(dc_batch, **(forward_kwargs or {})) ratio = sum(p.numel() for p in packed) / batch.numel() target = dc_batch[:, 1:].reshape(-1) key = "direct" elif r < direct_prob + cur_compress: comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch] with ( torch.cuda.amp.autocast(dtype=torch.bfloat16) if amp and torch.cuda.is_available() else cpu_autocast() if amp else contextlib.nullcontext() ): logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {})) ratio = sum(c.numel() for c in comp_batch) / batch.numel() target = batch[:, 1:].reshape(-1) key = "compressed" else: with ( torch.cuda.amp.autocast(dtype=torch.bfloat16) if amp and torch.cuda.is_available() else cpu_autocast() if amp else contextlib.nullcontext() ): logits, _ = model(batch, **(forward_kwargs or {})) pred = logits[:, :-1, :].reshape(-1, 2) loss = F.cross_entropy(pred, target) / accum_steps acc = (pred.argmax(dim=-1) == target).float().mean().item() loss.backward() if (step + 1) % accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 if key == "compressed": comp_losses.append(loss.item() * accum_steps) comp_accs.append(acc) comp_ratios.append(ratio) elif key == "direct": direct_losses.append(loss.item() * accum_steps) comp_ratios.append(ratio) else: raw_losses.append(loss.item() * accum_steps) raw_accs.append(acc) # run extra gradient updates using the final batch if extra_steps > 0 and last_batch is not None and not diffusion: for step in range(extra_steps): with ( torch.cuda.amp.autocast(dtype=torch.bfloat16) if amp and torch.cuda.is_available() else cpu_autocast() if amp else contextlib.nullcontext() ): logits, _ = model(last_batch, **(forward_kwargs or {})) pred = logits[:, :-1, :].reshape(-1, 2) target = last_batch[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) / accum_steps acc = (pred.argmax(dim=-1) == target).float().mean().item() loss.backward() if (step + 1) % accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() raw_losses.append(loss.item() * accum_steps) raw_accs.append(acc) global_step += 1 m = { "raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0, "raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0, "compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0, "compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0, "direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0, "compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0, } metrics.append(m) if log: print( f"Epoch {epoch} " f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | " f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} " f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}" ) return metrics __all__ = ["train_loop"]