OpenPeerLLM / convert_to_safetensors.py
Mentors4EDU's picture
Upload convert_to_safetensors.py
b347285 verified
raw
history blame
2.74 kB
import os
import torch
from safetensors.torch import save_file
import glob
import shutil
def convert_model_to_safetensors(model_path, output_path):
# Delete output file if it exists
if os.path.exists(output_path):
os.remove(output_path)
print(f"Looking for PyTorch model files in {model_path}")
# Create the output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Load the PyTorch model file
model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
glob.glob(os.path.join(model_path, "*.pth")) + \
glob.glob(os.path.join(model_path, "pytorch_model.bin"))
if not model_files:
raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
print(f"Found model file(s): {model_files}")
model_file = model_files[0] # Use the first found model file
# Load the state dict
print(f"Loading model from {model_file}")
checkpoint = torch.load(model_file, map_location='cpu')
print(f"Checkpoint type: {type(checkpoint)}")
print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
# Extract only the model weights, removing metadata
model_state_dict = {}
if isinstance(checkpoint, dict):
# If model_state_dict exists in checkpoint, use that
if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
# Otherwise try state_dict
elif 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
# Only keep tensor values
for key, value in checkpoint.items():
if isinstance(value, torch.Tensor):
model_state_dict[key] = value
print(f"Added tensor for key: {key} with shape {value.shape}")
print(f"Total number of tensors to save: {len(model_state_dict)}")
if len(model_state_dict) == 0:
raise ValueError("No tensors found in the checkpoint! Check the model structure.")
# Save as safetensors
print(f"Converting to safetensors and saving to {output_path}")
save_file(model_state_dict, output_path)
print("Conversion completed successfully!")
if __name__ == "__main__":
# Update these paths according to your model location
model_path = "./checkpoints" # Path to your checkpoints directory
output_path = "./checkpoints/model.safetensors"
convert_model_to_safetensors(model_path, output_path)