File size: 3,222 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn

from .model import BitTransformerLM


@dataclass
class TelemetryLog:
    """Telemetry container holding attention maps across steps.

    Attributes:
        attention_maps: Tensor of shape [steps, heads, seq, seq].
    """

    attention_maps: torch.Tensor


def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM:
    """Return a pruned copy of ``model`` according to attention telemetry.

    Args:
        model: Teacher model to distill from.
        scale: Fraction of weights to retain (0 < scale <= 1).
        telemetry: Logged attention maps used to estimate parameter importance.

    This function computes an importance score for each weight in the model's
    linear layers using the supplied attention maps. The score is the mean
    activation over time multiplied by the number of visits (non-zero
    attention). The bottom ``(1 - scale)`` fraction of weights in each layer are
    zeroed out, yielding a sparsified student model.
    """
    if not (0.0 < scale <= 1.0):
        raise ValueError("scale must lie in (0, 1].")

    # Clone the model so the teacher remains untouched.
    student = BitTransformerLM(
        d_model=model.d_model,
        nhead=model.layers[0].self_attn.num_heads,
        num_layers=model.num_layers,
        dim_feedforward=model.layers[0].linear1.out_features,
        max_seq_len=model.pos_enc.pe.size(0),
        lambda_K=model.lambda_K,
        lambda_C=model.lambda_C,
        lambda_S=model.lambda_S,
        reversible=model.reversible,
        use_checkpoint=model.use_checkpoint,
        use_autocast=model.use_autocast,
        use_act=model.use_act,
        act_threshold=model.act_threshold,
        chunk_size=model.chunk_size,
        overlap=model.overlap,
    )
    student.load_state_dict(model.state_dict())

    attn = telemetry.attention_maps  # [steps, heads, seq, seq]
    steps = attn.shape[0]
    heads = attn.shape[1]
    mean_act = attn.mean(dim=(0, 2, 3))
    visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1)
    head_importance = mean_act * visits
    head_importance = head_importance / head_importance.sum()

    prune_frac = 1.0 - scale

    for module in student.modules():
        if isinstance(module, nn.Linear):
            weight = module.weight.data
            out_features = weight.size(0)
            if out_features % heads == 0:
                repeats = out_features // heads
                row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1)
            else:
                row_scores = head_importance.mean().expand(out_features, 1)

            importance = weight.abs() * row_scores
            k = int(importance.numel() * prune_frac)
            if k > 0:
                thresh = torch.topk(importance.view(-1), k, largest=False).values.max()
                mask = importance > thresh
                weight.mul_(mask)
                if module.bias is not None:
                    row_mask = mask.view(out_features, -1).any(dim=1)
                    module.bias.data.mul_(row_mask)

    return student