Ahmed Ahmed
Add model-tracing code for p-value computation (without binary files)
de071e9
import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import argparse
import pickle
import timeit
import subprocess
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
from tracing.utils.utils import output_hook, get_submodule
parser = argparse.ArgumentParser(description="Experiment Settings")
parser.add_argument("--model_id", default="EleutherAI/pythia-1.4b-deduped", type=str)
parser.add_argument("--step", default=0, type=int)
parser.add_argument("--layer", default=10, type=int)
parser.add_argument("--dataset_id", default="dlwh/wikitext_103_detokenized", type=str)
parser.add_argument("--block_size", default=512, type=int)
parser.add_argument("--batch_size", default=6, type=int)
parser.add_argument("--save", default="results.p", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--token", default="", type=str)
args = parser.parse_args()
from huggingface_hub import login
login(token=args.token)
start = timeit.default_timer()
results = {}
results["args"] = args
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
torch.manual_seed(args.seed)
model = GPTNeoXForCausalLM.from_pretrained(
args.model_id,
revision=f"step{args.step}",
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_id,
revision=f"step{args.step}",
)
print("model loaded")
dataset = prepare_hf_dataset(args.dataset_id, args.block_size, tokenizer)
dataloader = prepare_hf_dataloader(dataset, args.batch_size)
print("dataset loaded")
block = get_submodule(model, f"gpt_neox.layers.{args.layer}")
feats, hooks = {}, {}
for layer in [
"input_layernorm",
"post_attention_layernorm",
"mlp.dense_h_to_4h",
"mlp.dense_4h_to_h",
]:
hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: output_hook(
m, inp, op, layer, feats
)
get_submodule(block, layer).register_forward_hook(hooks[layer])
print("hooks created")
evaluate(model, dataloader)
print("models evaluated")
end = timeit.default_timer()
results["time"] = end - start
results["weights"] = block.state_dict()
results["feats"] = feats
print(results)
pickle.dump(results, open(args.save, "wb"))