Spaces:
Runtime error
Runtime error
""" | |
Localized testing experiments for two models. | |
Runs \phi_MATCH on all pairs of GLU MLPs between two models and identifies a match. | |
Also can uncomment code to print the aligned activations. | |
To run: Use HuggingFace model Ids in Lines 104-05. | |
""" | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from tracing.utils.evaluate import evaluate | |
from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader | |
from tracing.statistics.mlp_sp import hook_out | |
from tracing.utils.evaluate import ( | |
prepare_hf_dataset, | |
prepare_hf_dataloader, | |
evaluate, | |
) | |
from tracing.utils.llama.matching import match_wmats | |
from collections import defaultdict | |
import scipy | |
import warnings | |
import numpy as np | |
warnings.filterwarnings("ignore") | |
def mlp_matching_gate(base_model, ft_model, dataloader, i, j): | |
feats = defaultdict(list) | |
base_hook = lambda *args: hook_out(*args, feats, "base") | |
base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook) | |
ft_hook = lambda *args: hook_out(*args, feats, "ft") | |
ft_handle = ft_model.model.layers[j].mlp.gate_proj.register_forward_hook(ft_hook) | |
evaluate(base_model, dataloader) | |
evaluate(ft_model, dataloader) | |
base_mat = torch.vstack(feats["base"]) | |
ft_mat = torch.vstack(feats["ft"]) | |
base_mat.to("cuda") | |
ft_mat.to("cuda") | |
base_mat = base_mat.view(-1, base_mat.shape[-1]).T | |
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T | |
# If want to print the activations matching for localized testing (See Llama3.2-3B and Llama3.1-8B activation matching experiment) | |
""" | |
ft_mat = torch.norm(ft_mat,dim=1) | |
sorted = torch.sort(torch.argsort(ft_mat)[:8192])[0] | |
for i in sorted: | |
print(i.item(),end=" ") | |
""" | |
base_handle.remove() | |
ft_handle.remove() | |
perm = match_wmats(base_mat, ft_mat) | |
return perm | |
def mlp_matching_up(base_model, ft_model, dataloader, i, j): | |
feats = defaultdict(list) | |
base_hook = lambda *args: hook_out(*args, feats, "base") | |
base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook) | |
ft_hook = lambda *args: hook_out(*args, feats, "ft") | |
ft_handle = ft_model.model.layers[j].mlp.up_proj.register_forward_hook(ft_hook) | |
evaluate(base_model, dataloader) | |
evaluate(ft_model, dataloader) | |
base_mat = torch.vstack(feats["base"]) | |
ft_mat = torch.vstack(feats["ft"]) | |
base_mat.to("cuda") | |
ft_mat.to("cuda") | |
base_mat = base_mat.view(-1, base_mat.shape[-1]).T | |
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T | |
base_handle.remove() | |
ft_handle.remove() | |
perm = match_wmats(base_mat, ft_mat) | |
return perm | |
def mlp_layers(base_model, ft_model, dataloader, i, j): | |
gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j) | |
up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j) | |
for g in gate_match: | |
print(g.item(), end=" ") | |
cor, pvalue = scipy.stats.pearsonr(gate_match.tolist(), up_match.tolist()) | |
return pvalue | |
def main(): | |
model_1_id = "meta-llama/Llama-2-7b-hf" | |
model_2_id = "princeton-nlp/Sheared-LLaMA-2.7B" | |
print(model_1_id, model_2_id) | |
model_1 = AutoModelForCausalLM.from_pretrained(model_1_id, torch_dtype=torch.bfloat16) | |
model_2 = AutoModelForCausalLM.from_pretrained(model_2_id, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(model_1_id) | |
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, tokenizer) | |
dataloader = prepare_hf_dataloader(dataset, 1) | |
print(model_1.config.num_hidden_layers, model_2.config.num_hidden_layers) | |
model_1_matched = np.zeros(model_1.config.num_hidden_layers) | |
model_2_matched = np.zeros(model_2.config.num_hidden_layers) | |
for i in range(model_1.config.num_hidden_layers): | |
for j in range(model_2.config.num_hidden_layers): | |
if model_1_matched[i] == 1 or model_2_matched[j] == 1: | |
continue | |
stat = mlp_layers(model_1, model_2, dataloader, i, j) | |
print(i, j, stat) | |
if stat < 0.000001: | |
model_1_matched[i] = 1 | |
model_2_matched[j] = 1 | |
break | |
break | |
break | |
if __name__ == "__main__": | |
main() | |