File size: 11,194 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""
This file contains the code to do generalize \phi_MATCH.
Given two models, it first retrains the MLPs or other FFN of the models, then runs \phi_MATCH on the distilled MLPs.

May need to be modified depending on model architecture (this code was used for GPT-architecture).
"""

MLP_SIZE = 3072
MLP_SIZE_2 = 3072
EMB_SIZE = 768
EMB_SIZE_2 = 768
N_BLOCKS = 12

import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    GPT2LMHeadModel,
    OpenAIGPTLMHeadModel,
)
import scipy
from collections import defaultdict

import numpy as np

from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
from tracing.utils.evaluate import (
    prepare_hf_dataset,
    prepare_hf_dataloader,
)

from tracing.utils.utils import manual_seed
from tracing.utils.llama.matching import match_wmats

manual_seed(0)


# architecture of MLP trained from scratch can be different from original
# eg, uncomment to get a 2-hidden layer MLP (original has just 1 hidden layer)
class CustomLlamaMLP(nn.Module):
    """
    Custom MLP module for Llama-style architecture with SwiGLU activation.

    This implementation allows for flexible architecture changes when training
    replacement MLPs for model distillation and analysis.

    Args:
        config: Model configuration containing embedding dimensions
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.n_embd
        self.intermediate_size = 4 * config.n_embd

        self.gate_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

        self.act_fn = nn.SiLU()

    def forward(self, x):
        """
        Forward pass implementing SwiGLU activation function.

        Args:
            x: Input tensor of shape [batch_size, seq_len, hidden_size]

        Returns:
            torch.Tensor: Output tensor after MLP transformation
        """
        down_proj = self.down_proj1(self.act_fn(self.gate_proj1(x)) * self.up_proj1(x))
        return down_proj


def hook_out(m, inp, op, feats, name):
    """
    Forward hook to capture output activations from model layers.

    Args:
        m: Module being hooked
        inp: Input to the module
        op: Output from the module
        feats: Dictionary to store activations
        name: Key to store the activations under
    """
    feats[name].append(op.detach().cpu())


def hook_in(m, inp, op, feats, name):
    """
    Forward hook to capture input activations to model layers.

    Args:
        m: Module being hooked
        inp: Input to the module (tuple)
        op: Output from the module
        feats: Dictionary to store activations
        name: Key to store the activations under
    """
    feats[name].append(inp[0].detach().cpu())


def mlp_layers(base_model_gate, base_model_up, ft_model_gate, ft_model_up, dataloader, i, j):
    """
    Compare gate and up projections between separate model components.

    Tests whether separately trained gate and up projection models have
    consistent permutation patterns, which would indicate functionally
    corresponding neurons.

    Args:
        base_model_gate: First model with gate projection weights
        base_model_up: First model with up projection weights
        ft_model_gate: Second model with gate projection weights
        ft_model_up: Second model with up projection weights
        dataloader: DataLoader providing input data for activation collection
        i: Layer index in the first model
        j: Layer index in the second model

    Returns:
        float: Pearson correlation p-value between gate and up projection permutations
    """
    gate_match = mlp_matching(base_model_gate, ft_model_gate, dataloader, i, j)
    up_match = mlp_matching(base_model_up, ft_model_up, dataloader, i, j)

    print(gate_match, up_match, i, j)

    cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())
    return pvalue


def mlp_matching(base_model, ft_model, dataloader, i, j):
    """
    Match neurons between models by comparing activations.

    Collects activations from the feed-forward layer for both models
    and computes a permutation that would align corresponding neurons.

    Args:
        base_model: Base model to compare
        ft_model: Target model to compare against the base model
        dataloader: DataLoader providing input data for activation collection
        i: Layer index in the base model
        j: Layer index in the target model

    Returns:
        torch.Tensor: Permutation indices that match neurons between models
    """
    feats = defaultdict(list)

    base_hook = lambda *args: hook_out(*args, feats, "base")
    base_handle = base_model.transformer.h[i].mlp.c_fc.register_forward_hook(base_hook)

    ft_hook = lambda *args: hook_out(*args, feats, "ft")
    ft_handle = ft_model.transformer.h[i].mlp.c_fc.register_forward_hook(ft_hook)

    evaluate(base_model, dataloader)
    evaluate(ft_model, dataloader)

    base_mat = torch.vstack(feats["base"])
    ft_mat = torch.vstack(feats["ft"])

    base_mat.to("cuda")
    ft_mat.to("cuda")

    base_mat = base_mat.view(-1, base_mat.shape[-1]).T
    ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T

    base_handle.remove()
    ft_handle.remove()

    perm = match_wmats(base_mat, ft_mat)
    return perm


