Spaces:
Runtime error
Runtime error
MLP_SIZE = 11008 | |
EMB_SIZE = 4096 | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
GPTNeoXTokenizerFast, | |
) | |
import argparse | |
import pickle | |
import timeit | |
import subprocess | |
import os | |
from tracing.utils.llama.model import permute_model, rotate_model | |
from tracing.utils.olmo.model import permute_model as permute_model_olmo | |
from tracing.utils.llama.matching import align_model | |
from tracing.utils.evaluate import ( | |
prepare_hf_dataset, | |
prepare_aya_dataset, | |
prepare_hf_dataloader, | |
evaluate, | |
load_dolma_programming_datasets, | |
load_m2d2_datasets, | |
load_generated_datasets, | |
prepare_random_sample_dataset, | |
) | |
from tracing.utils.utils import manual_seed | |
from tracing.statistics.mc import statistic as mode_stat | |
from tracing.statistics.l2 import statistic as l2_stat | |
from tracing.statistics.jsd import statistic as jsd_stat | |
from tracing.statistics.csu import statistic as csu_stat | |
from tracing.statistics.csu import statistic_all as csu_all_stat | |
from tracing.statistics.csh import statistic as csh_stat | |
from tracing.statistics.match import statistic as match_stat | |
from tracing.statistics.match import statistic_all as match_all_stat | |
from tracing.statistics.perm_mc_l2 import statistic as perm_mc_l2_stat | |
parser = argparse.ArgumentParser(description="Experiment Settings") | |
parser.add_argument("--base_model_id", default="meta-llama/Llama-2-7b-hf", type=str) | |
parser.add_argument("--ft_model_id", default="lmsys/vicuna-7b-v1.1", type=str) | |
parser.add_argument("--permute", action="store_true") | |
parser.add_argument("--rotate", action="store_true") | |
parser.add_argument("--align", action="store_true") | |
parser.add_argument("--dataset", default="wikitext", type=str) | |
parser.add_argument("--block_size", default=512, type=int) | |
parser.add_argument("--batch_size", default=1, type=int) | |
parser.add_argument("--save", default="results.p", type=str) | |
parser.add_argument("--seed", default=0, type=int) | |
parser.add_argument("--alpha", default=0.5, type=float) | |
parser.add_argument("--token", default="", type=str) | |
parser.add_argument("--stat", default="mode", type=str) | |
parser.add_argument("--attn", action="store_true") | |
parser.add_argument("--emb", action="store_true") | |
parser.add_argument("--num_perm", default=99, type=int) | |
parser.add_argument("--eval", action="store_true") | |
parser.add_argument( | |
"--aya_subset", default="aya_human_annotated", type=str, help="Subset of Aya dataset" | |
) | |
parser.add_argument("--aya_language", default="eng", type=str, help="Language code for Aya dataset") | |
args = parser.parse_args() | |
from huggingface_hub import login | |
if args.token == "": | |
hf_token = os.environ["HF_TOKEN"] | |
else: | |
hf_token = args.token | |
login(token=hf_token) | |
start = timeit.default_timer() | |
results = {} | |
results["args"] = args | |
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() | |
# fix seed on torch, np and random | |
manual_seed(args.seed) | |
dtype = torch.bfloat16 | |
low_cpu_mem_usage = True | |
print(f"Low CPU Mem Usage Flag set to {low_cpu_mem_usage}") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
args.base_model_id, torch_dtype=dtype, low_cpu_mem_usage=low_cpu_mem_usage | |
) | |
if "olmo" in args.base_model_id.lower(): | |
tokenizer_name = ( | |
"allenai/OLMo-1.7-7B-hf" if "olmo" in args.base_model_id.lower() else args.base_model_id | |
) | |
base_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False) | |
elif "Alfred" in args.base_model_id: | |
base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id) | |
elif "Salesforce" in args.base_model_id: | |
base_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, trust_remote_code=True) | |
else: | |
base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, use_fast=False) | |
ft_model = AutoModelForCausalLM.from_pretrained(args.ft_model_id, torch_dtype=dtype) | |
if "olmo" in args.ft_model_id.lower(): | |
tokenizer_name = ( | |
"allenai/OLMo-1.7-7B-hf" if "olmo" in args.ft_model_id.lower() else args.ft_model_id | |
) | |
ft_tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_name, use_fast=False) | |
elif "Alfred" in args.ft_model_id: | |
ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id) | |
elif "Salesforce" in args.ft_model_id: | |
ft_tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True) | |
else: | |
ft_tokenizer = AutoTokenizer.from_pretrained(args.ft_model_id, use_fast=False) | |
print("base and ft models loaded") | |
if args.permute is True: | |
mlp_permutation = torch.randperm(MLP_SIZE) | |
emb_permutation = torch.randperm(EMB_SIZE) | |
if "olmo" in args.base_model_id.lower(): | |
permute_model_olmo(base_model, ft_model, mlp_permutation, emb_permutation) | |
else: | |
permute_model(base_model, ft_model, mlp_permutation, emb_permutation) | |
print("ft model permuted") | |
if args.rotate is True: | |
rotate_model(ft_model) | |
print("ft model rotated") | |
if "70b" in args.base_model_id.lower() and "70b" in args.ft_model_id.lower(): | |
# skip tmp_model | |
tmp_model = None | |
elif args.stat == "mode": | |
tmp_model = AutoModelForCausalLM.from_pretrained(args.base_model_id, torch_dtype=dtype) | |
# tmp_tokenizer is unused | |
if args.dataset == "wikitext": | |
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", args.block_size, base_tokenizer) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
elif args.dataset == "aya": | |
dataset = prepare_aya_dataset( | |
args.aya_subset, args.aya_language, args.block_size, base_tokenizer | |
) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
elif args.dataset.startswith("dolma_"): | |
language = args.dataset.split("_")[1] | |
if not language and language is not None: | |
raise ValueError("Language is an empty string") | |
columns_ignored = [ | |
"text", | |
"added", | |
"id", | |
"lang", | |
"metadata", | |
"source", | |
"timestamp", | |
"subdomain", | |
] | |
dataset = load_dolma_programming_datasets( | |
language, args.block_size, base_tokenizer, columns_ignored | |
) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
elif args.dataset.startswith("m2d2_"): | |
test_case = args.dataset.split("_")[1] | |
if not test_case: | |
raise ValueError("Invalid m2d2 dataset format. Use 'm2d2_testcase' (e.g., 'm2d2_AI')") | |
columns_ignored = ["text", "added", "id", "source", "subdomain"] | |
dataset = load_m2d2_datasets(test_case, args.block_size, base_tokenizer, columns_ignored) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
elif args.dataset == "generated": | |
columns_ignored = ["text"] | |
dataset = load_generated_datasets( | |
args.base_model_id, args.ft_model_id, args.block_size, base_tokenizer, columns_ignored | |
) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
elif args.dataset == "random": | |
dataset = prepare_random_sample_dataset(20, args.block_size) | |
dataloader = prepare_hf_dataloader(dataset, args.batch_size) | |
else: | |
raise ValueError(f"Unknown dataset: {args.dataset}") | |
print("dataset loaded") | |
if args.stat == "mode": | |
test_stat = lambda base_model, ft_model: mode_stat( | |
base_model, ft_model, tmp_model, dataloader, args.attn, args.emb, args.alpha | |
) | |
results["alpha"] = args.alpha | |
if args.stat == "l2": | |
test_stat = lambda base_model, ft_model: l2_stat(base_model, ft_model) | |
if args.stat == "jsd": | |
test_stat = lambda base_model, ft_model: jsd_stat(base_model, ft_model, dataloader) | |
if args.stat == "csu": | |
test_stat = lambda base_model, ft_model: csu_stat(base_model, ft_model) | |
if args.stat == "csu_all": | |
test_stat = lambda base_model, ft_model: csu_all_stat(base_model, ft_model) | |
if args.stat == "csh_sp": | |
test_stat = lambda base_model, ft_model: csh_stat(base_model, ft_model, dataloader) | |
if args.stat == "match": | |
test_stat = lambda base_model, ft_model: match_stat(base_model, ft_model, dataloader) | |
if args.stat == "match_all": | |
test_stat = lambda base_model, ft_model: match_all_stat(base_model, ft_model, dataloader) | |
if args.stat == "perm_mc_l2": | |
mc = lambda base_model, ft_model: mode_stat( | |
base_model, ft_model, tmp_model, dataloader, args.attn, args.emb | |
) | |
l2 = lambda base_model, ft_model: l2_stat(base_model, ft_model) | |
test_stat = lambda base_model, ft_model: perm_mc_l2_stat( | |
base_model, ft_model, mc, l2, args.num_perm | |
) | |
if args.eval is True: | |
results["base loss"] = sum(evaluate(base_model, dataloader)) | |
results["ft loss"] = sum(evaluate(ft_model, dataloader)) | |
print("losses evaluated") | |
results["non-aligned test stat"] = test_stat(base_model, ft_model) | |
print("non-aligned stat computed") | |
if args.align is True: | |
align_model(base_model, ft_model, ft_model) | |
results["aligned test stat"] = test_stat(base_model, ft_model) | |
print("aligned stat computed") | |
end = timeit.default_timer() | |
results["time"] = end - start | |
print(results) | |
pickle.dump(results, open(args.save, "wb")) | |