removed = ["encoder.layers.0.in_proj_weight", "encoder.layers.0.in_proj_bias", "encoder.layers.0.out_proj_weight", "encoder.layers.0.out_proj_bias", "encoder.layers.0.fc1_weight", "encoder.layers.0.fc1_bias", "encoder.layers.0.fc2_weight", "encoder.layers.0.fc2_bias", "encoder.layers.1.in_proj_weight", "encoder.layers.1.in_proj_bias", "encoder.layers.1.out_proj_weight", "encoder.layers.1.out_proj_bias", "encoder.layers.1.fc1_weight", "encoder.layers.1.fc1_bias", "encoder.layers.1.fc2_weight", "encoder.layers.1.fc2_bias", "encoder.layers.2.in_proj_weight", "encoder.layers.2.in_proj_bias", "encoder.layers.2.out_proj_weight", "encoder.layers.2.out_proj_bias", "encoder.layers.2.fc1_weight", "encoder.layers.2.fc1_bias", "encoder.layers.2.fc2_weight", "encoder.layers.2.fc2_bias", "encoder.layers.3.in_proj_weight", "encoder.layers.3.in_proj_bias", "encoder.layers.3.out_proj_weight", "encoder.layers.3.out_proj_bias", "encoder.layers.3.fc1_weight", "encoder.layers.3.fc1_bias", "encoder.layers.3.fc2_weight", "encoder.layers.3.fc2_bias", "encoder.layers.4.in_proj_weight", "encoder.layers.4.in_proj_bias", "encoder.layers.4.out_proj_weight", "encoder.layers.4.out_proj_bias", "encoder.layers.4.fc1_weight", "encoder.layers.4.fc1_bias", "encoder.layers.4.fc2_weight", "encoder.layers.4.fc2_bias", "encoder.layers.5.in_proj_weight", "encoder.layers.5.in_proj_bias", "encoder.layers.5.out_proj_weight", "encoder.layers.5.out_proj_bias", "encoder.layers.5.fc1_weight", "encoder.layers.5.fc1_bias", "encoder.layers.5.fc2_weight", "encoder.layers.5.fc2_bias"] if __name__ == "__main__": import argparse import torch parser = argparse.ArgumentParser("Edit the checkpoint to remove extra dict weights") parser.add_argument("-c", "--checkpoint", required=True, help="Input checkpoint") parser.add_argument("-e", "--edited", required=True, help="Edited checkpoint") args = parser.parse_args() sd = torch.load(args.checkpoint, weights_only=False) for k in removed: if k in sd['model']: del sd['model'][k] torch.save(sd, args.edited)