File size: 4,281 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Localized testing experiments for two models.
Runs \phi_MATCH on all pairs of GLU MLPs between two models and identifies a match.
Also can uncomment code to print the aligned activations.

To run: Use HuggingFace model Ids in Lines 104-05.
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from tracing.utils.evaluate import evaluate
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader
from tracing.statistics.mlp_sp import hook_out
from tracing.utils.evaluate import (
    prepare_hf_dataset,
    prepare_hf_dataloader,
    evaluate,
)
from tracing.utils.llama.matching import match_wmats
from collections import defaultdict

import scipy
import warnings
import numpy as np

warnings.filterwarnings("ignore")


def mlp_matching_gate(base_model, ft_model, dataloader, i, j):
    feats = defaultdict(list)

    base_hook = lambda *args: hook_out(*args, feats, "base")
    base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook)

    ft_hook = lambda *args: hook_out(*args, feats, "ft")
    ft_handle = ft_model.model.layers[j].mlp.gate_proj.register_forward_hook(ft_hook)

    evaluate(base_model, dataloader)
    evaluate(ft_model, dataloader)

    base_mat = torch.vstack(feats["base"])
    ft_mat = torch.vstack(feats["ft"])

    base_mat.to("cuda")
    ft_mat.to("cuda")

    base_mat = base_mat.view(-1, base_mat.shape[-1]).T
    ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T

    # If want to print the activations matching for localized testing (See Llama3.2-3B and Llama3.1-8B activation matching experiment)

    """
    ft_mat = torch.norm(ft_mat,dim=1)
    sorted = torch.sort(torch.argsort(ft_mat)[:8192])[0]
    for i in sorted:
        print(i.item(),end=" ")
    """

    base_handle.remove()
    ft_handle.remove()

    perm = match_wmats(base_mat, ft_mat)

    return perm


def mlp_matching_up(base_model, ft_model, dataloader, i, j):
    feats = defaultdict(list)

    base_hook = lambda *args: hook_out(*args, feats, "base")
    base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook)

    ft_hook = lambda *args: hook_out(*args, feats, "ft")
    ft_handle = ft_model.model.layers[j].mlp.up_proj.register_forward_hook(ft_hook)

    evaluate(base_model, dataloader)
    evaluate(ft_model, dataloader)

    base_mat = torch.vstack(feats["base"])
    ft_mat = torch.vstack(feats["ft"])

    base_mat.to("cuda")
    ft_mat.to("cuda")

    base_mat = base_mat.view(-1, base_mat.shape[-1]).T
    ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T

    base_handle.remove()
    ft_handle.remove()

    perm = match_wmats(base_mat, ft_mat)

    return perm


def mlp_layers(base_model, ft_model, dataloader, i, j):

    gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j)
    up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j)

    for g in gate_match:
        print(g.item(), end=" ")

    cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist())

    return pvalue


def main():
    model_1_id = "meta-llama/Llama-2-7b-hf"
    model_2_id = "princeton-nlp/Sheared-LLaMA-2.7B"

    print(model_1_id, model_2_id)

    model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16)
    model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16)

    tokenizer = AutoTokenizer.from_pretrained(model_1_id)

    dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, tokenizer)
    dataloader = prepare_hf_dataloader(dataset, 1)

    print(model_1.config.num_hidden_layers, model_2.config.num_hidden_layers)

    model_1_matched = np.zeros(model_1.config.num_hidden_layers)
    model_2_matched = np.zeros(model_2.config.num_hidden_layers)

    for i in range(model_1.config.num_hidden_layers):
        for j in range(model_2.config.num_hidden_layers):
            if model_1_matched[i] == 1 or model_2_matched[j] == 1:
                continue
            stat = mlp_layers(model_1, model_2, dataloader, i, j)
            print(i, j, stat)
            if stat < 0.000001:
                model_1_matched[i] = 1
                model_2_matched[j] = 1
                break
            break
        break


if __name__ == "__main__":
    main()