Upload convert_to_safetensors.py
Browse files- convert_to_safetensors.py +19 -1
convert_to_safetensors.py
CHANGED
|
@@ -2,8 +2,12 @@
|
|
| 2 |
import torch
|
| 3 |
from safetensors.torch import save_file
|
| 4 |
import glob
|
|
|
|
| 5 |
|
| 6 |
def convert_model_to_safetensors(model_path, output_path):
|
|
|
|
|
|
|
|
|
|
| 7 |
print(f"Looking for PyTorch model files in {model_path}")
|
| 8 |
|
| 9 |
# Create the output directory if it doesn't exist
|
|
@@ -24,15 +28,29 @@ def convert_model_to_safetensors(model_path, output_path):
|
|
| 24 |
print(f"Loading model from {model_file}")
|
| 25 |
checkpoint = torch.load(model_file, map_location='cpu')
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
# Extract only the model weights, removing metadata
|
| 28 |
model_state_dict = {}
|
| 29 |
if isinstance(checkpoint, dict):
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
checkpoint = checkpoint['state_dict']
|
|
|
|
|
|
|
| 32 |
# Only keep tensor values
|
| 33 |
for key, value in checkpoint.items():
|
| 34 |
if isinstance(value, torch.Tensor):
|
| 35 |
model_state_dict[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Save as safetensors
|
| 38 |
print(f"Converting to safetensors and saving to {output_path}")
|
|
|
|
| 2 |
import torch
|
| 3 |
from safetensors.torch import save_file
|
| 4 |
import glob
|
| 5 |
+
import shutil
|
| 6 |
|
| 7 |
def convert_model_to_safetensors(model_path, output_path):
|
| 8 |
+
# Delete output file if it exists
|
| 9 |
+
if os.path.exists(output_path):
|
| 10 |
+
os.remove(output_path)
|
| 11 |
print(f"Looking for PyTorch model files in {model_path}")
|
| 12 |
|
| 13 |
# Create the output directory if it doesn't exist
|
|
|
|
| 28 |
print(f"Loading model from {model_file}")
|
| 29 |
checkpoint = torch.load(model_file, map_location='cpu')
|
| 30 |
|
| 31 |
+
print(f"Checkpoint type: {type(checkpoint)}")
|
| 32 |
+
print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
| 33 |
+
|
| 34 |
# Extract only the model weights, removing metadata
|
| 35 |
model_state_dict = {}
|
| 36 |
if isinstance(checkpoint, dict):
|
| 37 |
+
# If model_state_dict exists in checkpoint, use that
|
| 38 |
+
if 'model_state_dict' in checkpoint:
|
| 39 |
+
checkpoint = checkpoint['model_state_dict']
|
| 40 |
+
# Otherwise try state_dict
|
| 41 |
+
elif 'state_dict' in checkpoint:
|
| 42 |
checkpoint = checkpoint['state_dict']
|
| 43 |
+
print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
| 44 |
+
|
| 45 |
# Only keep tensor values
|
| 46 |
for key, value in checkpoint.items():
|
| 47 |
if isinstance(value, torch.Tensor):
|
| 48 |
model_state_dict[key] = value
|
| 49 |
+
print(f"Added tensor for key: {key} with shape {value.shape}")
|
| 50 |
+
|
| 51 |
+
print(f"Total number of tensors to save: {len(model_state_dict)}")
|
| 52 |
+
if len(model_state_dict) == 0:
|
| 53 |
+
raise ValueError("No tensors found in the checkpoint! Check the model structure.")
|
| 54 |
|
| 55 |
# Save as safetensors
|
| 56 |
print(f"Converting to safetensors and saving to {output_path}")
|