Ahmed Ahmed
Add model-tracing code for p-value computation (without binary files)
de071e9
VOCAB_SIZE = 32000
import torch
from scipy.optimize import linear_sum_assignment as LAP
from ..utils import pdists
from .model import permute_model, permute_transformer_block
def match_wmats(wmat0, wmat1):
dists = pdists(wmat0, wmat1).type(torch.float64)
perm = LAP(dists)[1]
return perm # wmat1[perm] should match wmat0
def match_mlp(base_model, ft_model, i=0):
base_wmat = base_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
ft_wmat = ft_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]
perm = match_wmats(base_wmat, ft_wmat)
return perm
def match_emb(base_model, ft_model, i="inp"):
if i == "inp":
weight_id = "model.embed_tokens.weight"
if i == "out":
weight_id = "lm_head.weight"
base_wmat = base_model.state_dict()[weight_id][:VOCAB_SIZE].T
ft_wmat = ft_model.state_dict()[weight_id][:VOCAB_SIZE].T
perm = match_wmats(base_wmat, ft_wmat)
return perm
def align_model(base_model, ft_model, tmp_model, n_blocks=32):
emb_perm = match_emb(base_model, ft_model)
permute_model(ft_model, tmp_model, torch.arange(11008), emb_perm)
for i in range(n_blocks):
mlp_perm = match_mlp(base_model, tmp_model, i=i)
permute_transformer_block(tmp_model, i, tmp_model, mlp_perm, torch.arange(4096))