|
|
import os
|
|
|
import torch
|
|
|
from safetensors.torch import save_file
|
|
|
import glob
|
|
|
import shutil
|
|
|
|
|
|
def convert_model_to_safetensors(model_path, output_path):
|
|
|
|
|
|
if os.path.exists(output_path):
|
|
|
os.remove(output_path)
|
|
|
print(f"Looking for PyTorch model files in {model_path}")
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
|
|
|
glob.glob(os.path.join(model_path, "*.pth")) + \
|
|
|
glob.glob(os.path.join(model_path, "pytorch_model.bin"))
|
|
|
|
|
|
if not model_files:
|
|
|
raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
|
|
|
|
|
|
print(f"Found model file(s): {model_files}")
|
|
|
model_file = model_files[0]
|
|
|
|
|
|
|
|
|
print(f"Loading model from {model_file}")
|
|
|
checkpoint = torch.load(model_file, map_location='cpu')
|
|
|
|
|
|
print(f"Checkpoint type: {type(checkpoint)}")
|
|
|
print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
|
|
|
|
|
|
|
|
model_state_dict = {}
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
|
checkpoint = checkpoint['model_state_dict']
|
|
|
|
|
|
elif 'state_dict' in checkpoint:
|
|
|
checkpoint = checkpoint['state_dict']
|
|
|
print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
|
|
|
|
|
|
|
|
for key, value in checkpoint.items():
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
model_state_dict[key] = value
|
|
|
print(f"Added tensor for key: {key} with shape {value.shape}")
|
|
|
|
|
|
print(f"Total number of tensors to save: {len(model_state_dict)}")
|
|
|
if len(model_state_dict) == 0:
|
|
|
raise ValueError("No tensors found in the checkpoint! Check the model structure.")
|
|
|
|
|
|
|
|
|
print(f"Converting to safetensors and saving to {output_path}")
|
|
|
save_file(model_state_dict, output_path)
|
|
|
print("Conversion completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
model_path = "./checkpoints"
|
|
|
output_path = "./checkpoints/model.safetensors"
|
|
|
|
|
|
convert_model_to_safetensors(model_path, output_path) |