from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from .model import BitTransformerLM @dataclass class TelemetryLog: """Telemetry container holding attention maps across steps. Attributes: attention_maps: Tensor of shape [steps, heads, seq, seq]. """ attention_maps: torch.Tensor def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM: """Return a pruned copy of ``model`` according to attention telemetry. Args: model: Teacher model to distill from. scale: Fraction of weights to retain (0 < scale <= 1). telemetry: Logged attention maps used to estimate parameter importance. This function computes an importance score for each weight in the model's linear layers using the supplied attention maps. The score is the mean activation over time multiplied by the number of visits (non-zero attention). The bottom ``(1 - scale)`` fraction of weights in each layer are zeroed out, yielding a sparsified student model. """ if not (0.0 < scale <= 1.0): raise ValueError("scale must lie in (0, 1].") # Clone the model so the teacher remains untouched. student = BitTransformerLM( d_model=model.d_model, nhead=model.layers[0].self_attn.num_heads, num_layers=model.num_layers, dim_feedforward=model.layers[0].linear1.out_features, max_seq_len=model.pos_enc.pe.size(0), lambda_K=model.lambda_K, lambda_C=model.lambda_C, lambda_S=model.lambda_S, reversible=model.reversible, use_checkpoint=model.use_checkpoint, use_autocast=model.use_autocast, use_act=model.use_act, act_threshold=model.act_threshold, chunk_size=model.chunk_size, overlap=model.overlap, ) student.load_state_dict(model.state_dict()) attn = telemetry.attention_maps # [steps, heads, seq, seq] steps = attn.shape[0] heads = attn.shape[1] mean_act = attn.mean(dim=(0, 2, 3)) visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1) head_importance = mean_act * visits head_importance = head_importance / head_importance.sum() prune_frac = 1.0 - scale for module in student.modules(): if isinstance(module, nn.Linear): weight = module.weight.data out_features = weight.size(0) if out_features % heads == 0: repeats = out_features // heads row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1) else: row_scores = head_importance.mean().expand(out_features, 1) importance = weight.abs() * row_scores k = int(importance.numel() * prune_frac) if k > 0: thresh = torch.topk(importance.view(-1), k, largest=False).values.max() mask = importance > thresh weight.mul_(mask) if module.bias is not None: row_mask = mask.view(out_features, -1).any(dim=1) module.bias.data.mul_(row_mask) return student