WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
3.22 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from .model import BitTransformerLM
@dataclass
class TelemetryLog:
"""Telemetry container holding attention maps across steps.
Attributes:
attention_maps: Tensor of shape [steps, heads, seq, seq].
"""
attention_maps: torch.Tensor
def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM:
"""Return a pruned copy of ``model`` according to attention telemetry.
Args:
model: Teacher model to distill from.
scale: Fraction of weights to retain (0 < scale <= 1).
telemetry: Logged attention maps used to estimate parameter importance.
This function computes an importance score for each weight in the model's
linear layers using the supplied attention maps. The score is the mean
activation over time multiplied by the number of visits (non-zero
attention). The bottom ``(1 - scale)`` fraction of weights in each layer are
zeroed out, yielding a sparsified student model.
"""
if not (0.0 < scale <= 1.0):
raise ValueError("scale must lie in (0, 1].")
# Clone the model so the teacher remains untouched.
student = BitTransformerLM(
d_model=model.d_model,
nhead=model.layers[0].self_attn.num_heads,
num_layers=model.num_layers,
dim_feedforward=model.layers[0].linear1.out_features,
max_seq_len=model.pos_enc.pe.size(0),
lambda_K=model.lambda_K,
lambda_C=model.lambda_C,
lambda_S=model.lambda_S,
reversible=model.reversible,
use_checkpoint=model.use_checkpoint,
use_autocast=model.use_autocast,
use_act=model.use_act,
act_threshold=model.act_threshold,
chunk_size=model.chunk_size,
overlap=model.overlap,
)
student.load_state_dict(model.state_dict())
attn = telemetry.attention_maps # [steps, heads, seq, seq]
steps = attn.shape[0]
heads = attn.shape[1]
mean_act = attn.mean(dim=(0, 2, 3))
visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1)
head_importance = mean_act * visits
head_importance = head_importance / head_importance.sum()
prune_frac = 1.0 - scale
for module in student.modules():
if isinstance(module, nn.Linear):
weight = module.weight.data
out_features = weight.size(0)
if out_features % heads == 0:
repeats = out_features // heads
row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1)
else:
row_scores = head_importance.mean().expand(out_features, 1)
importance = weight.abs() * row_scores
k = int(importance.numel() * prune_frac)
if k > 0:
thresh = torch.topk(importance.view(-1), k, largest=False).values.max()
mask = importance > thresh
weight.mul_(mask)
if module.bias is not None:
row_mask = mask.view(out_features, -1).any(dim=1)
module.bias.data.mul_(row_mask)
return student