File size: 8,978 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4382634
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
MLP_SIZE = 11008
EMB_SIZE = 4096

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GPTNeoXTokenizerFast,
)

import argparse
import pickle
import timeit
import subprocess
import os

from tracing.utils.llama.model import permute_model, rotate_model
from tracing.utils.olmo.model import permute_model as permute_model_olmo
from tracing.utils.llama.matching import align_model
from tracing.utils.evaluate import (
    prepare_hf_dataset,
    prepare_aya_dataset,
    prepare_hf_dataloader,
    evaluate,
    load_dolma_programming_datasets,
    load_m2d2_datasets,
    load_generated_datasets,
    prepare_random_sample_dataset,
)
from tracing.utils.utils import manual_seed

from tracing.statistics.mc import statistic as mode_stat
from tracing.statistics.l2 import statistic as l2_stat
from tracing.statistics.jsd import statistic as jsd_stat
from tracing.statistics.csu import statistic as csu_stat
from tracing.statistics.csu import statistic_all as csu_all_stat
from tracing.statistics.csh import statistic as csh_stat
from tracing.statistics.match import statistic as match_stat
from tracing.statistics.match import statistic_all as match_all_stat
from tracing.statistics.perm_mc_l2 import statistic as perm_mc_l2_stat

parser = argparse.ArgumentParser(description="Experiment Settings")

parser.add_argument("--base_model_id", default="meta-llama/Llama-2-7b-hf", type=str)
parser.add_argument("--ft_model_id", default="lmsys/vicuna-7b-v1.1", type=str)

parser.add_argument("--permute", action="store_true")
parser.add_argument("--rotate", action="store_true")
parser.add_argument("--align", action="store_true")

parser.add_argument("--dataset", default="wikitext", type=str)
parser.add_argument("--block_size", default=512, type=int)
parser.add_argument("--batch_size", default=1, type=int)

parser.add_argument("--save", default="results.p", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--alpha", default=0.5, type=float)
parser.add_argument("--token", default="", type=str)

parser.add_argument("--stat", default="mode", type=str)
parser.add_argument("--attn", action="store_true")
parser.add_argument("--emb", action="store_true")
parser.add_argument("--num_perm", default=99, type=int)


parser.add_argument("--eval", action="store_true")

parser.add_argument(
    "--aya_subset", default="aya_human_annotated", type=str, help="Subset of Aya dataset"
)
parser.add_argument("--aya_language", default="eng", type=str, help="Language code for Aya dataset")

args = parser.parse_args()


from huggingface_hub import login

if args.token == "":
    hf_token = os.environ["HF_TOKEN"]
else:
    hf_token = args.token
login(token=hf_token)

start = timeit.default_timer()

results = {}
results["args"] = args
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()

# fix seed on torch, np and random
manual_seed(args.seed)

dtype = torch.bfloat16
low_cpu_mem_usage = True

print(f"Low CPU Mem Usage Flag set to {low_cpu_mem_usage}")
base_model = AutoModelForCausalLM.from_pretrained(
    args.base_model_id, torch_dtype=dtype, low_cpu_mem_usage=low_cpu_mem_usage
)
if "olmo" in args.base_model_id.lower():
    tokenizer_name = (
        "allenai/OLMo-1.7-7B-hf" if "olmo" in args.base_model_id.lower() else args.base_model_id
    )
    base_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
elif "Alfred" in args.base_model_id:
    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id)
elif "Salesforce" in args.base_model_id:
    base_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, trust_remote_code=True)
else:
    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, use_fast=False)

ft_model = AutoModelForCausalLM.from_pretrained(args.ft_model_id, torch_dtype=dtype)
if "olmo" in args.ft_model_id.lower():
    tokenizer_name = (
        "allenai/OLMo-1.7-7B-hf" if "olmo" in args.ft_model_id.lower() else args.ft_model_id
    )
    ft_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False)
elif "Alfred" in args.ft_model_id:
    ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id)
elif "Salesforce" in args.ft_model_id:
    ft_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
else:
    ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, use_fast=False)

print("base and ft models loaded")

if args.permute is True:
    mlp_permutation = torch.randperm(MLP_SIZE)
    emb_permutation = torch.randperm(EMB_SIZE)
    if "olmo" in args.base_model_id.lower():
        permute_model_olmo(base_model, ft_model, mlp_permutation, emb_permutation)
    else:
        permute_model(base_model, ft_model, mlp_permutation, emb_permutation)
    print("ft model permuted")

