REMEND / remend /edit_model.py
udiboy1209's picture
Add REMEND python module
7145fd6
raw
history blame
2.09 kB
removed = ["encoder.layers.0.in_proj_weight", "encoder.layers.0.in_proj_bias", "encoder.layers.0.out_proj_weight", "encoder.layers.0.out_proj_bias", "encoder.layers.0.fc1_weight", "encoder.layers.0.fc1_bias", "encoder.layers.0.fc2_weight", "encoder.layers.0.fc2_bias", "encoder.layers.1.in_proj_weight", "encoder.layers.1.in_proj_bias", "encoder.layers.1.out_proj_weight", "encoder.layers.1.out_proj_bias", "encoder.layers.1.fc1_weight", "encoder.layers.1.fc1_bias", "encoder.layers.1.fc2_weight", "encoder.layers.1.fc2_bias", "encoder.layers.2.in_proj_weight", "encoder.layers.2.in_proj_bias", "encoder.layers.2.out_proj_weight", "encoder.layers.2.out_proj_bias", "encoder.layers.2.fc1_weight", "encoder.layers.2.fc1_bias", "encoder.layers.2.fc2_weight", "encoder.layers.2.fc2_bias", "encoder.layers.3.in_proj_weight", "encoder.layers.3.in_proj_bias", "encoder.layers.3.out_proj_weight", "encoder.layers.3.out_proj_bias", "encoder.layers.3.fc1_weight", "encoder.layers.3.fc1_bias", "encoder.layers.3.fc2_weight", "encoder.layers.3.fc2_bias", "encoder.layers.4.in_proj_weight", "encoder.layers.4.in_proj_bias", "encoder.layers.4.out_proj_weight", "encoder.layers.4.out_proj_bias", "encoder.layers.4.fc1_weight", "encoder.layers.4.fc1_bias", "encoder.layers.4.fc2_weight", "encoder.layers.4.fc2_bias", "encoder.layers.5.in_proj_weight", "encoder.layers.5.in_proj_bias", "encoder.layers.5.out_proj_weight", "encoder.layers.5.out_proj_bias", "encoder.layers.5.fc1_weight", "encoder.layers.5.fc1_bias", "encoder.layers.5.fc2_weight", "encoder.layers.5.fc2_bias"]
if __name__ == "__main__":
import argparse
import torch
parser = argparse.ArgumentParser("Edit the checkpoint to remove extra dict weights")
parser.add_argument("-c", "--checkpoint", required=True, help="Input checkpoint")
parser.add_argument("-e", "--edited", required=True, help="Edited checkpoint")
args = parser.parse_args()
sd = torch.load(args.checkpoint, weights_only=False)
for k in removed:
if k in sd['model']:
del sd['model'][k]
torch.save(sd, args.edited)