WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
10.5 kB
"""Common training utilities for BitTransformer models."""
from __future__ import annotations
from typing import Callable, Dict, List, Optional
import contextlib
import sys
import warnings
import math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .compression import compress_bits, pack_bits, unpack_bits
from .optimization import configure_optimizer
from .model import BitTransformerLM
from .utils import set_dropout
from .torch_utils import cpu_autocast
def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float:
"""Cosine ramp from ``start`` to ``end`` over ``total_steps``."""
if total_steps <= 0 or step >= total_steps:
return end
cos_inner = math.pi * step / total_steps
return start + (end - start) * (1 - math.cos(cos_inner)) / 2
def train_loop(
model: BitTransformerLM,
data: torch.Tensor,
*,
epochs: int = 1,
extra_steps: int = 0,
compress_prob: float = 0.5,
direct_prob: float = 0.0,
batch_size: int = 8,
num_workers: int = 0,
accum_steps: int = 1,
amp: bool = False,
compile_model: bool = False,
log: bool = False,
forward_kwargs: Optional[Dict] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
diffusion: bool = False,
noise_fn: Optional[Callable[[], float]] = None,
diffusion_curriculum: bool = False,
compress_warmup: int = 0,
) -> List[Dict[str, float]]:
"""Generic training loop supporting optional compression and diffusion.
``compress_prob`` controls the fraction of batches that are run through
``forward_compressed``. ``direct_prob`` instead feeds the model with the
bit-packed result of ``compress_bits`` after converting back to a bit
tensor. When enabled, metrics for direct-compressed batches are tracked
separately.
When ``diffusion`` is ``True`` the loop performs denoising training. Batches
are noised by randomly flipping bits with a probability given by
``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When
``diffusion_curriculum`` is ``True`` the noise probability decreases
linearly from ``0.5`` to ``0.0`` over the training epochs. The model is
then trained to recover the clean sequence using full-context attention
(``causal=False``).
Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow
integration with long-running training sessions, otherwise new ones are
created automatically.
"""
if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1":
model = torch.compile(model)
elif compile_model:
warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12")
model.train()
set_dropout(model, 0.1)
device = next(model.parameters()).device
loader = DataLoader(
data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
persistent_workers=num_workers > 0,
)
steps_per_epoch = max(1, len(loader))
total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps)
if optimizer is None or scheduler is None:
optimizer, scheduler = configure_optimizer(
model, lr=1e-3, total_steps=total_updates
)
metrics: List[Dict[str, float]] = []
global_step = 0
for epoch in range(epochs):
raw_losses: List[float] = []
raw_accs: List[float] = []
comp_losses: List[float] = []
comp_accs: List[float] = []
comp_ratios: List[float] = []
direct_losses: List[float] = []
last_batch = None
for step, batch in enumerate(loader):
last_batch = batch
batch = batch.to(device)
cur_compress = (
cosine_ramp(global_step, 0.0, compress_prob, compress_warmup)
if not diffusion
else compress_prob
)
if diffusion:
if diffusion_curriculum:
p = 0.5 * (1 - epoch / max(1, epochs - 1))
else:
p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5)
noise = (torch.rand_like(batch.float()) < p).long()
noisy = batch ^ noise
with (
torch.cuda.amp.autocast(dtype=torch.bfloat16)
if amp and torch.cuda.is_available()
else cpu_autocast() if amp else contextlib.nullcontext()
):
logits, _ = model(noisy, causal=False)
pred = logits.reshape(-1, 2)
target = batch.reshape(-1)
loss = F.cross_entropy(pred, target) / accum_steps
acc = (pred.argmax(dim=-1) == target).float().mean().item()
raw_losses.append(loss.item() * accum_steps)
raw_accs.append(acc)
loss.backward()
if (step + 1) % accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
continue
r = torch.rand(())
key = "raw"
ratio = 1.0
target = batch[:, 1:].reshape(-1)
if r < direct_prob:
packed = [pack_bits(row.to(torch.uint8)) for row in batch]
unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed]
max_len = min(
max(u.numel() for u in unpacked),
model.pos_enc.pe.size(0),
)
padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked]
dc_batch = torch.stack(padded).long()
with (
torch.cuda.amp.autocast(dtype=torch.bfloat16)
if amp and torch.cuda.is_available()
else cpu_autocast() if amp else contextlib.nullcontext()
):
logits, _ = model(dc_batch, **(forward_kwargs or {}))
ratio = sum(p.numel() for p in packed) / batch.numel()
target = dc_batch[:, 1:].reshape(-1)
key = "direct"
elif r < direct_prob + cur_compress:
comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch]
with (
torch.cuda.amp.autocast(dtype=torch.bfloat16)
if amp and torch.cuda.is_available()
else cpu_autocast() if amp else contextlib.nullcontext()
):
logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {}))
ratio = sum(c.numel() for c in comp_batch) / batch.numel()
target = batch[:, 1:].reshape(-1)
key = "compressed"
else:
with (
torch.cuda.amp.autocast(dtype=torch.bfloat16)
if amp and torch.cuda.is_available()
else cpu_autocast() if amp else contextlib.nullcontext()
):
logits, _ = model(batch, **(forward_kwargs or {}))
pred = logits[:, :-1, :].reshape(-1, 2)
loss = F.cross_entropy(pred, target) / accum_steps
acc = (pred.argmax(dim=-1) == target).float().mean().item()
loss.backward()
if (step + 1) % accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if key == "compressed":
comp_losses.append(loss.item() * accum_steps)
comp_accs.append(acc)
comp_ratios.append(ratio)
elif key == "direct":
direct_losses.append(loss.item() * accum_steps)
comp_ratios.append(ratio)
else:
raw_losses.append(loss.item() * accum_steps)
raw_accs.append(acc)
# run extra gradient updates using the final batch
if extra_steps > 0 and last_batch is not None and not diffusion:
for step in range(extra_steps):
with (
torch.cuda.amp.autocast(dtype=torch.bfloat16)
if amp and torch.cuda.is_available()
else cpu_autocast() if amp else contextlib.nullcontext()
):
logits, _ = model(last_batch, **(forward_kwargs or {}))
pred = logits[:, :-1, :].reshape(-1, 2)
target = last_batch[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target) / accum_steps
acc = (pred.argmax(dim=-1) == target).float().mean().item()
loss.backward()
if (step + 1) % accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
raw_losses.append(loss.item() * accum_steps)
raw_accs.append(acc)
global_step += 1
m = {
"raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0,
"raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0,
"compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0,
"compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0,
"direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0,
"compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0,
}
metrics.append(m)
if log:
print(
f"Epoch {epoch} "
f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | "
f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} "
f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}"
)
return metrics
__all__ = ["train_loop"]