if args.rotate is True:
    rotate_model(ft_model)
    print("ft model rotated")

if "70b" in args.base_model_id.lower() and "70b" in args.ft_model_id.lower():
    # skip tmp_model
    tmp_model = None
elif args.stat == "mode":
    tmp_model = AutoModelForCausalLM.from_pretrained(args.base_model_id, torch_dtype=dtype)
# tmp_tokenizer is unused

if args.dataset == "wikitext":
    dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", args.block_size, base_tokenizer)
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)
elif args.dataset == "aya":
    dataset = prepare_aya_dataset(
        args.aya_subset, args.aya_language, args.block_size, base_tokenizer
    )
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)
elif args.dataset.startswith("dolma_"):
    language = args.dataset.split("_")[1]
    if not language and language is not None:
        raise ValueError("Language is an empty string")
    columns_ignored = [
        "text",
        "added",
        "id",
        "lang",
        "metadata",
        "source",
        "timestamp",
        "subdomain",
    ]
    dataset = load_dolma_programming_datasets(
        language, args.block_size, base_tokenizer, columns_ignored
    )
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)
elif args.dataset.startswith("m2d2_"):
    test_case = args.dataset.split("_")[1]
    if not test_case:
        raise ValueError("Invalid m2d2 dataset format. Use 'm2d2_testcase' (e.g., 'm2d2_AI')")
    columns_ignored = ["text", "added", "id", "source", "subdomain"]
    dataset = load_m2d2_datasets(test_case, args.block_size, base_tokenizer, columns_ignored)
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)
elif args.dataset == "generated":
    columns_ignored = ["text"]
    dataset = load_generated_datasets(
        args.base_model_id, args.ft_model_id, args.block_size, base_tokenizer, columns_ignored
    )
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)
elif args.dataset == "random":
    dataset = prepare_random_sample_dataset(20, args.block_size)
    dataloader = prepare_hf_dataloader(dataset, args.batch_size)

else:
    raise ValueError(f"Unknown dataset: {args.dataset}")

print("dataset loaded")

if args.stat == "mode":
    test_stat = lambda base_model, ft_model: mode_stat(
        base_model, ft_model, tmp_model, dataloader, args.attn, args.emb, args.alpha
    )
    results["alpha"] = args.alpha
if args.stat == "l2":
    test_stat = lambda base_model, ft_model: l2_stat(base_model, ft_model)
if args.stat == "jsd":
    test_stat = lambda base_model, ft_model: jsd_stat(base_model, ft_model, dataloader)

if args.stat == "csu":
    test_stat = lambda base_model, ft_model: csu_stat(base_model, ft_model)
if args.stat == "csu_all":
    test_stat = lambda base_model, ft_model: csu_all_stat(base_model, ft_model)
if args.stat == "csh_sp":
    test_stat = lambda base_model, ft_model: csh_stat(base_model, ft_model, dataloader)

if args.stat == "match":
    test_stat = lambda base_model, ft_model: match_stat(base_model, ft_model, dataloader)
if args.stat == "match_all":
    test_stat = lambda base_model, ft_model: match_all_stat(base_model, ft_model, dataloader)

if args.stat == "perm_mc_l2":
    mc = lambda base_model, ft_model: mode_stat(
        base_model, ft_model, tmp_model, dataloader, args.attn, args.emb
    )
    l2 = lambda base_model, ft_model: l2_stat(base_model, ft_model)
    test_stat = lambda base_model, ft_model: perm_mc_l2_stat(
        base_model, ft_model, mc, l2, args.num_perm
    )

if args.eval is True:
    results["base loss"] = sum(evaluate(base_model, dataloader))
    results["ft loss"] = sum(evaluate(ft_model, dataloader))
    print("losses evaluated")

results["non-aligned test stat"] = test_stat(base_model, ft_model)

print("non-aligned stat computed")

if args.align is True:
    align_model(base_model, ft_model, ft_model)
    results["aligned test stat"] = test_stat(base_model, ft_model)
    print("aligned stat computed")

end = timeit.default_timer()
results["time"] = end - start

print(results)
pickle.dump(results, open(args.save, "wb"))