HaiwangYu commited on
Commit
6138642
·
verified ·
1 Parent(s): 3032040

Upload folder using huggingface_hub

Browse files
pdsp/safetensors/to-safetensors.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # filepath: convert_pth_to_safetensors.py
3
+
4
+ import os
5
+ import argparse
6
+ import torch
7
+ from safetensors.torch import save_file
8
+ from pathlib import Path
9
+
10
+ def convert_pth_to_safetensors(pth_path, output_dir=None, output_name=None):
11
+ """
12
+ Convert a PyTorch pickle (.pth) model to safetensors format.
13
+
14
+ Args:
15
+ pth_path (str): Path to the PyTorch .pth file
16
+ output_dir (str, optional): Directory to save the converted model. Defaults to same directory as input.
17
+ output_name (str, optional): Name for the output file. Defaults to input filename with .safetensors extension.
18
+
19
+ Returns:
20
+ str: Path to the saved safetensors file
21
+ """
22
+ print(f"Loading PyTorch model from: {pth_path}")
23
+
24
+ # Load PyTorch model
25
+ try:
26
+ state_dict = torch.load(pth_path, map_location="cpu")
27
+ except Exception as e:
28
+ raise RuntimeError(f"Failed to load PyTorch model: {e}")
29
+
30
+ # Handle different types of saved objects
31
+ if isinstance(state_dict, dict):
32
+ # Check if this is a state_dict or a full model save
33
+ if all(isinstance(v, torch.Tensor) for v in state_dict.values()):
34
+ # It's already a state_dict
35
+ print(f"Loaded state_dict with {len(state_dict)} parameters")
36
+ elif 'state_dict' in state_dict:
37
+ # It's a checkpoint with 'state_dict' key
38
+ state_dict = state_dict['state_dict']
39
+ print(f"Extracted state_dict from checkpoint with {len(state_dict)} parameters")
40
+ elif 'model_state_dict' in state_dict:
41
+ # It's a checkpoint with 'model_state_dict' key
42
+ state_dict = state_dict['model_state_dict']
43
+ print(f"Extracted model_state_dict from checkpoint with {len(state_dict)} parameters")
44
+ else:
45
+ # Try to find a key that contains tensors
46
+ tensor_keys = [k for k, v in state_dict.items() if isinstance(v, dict) and
47
+ any(isinstance(item, torch.Tensor) for item in v.values())]
48
+ if tensor_keys:
49
+ state_dict = state_dict[tensor_keys[0]]
50
+ print(f"Extracted state_dict from key '{tensor_keys[0]}' with {len(state_dict)} parameters")
51
+ else:
52
+ raise ValueError("Could not find state_dict in the loaded file")
53
+ elif hasattr(state_dict, 'state_dict'):
54
+ # It's a full model object
55
+ state_dict = state_dict.state_dict()
56
+ print(f"Extracted state_dict from model object with {len(state_dict)} parameters")
57
+ else:
58
+ raise ValueError("Unsupported format: loaded object is not a state_dict or model")
59
+
60
+ # Ensure all values are tensors
61
+ for k, v in list(state_dict.items()):
62
+ if not isinstance(v, torch.Tensor):
63
+ print(f"Warning: Removing non-tensor value for key '{k}' of type {type(v)}")
64
+ state_dict.pop(k)
65
+
66
+ # Determine output path
67
+ if output_dir is None:
68
+ output_dir = os.path.dirname(pth_path)
69
+
70
+ if output_name is None:
71
+ base_name = os.path.basename(pth_path)
72
+ output_name = os.path.splitext(base_name)[0] + ".safetensors"
73
+
74
+ os.makedirs(output_dir, exist_ok=True)
75
+ output_path = os.path.join(output_dir, output_name)
76
+
77
+ # Save to safetensors format
78
+ print(f"Saving model to: {output_path}")
79
+ try:
80
+ save_file(state_dict, output_path)
81
+ print(f"Successfully saved safetensors file: {output_path}")
82
+ return output_path
83
+ except Exception as e:
84
+ raise RuntimeError(f"Failed to save safetensors file: {e}")
85
+
86
+ def convert_directory(input_dir, output_dir=None, recursive=False, file_pattern="*.pth"):
87
+ """
88
+ Convert all PyTorch .pth models in a directory to safetensors format.
89
+
90
+ Args:
91
+ input_dir (str): Input directory containing PyTorch models
92
+ output_dir (str, optional): Output directory for safetensors files. Defaults to input_dir.
93
+ recursive (bool, optional): Whether to recursively search for models in subdirectories. Defaults to False.
94
+ file_pattern (str, optional): File pattern to match. Defaults to "*.pth".
95
+
96
+ Returns:
97
+ list: List of paths to converted safetensors files
98
+ """
99
+ if output_dir is None:
100
+ output_dir = input_dir
101
+
102
+ converted_files = []
103
+
104
+ # Find all PyTorch files
105
+ input_path = Path(input_dir)
106
+
107
+ if recursive:
108
+ pth_files = list(input_path.rglob(file_pattern))
109
+ else:
110
+ pth_files = list(input_path.glob(file_pattern))
111
+
112
+ if not pth_files:
113
+ print(f"No PyTorch files found in {input_dir} with pattern {file_pattern}")
114
+ return converted_files
115
+
116
+ print(f"Found {len(pth_files)} PyTorch files to convert")
117
+
118
+ # Convert each file
119
+ for pth_file in pth_files:
120
+ relative_path = pth_file.relative_to(input_path)
121
+ target_dir = Path(output_dir) / relative_path.parent
122
+ target_dir.mkdir(parents=True, exist_ok=True)
123
+
124
+ output_name = pth_file.stem + ".safetensors"
125
+
126
+ try:
127
+ converted_file = convert_pth_to_safetensors(
128
+ str(pth_file),
129
+ str(target_dir),
130
+ output_name
131
+ )
132
+ converted_files.append(converted_file)
133
+ except Exception as e:
134
+ print(f"Error converting {pth_file}: {e}")
135
+
136
+ return converted_files
137
+
138
+ def main():
139
+ parser = argparse.ArgumentParser(description="Convert PyTorch .pth models to safetensors format")
140
+
141
+ parser.add_argument("input", help="Input PyTorch model file or directory")
142
+ parser.add_argument("--output", "-o", help="Output file or directory for safetensors files")
143
+ parser.add_argument("--recursive", "-r", action="store_true",
144
+ help="Recursively search for PyTorch files in subdirectories")
145
+ parser.add_argument("--pattern", "-p", default="*.pth",
146
+ help="File pattern to match when searching directories (default: *.pth)")
147
+
148
+ args = parser.parse_args()
149
+
150
+ input_path = Path(args.input)
151
+
152
+ if input_path.is_file():
153
+ # Convert single file
154
+ output_dir = os.path.dirname(args.output) if args.output else None
155
+ output_name = os.path.basename(args.output) if args.output else None
156
+
157
+ try:
158
+ converted_file = convert_pth_to_safetensors(str(input_path), output_dir, output_name)
159
+ print(f"Conversion completed: {converted_file}")
160
+ except Exception as e:
161
+ print(f"Error: {e}")
162
+ return 1
163
+ else:
164
+ # Convert directory
165
+ try:
166
+ converted_files = convert_directory(
167
+ str(input_path),
168
+ args.output,
169
+ args.recursive,
170
+ args.pattern
171
+ )
172
+ print(f"Converted {len(converted_files)} files")
173
+ except Exception as e:
174
+ print(f"Error: {e}")
175
+ return 1
176
+
177
+ return 0
178
+
179
+ if __name__ == "__main__":
180
+ exit(main())
pdsp/safetensors/unet-l23-cosmic500-e50.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f4cf50006086c0c9cb8294f6fab72ba45284e0e2854c5bf04231bbda859f3b6
3
+ size 53625164