|
"""Common training utilities for BitTransformer models.""" |
|
|
|
from __future__ import annotations |
|
|
|
from typing import Callable, Dict, List, Optional |
|
import contextlib |
|
import sys |
|
import warnings |
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
from .compression import compress_bits, pack_bits, unpack_bits |
|
from .optimization import configure_optimizer |
|
from .model import BitTransformerLM |
|
from .utils import set_dropout |
|
from .torch_utils import cpu_autocast |
|
|
|
|
|
def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float: |
|
"""Cosine ramp from ``start`` to ``end`` over ``total_steps``.""" |
|
if total_steps <= 0 or step >= total_steps: |
|
return end |
|
cos_inner = math.pi * step / total_steps |
|
return start + (end - start) * (1 - math.cos(cos_inner)) / 2 |
|
|
|
|
|
def train_loop( |
|
model: BitTransformerLM, |
|
data: torch.Tensor, |
|
*, |
|
epochs: int = 1, |
|
extra_steps: int = 0, |
|
compress_prob: float = 0.5, |
|
direct_prob: float = 0.0, |
|
batch_size: int = 8, |
|
num_workers: int = 0, |
|
accum_steps: int = 1, |
|
amp: bool = False, |
|
compile_model: bool = False, |
|
log: bool = False, |
|
forward_kwargs: Optional[Dict] = None, |
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
|
diffusion: bool = False, |
|
noise_fn: Optional[Callable[[], float]] = None, |
|
diffusion_curriculum: bool = False, |
|
compress_warmup: int = 0, |
|
) -> List[Dict[str, float]]: |
|
"""Generic training loop supporting optional compression and diffusion. |
|
|
|
``compress_prob`` controls the fraction of batches that are run through |
|
``forward_compressed``. ``direct_prob`` instead feeds the model with the |
|
bit-packed result of ``compress_bits`` after converting back to a bit |
|
tensor. When enabled, metrics for direct-compressed batches are tracked |
|
separately. |
|
|
|
When ``diffusion`` is ``True`` the loop performs denoising training. Batches |
|
are noised by randomly flipping bits with a probability given by |
|
``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When |
|
``diffusion_curriculum`` is ``True`` the noise probability decreases |
|
linearly from ``0.5`` to ``0.0`` over the training epochs. The model is |
|
then trained to recover the clean sequence using full-context attention |
|
(``causal=False``). |
|
|
|
Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow |
|
integration with long-running training sessions, otherwise new ones are |
|
created automatically. |
|
""" |
|
if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1": |
|
model = torch.compile(model) |
|
elif compile_model: |
|
warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12") |
|
|
|
model.train() |
|
set_dropout(model, 0.1) |
|
|
|
device = next(model.parameters()).device |
|
loader = DataLoader( |
|
data, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
persistent_workers=num_workers > 0, |
|
) |
|
steps_per_epoch = max(1, len(loader)) |
|
total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps) |
|
if optimizer is None or scheduler is None: |
|
optimizer, scheduler = configure_optimizer( |
|
model, lr=1e-3, total_steps=total_updates |
|
) |
|
metrics: List[Dict[str, float]] = [] |
|
|
|
global_step = 0 |
|
for epoch in range(epochs): |
|
raw_losses: List[float] = [] |
|
raw_accs: List[float] = [] |
|
comp_losses: List[float] = [] |
|
comp_accs: List[float] = [] |
|
comp_ratios: List[float] = [] |
|
direct_losses: List[float] = [] |
|
|
|
last_batch = None |
|
for step, batch in enumerate(loader): |
|
last_batch = batch |
|
batch = batch.to(device) |
|
cur_compress = ( |
|
cosine_ramp(global_step, 0.0, compress_prob, compress_warmup) |
|
if not diffusion |
|
else compress_prob |
|
) |
|
if diffusion: |
|
if diffusion_curriculum: |
|
p = 0.5 * (1 - epoch / max(1, epochs - 1)) |
|
else: |
|
p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5) |
|
noise = (torch.rand_like(batch.float()) < p).long() |
|
noisy = batch ^ noise |
|
with ( |
|
torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
if amp and torch.cuda.is_available() |
|
else cpu_autocast() if amp else contextlib.nullcontext() |
|
): |
|
logits, _ = model(noisy, causal=False) |
|
pred = logits.reshape(-1, 2) |
|
target = batch.reshape(-1) |
|
loss = F.cross_entropy(pred, target) / accum_steps |
|
acc = (pred.argmax(dim=-1) == target).float().mean().item() |
|
raw_losses.append(loss.item() * accum_steps) |
|
raw_accs.append(acc) |
|
loss.backward() |
|
if (step + 1) % accum_steps == 0: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
global_step += 1 |
|
continue |
|
|
|
r = torch.rand(()) |
|
key = "raw" |
|
ratio = 1.0 |
|
target = batch[:, 1:].reshape(-1) |
|
|
|
if r < direct_prob: |
|
packed = [pack_bits(row.to(torch.uint8)) for row in batch] |
|
unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed] |
|
max_len = min( |
|
max(u.numel() for u in unpacked), |
|
model.pos_enc.pe.size(0), |
|
) |
|
padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked] |
|
dc_batch = torch.stack(padded).long() |
|
with ( |
|
torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
if amp and torch.cuda.is_available() |
|
else cpu_autocast() if amp else contextlib.nullcontext() |
|
): |
|
logits, _ = model(dc_batch, **(forward_kwargs or {})) |
|
ratio = sum(p.numel() for p in packed) / batch.numel() |
|
target = dc_batch[:, 1:].reshape(-1) |
|
key = "direct" |
|
elif r < direct_prob + cur_compress: |
|
comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch] |
|
with ( |
|
torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
if amp and torch.cuda.is_available() |
|
else cpu_autocast() if amp else contextlib.nullcontext() |
|
): |
|
logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {})) |
|
ratio = sum(c.numel() for c in comp_batch) / batch.numel() |
|
target = batch[:, 1:].reshape(-1) |
|
key = "compressed" |
|
else: |
|
with ( |
|
torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
if amp and torch.cuda.is_available() |
|
else cpu_autocast() if amp else contextlib.nullcontext() |
|
): |
|
logits, _ = model(batch, **(forward_kwargs or {})) |
|
|
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
loss = F.cross_entropy(pred, target) / accum_steps |
|
acc = (pred.argmax(dim=-1) == target).float().mean().item() |
|
|
|
loss.backward() |
|
if (step + 1) % accum_steps == 0: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
global_step += 1 |
|
|
|
if key == "compressed": |
|
comp_losses.append(loss.item() * accum_steps) |
|
comp_accs.append(acc) |
|
comp_ratios.append(ratio) |
|
elif key == "direct": |
|
direct_losses.append(loss.item() * accum_steps) |
|
comp_ratios.append(ratio) |
|
else: |
|
raw_losses.append(loss.item() * accum_steps) |
|
raw_accs.append(acc) |
|
|
|
|
|
if extra_steps > 0 and last_batch is not None and not diffusion: |
|
for step in range(extra_steps): |
|
with ( |
|
torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
if amp and torch.cuda.is_available() |
|
else cpu_autocast() if amp else contextlib.nullcontext() |
|
): |
|
logits, _ = model(last_batch, **(forward_kwargs or {})) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = last_batch[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) / accum_steps |
|
acc = (pred.argmax(dim=-1) == target).float().mean().item() |
|
loss.backward() |
|
if (step + 1) % accum_steps == 0: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
raw_losses.append(loss.item() * accum_steps) |
|
raw_accs.append(acc) |
|
global_step += 1 |
|
|
|
m = { |
|
"raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0, |
|
"raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0, |
|
"compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0, |
|
"compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0, |
|
"direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0, |
|
"compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0, |
|
} |
|
metrics.append(m) |
|
|
|
if log: |
|
print( |
|
f"Epoch {epoch} " |
|
f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | " |
|
f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} " |
|
f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}" |
|
) |
|
|
|
return metrics |
|
|
|
__all__ = ["train_loop"] |
|
|