Mentors4EDU commited on
Commit
b347285
·
verified ·
1 Parent(s): 8edce2f

Upload convert_to_safetensors.py

Browse files
Files changed (1) hide show
  1. convert_to_safetensors.py +19 -1
convert_to_safetensors.py CHANGED
@@ -2,8 +2,12 @@
2
  import torch
3
  from safetensors.torch import save_file
4
  import glob
 
5
 
6
  def convert_model_to_safetensors(model_path, output_path):
 
 
 
7
  print(f"Looking for PyTorch model files in {model_path}")
8
 
9
  # Create the output directory if it doesn't exist
@@ -24,15 +28,29 @@ def convert_model_to_safetensors(model_path, output_path):
24
  print(f"Loading model from {model_file}")
25
  checkpoint = torch.load(model_file, map_location='cpu')
26
 
 
 
 
27
  # Extract only the model weights, removing metadata
28
  model_state_dict = {}
29
  if isinstance(checkpoint, dict):
30
- if 'state_dict' in checkpoint:
 
 
 
 
31
  checkpoint = checkpoint['state_dict']
 
 
32
  # Only keep tensor values
33
  for key, value in checkpoint.items():
34
  if isinstance(value, torch.Tensor):
35
  model_state_dict[key] = value
 
 
 
 
 
36
 
37
  # Save as safetensors
38
  print(f"Converting to safetensors and saving to {output_path}")
 
2
  import torch
3
  from safetensors.torch import save_file
4
  import glob
5
+ import shutil
6
 
7
  def convert_model_to_safetensors(model_path, output_path):
8
+ # Delete output file if it exists
9
+ if os.path.exists(output_path):
10
+ os.remove(output_path)
11
  print(f"Looking for PyTorch model files in {model_path}")
12
 
13
  # Create the output directory if it doesn't exist
 
28
  print(f"Loading model from {model_file}")
29
  checkpoint = torch.load(model_file, map_location='cpu')
30
 
31
+ print(f"Checkpoint type: {type(checkpoint)}")
32
+ print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
33
+
34
  # Extract only the model weights, removing metadata
35
  model_state_dict = {}
36
  if isinstance(checkpoint, dict):
37
+ # If model_state_dict exists in checkpoint, use that
38
+ if 'model_state_dict' in checkpoint:
39
+ checkpoint = checkpoint['model_state_dict']
40
+ # Otherwise try state_dict
41
+ elif 'state_dict' in checkpoint:
42
  checkpoint = checkpoint['state_dict']
43
+ print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
44
+
45
  # Only keep tensor values
46
  for key, value in checkpoint.items():
47
  if isinstance(value, torch.Tensor):
48
  model_state_dict[key] = value
49
+ print(f"Added tensor for key: {key} with shape {value.shape}")
50
+
51
+ print(f"Total number of tensors to save: {len(model_state_dict)}")
52
+ if len(model_state_dict) == 0:
53
+ raise ValueError("No tensors found in the checkpoint! Check the model structure.")
54
 
55
  # Save as safetensors
56
  print(f"Converting to safetensors and saving to {output_path}")