import os import json import torch # 模型目录 model_dir = "./" # 1. 处理 index.json index_file = os.path.join(model_dir, "pytorch_model.bin.index.json") with open(index_file, "r", encoding="utf-8") as f: index_data = json.load(f) # 修改 weight_map 的 key new_weight_map = {} for k, v in index_data["weight_map"].items(): new_k = k.replace("module.", "", 1) if k.startswith("module.") else k new_weight_map[new_k] = v index_data["weight_map"] = new_weight_map # 直接覆盖保存 with open(index_file, "w", encoding="utf-8") as f: json.dump(index_data, f, indent=2, ensure_ascii=False) print(f"已处理 {index_file}") # 2. 处理所有 bin 文件 for fname in os.listdir(model_dir): if fname.endswith(".bin"): bin_path = os.path.join(model_dir, fname) print(f"处理 {bin_path} ...") state_dict = torch.load(bin_path, map_location="cpu") new_state_dict = {} for k, v in state_dict.items(): new_k = k.replace("module.", "", 1) if k.startswith("module.") else k new_state_dict[new_k] = v # 直接覆盖保存 torch.save(new_state_dict, bin_path) print(f"完成 {fname}") print("所有文件已完成 module. 前缀删除 ✅")