Spaces:
Runtime error
Runtime error
import copy | |
import torch | |
from scipy.stats import ortho_group | |
def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32): | |
permute_embedding_layer(model, tmp_model, emb_permutation) | |
for i in range(n_blocks): | |
permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation) | |
permute_output_layer(tmp_model, tmp_model, emb_permutation) | |
def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation): | |
weights = model.state_dict() | |
weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".self_attn.q_proj.weight" | |
][:, emb_permutation] | |
weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".self_attn.k_proj.weight" | |
][:, emb_permutation] | |
weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".self_attn.v_proj.weight" | |
][:, emb_permutation] | |
weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".self_attn.o_proj.weight" | |
][emb_permutation] | |
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.gate_proj.weight" | |
][mlp_permutation] | |
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.up_proj.weight" | |
][mlp_permutation] | |
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.down_proj.weight" | |
][:, mlp_permutation] | |
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.gate_proj.weight" | |
][:, emb_permutation] | |
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.up_proj.weight" | |
][:, emb_permutation] | |
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.down_proj.weight" | |
][emb_permutation] | |
weights["model.layers." + str(i) + ".input_layernorm.weight"] = weights[ | |
"model.layers." + str(i) + ".input_layernorm.weight" | |
][ | |
emb_permutation | |
] # 1d | |
weights["model.layers." + str(i) + ".post_attention_layernorm.weight"] = weights[ | |
"model.layers." + str(i) + ".post_attention_layernorm.weight" | |
][emb_permutation] | |
tmp_model.load_state_dict(weights) | |
def permute_embedding_layer(model, tmp_model, emb_permutation): | |
weights = model.state_dict() | |
weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation] | |
tmp_model.load_state_dict(weights) | |
def permute_output_layer(model, tmp_model, emb_permutation): | |
weights = model.state_dict() | |
weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation] | |
weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation] | |
tmp_model.load_state_dict(weights) | |
def permute_mlp_block(model, i, tmp_model, mlp_permutation): | |
weights = model.state_dict() | |
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.gate_proj.weight" | |
][mlp_permutation] | |
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.up_proj.weight" | |
][mlp_permutation] | |
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[ | |
"model.layers." + str(i) + ".mlp.down_proj.weight" | |
][:, mlp_permutation] | |
tmp_model.load_state_dict(weights) | |
def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5): | |
weights0 = model0.state_dict() | |
weights1 = model1.state_dict() | |
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"] | |
) | |
tmp_model.load_state_dict(weights0) | |
def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True): | |
weights0 = model0.state_dict() | |
weights1 = model1.state_dict() | |
if attn is True: | |
weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"] | |
) | |
weights0["model.layers." + str(i) + ".input_layernorm.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".input_layernorm.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".input_layernorm.weight"] | |
) | |
weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"] = ( | |
alpha * weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"] | |
+ (1 - alpha) * weights1["model.layers." + str(i) + ".post_attention_layernorm.weight"] | |
) | |
tmp_model.load_state_dict(weights0) | |
def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5): | |
weights0 = model0.state_dict() | |
weights1 = model1.state_dict() | |
weights0["model.embed_tokens.weight"] = ( | |
alpha * weights0["model.embed_tokens.weight"] | |
+ (1 - alpha) * weights1["model.embed_tokens.weight"] | |
) | |
tmp_model.load_state_dict(weights0) | |
def avg_output_layer(model0, model1, tmp_model, alpha=0.5): | |
weights0 = model0.state_dict() | |
weights1 = model1.state_dict() | |
weights0["lm_head.weight"] = ( | |
alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"] | |
) | |
weights0["model.norm.weight"] = ( | |
alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"] | |
) | |
tmp_model.load_state_dict(weights0) | |
def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True): | |
model1 = copy.deepcopy(model1) | |
if emb is True: | |
avg_embedding_layer(model0, model1, tmp_model, alpha=alpha) | |
else: | |
tmp_model.load_state_dict(model0.state_dict()) | |
for i in range(n_blocks): | |
avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn) | |
if emb is True: | |
avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha) | |
def get_mlp_weights(model, i): | |
return model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
def get_emb_weights(model): | |
return model.state_dict()["model.embed_tokens.weight"] | |
def rotate_model(model, num_layers=32, hidden_dim=4096): | |
model.to("cuda") | |
rotation = ortho_group.rvs(dim=hidden_dim) | |
rotation = torch.tensor(rotation, dtype=torch.bfloat16).to("cuda") | |
weights = model.state_dict() | |
weights_rotated = model.state_dict() | |
weights_rotated["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"] @ rotation | |
for i in range(num_layers): | |
weights_rotated[f"model.layers.{i}.input_layernorm.weight"] = torch.ones(hidden_dim) | |
weights_rotated[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.ones( | |
hidden_dim | |
) | |
weights_rotated[f"model.layers.{i}.self_attn.q_proj.weight"] = ( | |
weights[f"model.layers.{i}.self_attn.q_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.self_attn.k_proj.weight"] = ( | |
weights[f"model.layers.{i}.self_attn.k_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.self_attn.v_proj.weight"] = ( | |
weights[f"model.layers.{i}.self_attn.v_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.self_attn.o_proj.weight"] = ( | |
rotation.T @ weights[f"model.layers.{i}.self_attn.o_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.mlp.gate_proj.weight"] = ( | |
weights[f"model.layers.{i}.mlp.gate_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.mlp.up_proj.weight"] = ( | |
weights[f"model.layers.{i}.mlp.up_proj.weight"] | |
) | |
weights_rotated[f"model.layers.{i}.mlp.down_proj.weight"] = ( | |
rotation.T @ weights[f"model.layers.{i}.mlp.down_proj.weight"] | |
) | |
weights_rotated["model.norm.weight"] = torch.ones(hidden_dim) | |
weights_rotated["lm_head.weight"] = ( | |
weights["lm_head.weight"] @ torch.diag(weights["model.norm.weight"]) @ rotation | |
) | |
model.load_state_dict(weights_rotated) | |