def run(i):
    """
    Run the generalized MATCH algorithm to compare models with distilled components.

    This function:
    1. Loads two different GPT-2 models
    2. Trains custom MLPs to replicate the behavior of the original model MLPs
    3. Optionally applies random rotations to test invariance to representation changes
    4. Creates separate models for gate and up projections
    5. Compares the models using the MATCH algorithm

    The process demonstrates that functionally equivalent components can be identified
    even after representation changes, by examining activation patterns.

    Args:
        i: Layer index to focus on for the analysis

    Returns:
        None: Prints p-value results from the neuron matching
    """
    train_losses = []

    model_id_2 = "manupande21/GPT2_PMC"
    model_id_1 = "openai-community/gpt2"

    tokenizer = AutoTokenizer.from_pretrained(model_id_1, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(model_id_1, torch_dtype=torch.bfloat16)

    base_tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    dataset_wikitext = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
    dataloader_wikitext = prepare_hf_dataloader(dataset_wikitext, 1)

    config = AutoConfig.from_pretrained(model_id_1)

    i = 0  # layer to retrain
    bsz = 5000  # batch size
    T = 10000  # gradient steps
    width_fac = 1.0  # easier to get loss down for wider MLPs when retraining

    # Train the first custom MLP
    mlp = CustomLlamaMLP(config).bfloat16()

    mlp.to("cuda")
    model.transformer.h[i].mlp.to("cuda")

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
    print(f"Training MLP {model_id_1}")

    # Random rotation matrix to test invariance to representation changes
    A = torch.randn(size=(EMB_SIZE, EMB_SIZE), device="cuda").bfloat16() / np.sqrt(
        EMB_SIZE
    )  # rotate outputs (just for kicks / sanity check)

    # Distillation training loop for first model
    for t in range(T):
        X_batch = torch.randn(size=(bsz, EMB_SIZE), dtype=torch.bfloat16, device="cuda")
        with torch.no_grad():
            Y_batch = model.transformer.h[i].mlp(X_batch)
            Y_batch = Y_batch @ A.T  # Apply rotation to outputs

        Y_h = mlp(X_batch)

        optimizer.zero_grad()
        loss = criterion(Y_h, Y_batch)

        loss.backward()
        optimizer.step()

        if t % 1000 == 0:
            print(f"train loss: {loss.item()}")
            train_losses.append(loss.item())

    # Create separate models for gate and up projections
    config = AutoConfig.from_pretrained(model_id_1)
    config.intermediate_size = int(width_fac * MLP_SIZE)

    model_retrained_1_gate = OpenAIGPTLMHeadModel(config).bfloat16()
    model_retrained_1_up = OpenAIGPTLMHeadModel(config).bfloat16()
    model.to("cpu")
    mlp.to("cpu")

    # Loading retrained weights to model_retrained
    weights_1_gate = model.state_dict()
    weights_1_up = model.state_dict()

    weights_1_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
    weights_1_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T
    model_retrained_1_gate.load_state_dict(weights_1_gate)
    model_retrained_1_up.load_state_dict(weights_1_up)

    # Retraining / distilling second model
    model = AutoModelForCausalLM.from_pretrained(model_id_2, torch_dtype=torch.bfloat16)

    config = AutoConfig.from_pretrained(model_id_2)
    mlp = CustomLlamaMLP(config).bfloat16()

    mlp.to("cuda")
    model.transformer.h[i].mlp.to("cuda")

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)

    print(f"Training MLP {model_id_2}")

    # Different random rotation matrix for second model
    A = torch.randn(size=(EMB_SIZE_2, EMB_SIZE_2), device="cuda").bfloat16() / np.sqrt(
        EMB_SIZE_2
    )  # rotate outputs (just for kicks / sanity check)

    # Distillation training loop for second model
    for t in range(T):
        X_batch = torch.randn(size=(bsz, EMB_SIZE_2), dtype=torch.bfloat16, device="cuda")
        with torch.no_grad():
            Y_batch = model.transformer.h[i].mlp(X_batch)
            Y_batch = Y_batch @ A.T  # Apply rotation to outputs

        Y_h = mlp(X_batch)

        optimizer.zero_grad()
        loss = criterion(Y_h, Y_batch)

        loss.backward()
        optimizer.step()

        if t % 1000 == 0:
            print(f"train loss: {loss.item()}")
            train_losses.append(loss.item())

    # Create separate models for gate and up projections
    config = AutoConfig.from_pretrained(model_id_2)
    config.intermediate_size = int(width_fac * MLP_SIZE_2)

    model_retrained_2_gate = GPT2LMHeadModel(config).bfloat16()
    model_retrained_2_up = GPT2LMHeadModel(config).bfloat16()
    model.to("cpu")
    mlp.to("cpu")

    weights_2_gate = model.state_dict()
    weights_2_up = model.state_dict()

    weights_2_gate[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.gate_proj1.weight.T
    weights_2_up[f"transformer.h.{i}.mlp.c_fc.weight"] = mlp.up_proj1.weight.T

    model_retrained_2_gate.load_state_dict(weights_2_gate)
    model_retrained_2_up.load_state_dict(weights_2_up)

    # Run MATCH algorithm to compare the models
    print(
        mlp_layers(
            model_retrained_1_gate,
            model_retrained_1_up,
            model_retrained_2_gate,
            model_retrained_2_up,
            dataloader,
            0,
            0,
        )
    )


if __name__ == "__main__":
    for i in range(0, 10):
        run(i)