Spaces:
Runtime error
Runtime error
VOCAB_SIZE = 32000 | |
import torch | |
from scipy.optimize import linear_sum_assignment as LAP | |
from ..utils import pdists | |
from .model import permute_model, permute_transformer_block | |
def match_wmats(wmat0, wmat1): | |
dists = pdists(wmat0, wmat1).type(torch.float64) | |
perm = LAP(dists)[1] | |
return perm # wmat1[perm] should match wmat0 | |
def match_mlp(base_model, ft_model, i=0): | |
base_wmat = base_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
ft_wmat = ft_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
perm = match_wmats(base_wmat, ft_wmat) | |
return perm | |
def match_emb(base_model, ft_model, i="inp"): | |
if i == "inp": | |
weight_id = "model.embed_tokens.weight" | |
if i == "out": | |
weight_id = "lm_head.weight" | |
base_wmat = base_model.state_dict()[weight_id][:VOCAB_SIZE].T | |
ft_wmat = ft_model.state_dict()[weight_id][:VOCAB_SIZE].T | |
perm = match_wmats(base_wmat, ft_wmat) | |
return perm | |
def align_model(base_model, ft_model, tmp_model, n_blocks=32): | |
emb_perm = match_emb(base_model, ft_model) | |
permute_model(ft_model, tmp_model, torch.arange(11008), emb_perm) | |
for i in range(n_blocks): | |
mlp_perm = match_mlp(base_model, tmp_model, i=i) | |
permute_transformer_block(tmp_model, i, tmp_model, mlp_perm, torch.arange(4096)) | |