|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .model import BitTransformerLM |
|
|
|
|
|
@dataclass |
|
class TelemetryLog: |
|
"""Telemetry container holding attention maps across steps. |
|
|
|
Attributes: |
|
attention_maps: Tensor of shape [steps, heads, seq, seq]. |
|
""" |
|
|
|
attention_maps: torch.Tensor |
|
|
|
|
|
def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM: |
|
"""Return a pruned copy of ``model`` according to attention telemetry. |
|
|
|
Args: |
|
model: Teacher model to distill from. |
|
scale: Fraction of weights to retain (0 < scale <= 1). |
|
telemetry: Logged attention maps used to estimate parameter importance. |
|
|
|
This function computes an importance score for each weight in the model's |
|
linear layers using the supplied attention maps. The score is the mean |
|
activation over time multiplied by the number of visits (non-zero |
|
attention). The bottom ``(1 - scale)`` fraction of weights in each layer are |
|
zeroed out, yielding a sparsified student model. |
|
""" |
|
if not (0.0 < scale <= 1.0): |
|
raise ValueError("scale must lie in (0, 1].") |
|
|
|
|
|
student = BitTransformerLM( |
|
d_model=model.d_model, |
|
nhead=model.layers[0].self_attn.num_heads, |
|
num_layers=model.num_layers, |
|
dim_feedforward=model.layers[0].linear1.out_features, |
|
max_seq_len=model.pos_enc.pe.size(0), |
|
lambda_K=model.lambda_K, |
|
lambda_C=model.lambda_C, |
|
lambda_S=model.lambda_S, |
|
reversible=model.reversible, |
|
use_checkpoint=model.use_checkpoint, |
|
use_autocast=model.use_autocast, |
|
use_act=model.use_act, |
|
act_threshold=model.act_threshold, |
|
chunk_size=model.chunk_size, |
|
overlap=model.overlap, |
|
) |
|
student.load_state_dict(model.state_dict()) |
|
|
|
attn = telemetry.attention_maps |
|
steps = attn.shape[0] |
|
heads = attn.shape[1] |
|
mean_act = attn.mean(dim=(0, 2, 3)) |
|
visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1) |
|
head_importance = mean_act * visits |
|
head_importance = head_importance / head_importance.sum() |
|
|
|
prune_frac = 1.0 - scale |
|
|
|
for module in student.modules(): |
|
if isinstance(module, nn.Linear): |
|
weight = module.weight.data |
|
out_features = weight.size(0) |
|
if out_features % heads == 0: |
|
repeats = out_features // heads |
|
row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1) |
|
else: |
|
row_scores = head_importance.mean().expand(out_features, 1) |
|
|
|
importance = weight.abs() * row_scores |
|
k = int(importance.numel() * prune_frac) |
|
if k > 0: |
|
thresh = torch.topk(importance.view(-1), k, largest=False).values.max() |
|
mask = importance > thresh |
|
weight.mul_(mask) |
|
if module.bias is not None: |
|
row_mask = mask.view(out_features, -1).any(dim=1) |
|
module.bias.data.mul_(row_mask) |
|
|
|
return student |
|
|