HaiwangYu commited on
Commit
93dae0f
·
verified ·
1 Parent(s): 6138642

Delete pdsp/to-safetensors.py

Browse files
Files changed (1) hide show
  1. pdsp/to-safetensors.py +0 -180
pdsp/to-safetensors.py DELETED
@@ -1,180 +0,0 @@
1
- #!/usr/bin/env python
2
- # filepath: convert_pth_to_safetensors.py
3
-
4
- import os
5
- import argparse
6
- import torch
7
- from safetensors.torch import save_file
8
- from pathlib import Path
9
-
10
- def convert_pth_to_safetensors(pth_path, output_dir=None, output_name=None):
11
- """
12
- Convert a PyTorch pickle (.pth) model to safetensors format.
13
-
14
- Args:
15
- pth_path (str): Path to the PyTorch .pth file
16
- output_dir (str, optional): Directory to save the converted model. Defaults to same directory as input.
17
- output_name (str, optional): Name for the output file. Defaults to input filename with .safetensors extension.
18
-
19
- Returns:
20
- str: Path to the saved safetensors file
21
- """
22
- print(f"Loading PyTorch model from: {pth_path}")
23
-
24
- # Load PyTorch model
25
- try:
26
- state_dict = torch.load(pth_path, map_location="cpu")
27
- except Exception as e:
28
- raise RuntimeError(f"Failed to load PyTorch model: {e}")
29
-
30
- # Handle different types of saved objects
31
- if isinstance(state_dict, dict):
32
- # Check if this is a state_dict or a full model save
33
- if all(isinstance(v, torch.Tensor) for v in state_dict.values()):
34
- # It's already a state_dict
35
- print(f"Loaded state_dict with {len(state_dict)} parameters")
36
- elif 'state_dict' in state_dict:
37
- # It's a checkpoint with 'state_dict' key
38
- state_dict = state_dict['state_dict']
39
- print(f"Extracted state_dict from checkpoint with {len(state_dict)} parameters")
40
- elif 'model_state_dict' in state_dict:
41
- # It's a checkpoint with 'model_state_dict' key
42
- state_dict = state_dict['model_state_dict']
43
- print(f"Extracted model_state_dict from checkpoint with {len(state_dict)} parameters")
44
- else:
45
- # Try to find a key that contains tensors
46
- tensor_keys = [k for k, v in state_dict.items() if isinstance(v, dict) and
47
- any(isinstance(item, torch.Tensor) for item in v.values())]
48
- if tensor_keys:
49
- state_dict = state_dict[tensor_keys[0]]
50
- print(f"Extracted state_dict from key '{tensor_keys[0]}' with {len(state_dict)} parameters")
51
- else:
52
- raise ValueError("Could not find state_dict in the loaded file")
53
- elif hasattr(state_dict, 'state_dict'):
54
- # It's a full model object
55
- state_dict = state_dict.state_dict()
56
- print(f"Extracted state_dict from model object with {len(state_dict)} parameters")
57
- else:
58
- raise ValueError("Unsupported format: loaded object is not a state_dict or model")
59
-
60
- # Ensure all values are tensors
61
- for k, v in list(state_dict.items()):
62
- if not isinstance(v, torch.Tensor):
63
- print(f"Warning: Removing non-tensor value for key '{k}' of type {type(v)}")
64
- state_dict.pop(k)
65
-
66
- # Determine output path
67
- if output_dir is None:
68
- output_dir = os.path.dirname(pth_path)
69
-
70
- if output_name is None:
71
- base_name = os.path.basename(pth_path)
72
- output_name = os.path.splitext(base_name)[0] + ".safetensors"
73
-
74
- os.makedirs(output_dir, exist_ok=True)
75
- output_path = os.path.join(output_dir, output_name)
76
-
77
- # Save to safetensors format
78
- print(f"Saving model to: {output_path}")
79
- try:
80
- save_file(state_dict, output_path)
81
- print(f"Successfully saved safetensors file: {output_path}")
82
- return output_path
83
- except Exception as e:
84
- raise RuntimeError(f"Failed to save safetensors file: {e}")
85
-
86
- def convert_directory(input_dir, output_dir=None, recursive=False, file_pattern="*.pth"):
87
- """
88
- Convert all PyTorch .pth models in a directory to safetensors format.
89
-
90
- Args:
91
- input_dir (str): Input directory containing PyTorch models
92
- output_dir (str, optional): Output directory for safetensors files. Defaults to input_dir.
93
- recursive (bool, optional): Whether to recursively search for models in subdirectories. Defaults to False.
94
- file_pattern (str, optional): File pattern to match. Defaults to "*.pth".
95
-
96
- Returns:
97
- list: List of paths to converted safetensors files
98
- """
99
- if output_dir is None:
100
- output_dir = input_dir
101
-
102
- converted_files = []
103
-
104
- # Find all PyTorch files
105
- input_path = Path(input_dir)
106
-
107
- if recursive:
108
- pth_files = list(input_path.rglob(file_pattern))
109
- else:
110
- pth_files = list(input_path.glob(file_pattern))
111
-
112
- if not pth_files:
113
- print(f"No PyTorch files found in {input_dir} with pattern {file_pattern}")
114
- return converted_files
115
-
116
- print(f"Found {len(pth_files)} PyTorch files to convert")
117
-
118
- # Convert each file
119
- for pth_file in pth_files:
120
- relative_path = pth_file.relative_to(input_path)
121
- target_dir = Path(output_dir) / relative_path.parent
122
- target_dir.mkdir(parents=True, exist_ok=True)
123
-
124
- output_name = pth_file.stem + ".safetensors"
125
-
126
- try:
127
- converted_file = convert_pth_to_safetensors(
128
- str(pth_file),
129
- str(target_dir),
130
- output_name
131
- )
132
- converted_files.append(converted_file)
133
- except Exception as e:
134
- print(f"Error converting {pth_file}: {e}")
135
-
136
- return converted_files
137
-
138
- def main():
139
- parser = argparse.ArgumentParser(description="Convert PyTorch .pth models to safetensors format")
140
-
141
- parser.add_argument("input", help="Input PyTorch model file or directory")
142
- parser.add_argument("--output", "-o", help="Output file or directory for safetensors files")
143
- parser.add_argument("--recursive", "-r", action="store_true",
144
- help="Recursively search for PyTorch files in subdirectories")
145
- parser.add_argument("--pattern", "-p", default="*.pth",
146
- help="File pattern to match when searching directories (default: *.pth)")
147
-
148
- args = parser.parse_args()
149
-
150
- input_path = Path(args.input)
151
-
152
- if input_path.is_file():
153
- # Convert single file
154
- output_dir = os.path.dirname(args.output) if args.output else None
155
- output_name = os.path.basename(args.output) if args.output else None
156
-
157
- try:
158
- converted_file = convert_pth_to_safetensors(str(input_path), output_dir, output_name)
159
- print(f"Conversion completed: {converted_file}")
160
- except Exception as e:
161
- print(f"Error: {e}")
162
- return 1
163
- else:
164
- # Convert directory
165
- try:
166
- converted_files = convert_directory(
167
- str(input_path),
168
- args.output,
169
- args.recursive,
170
- args.pattern
171
- )
172
- print(f"Converted {len(converted_files)} files")
173
- except Exception as e:
174
- print(f"Error: {e}")
175
- return 1
176
-
177
- return 0
178
-
179
- if __name__ == "__main__":
180
- exit(main())