File size: 2,738 Bytes
e9af1a0
 
 
 
b347285
e9af1a0
 
b347285
 
 
e9af1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b347285
 
 
e9af1a0
 
 
b347285
 
 
 
 
e9af1a0
b347285
 
e9af1a0
 
 
 
b347285
 
 
 
 
e9af1a0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
from safetensors.torch import save_file
import glob
import shutil

def convert_model_to_safetensors(model_path, output_path):
    # Delete output file if it exists
    if os.path.exists(output_path):
        os.remove(output_path)
    print(f"Looking for PyTorch model files in {model_path}")
    
    # Create the output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Load the PyTorch model file
    model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
                 glob.glob(os.path.join(model_path, "*.pth")) + \
                 glob.glob(os.path.join(model_path, "pytorch_model.bin"))
    
    if not model_files:
        raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
    
    print(f"Found model file(s): {model_files}")
    model_file = model_files[0]  # Use the first found model file
    
    # Load the state dict
    print(f"Loading model from {model_file}")
    checkpoint = torch.load(model_file, map_location='cpu')
    
    print(f"Checkpoint type: {type(checkpoint)}")
    print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
    
    # Extract only the model weights, removing metadata
    model_state_dict = {}
    if isinstance(checkpoint, dict):
        # If model_state_dict exists in checkpoint, use that
        if 'model_state_dict' in checkpoint:
            checkpoint = checkpoint['model_state_dict']
        # Otherwise try state_dict
        elif 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']
        print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
        
        # Only keep tensor values
        for key, value in checkpoint.items():
            if isinstance(value, torch.Tensor):
                model_state_dict[key] = value
                print(f"Added tensor for key: {key} with shape {value.shape}")
    
    print(f"Total number of tensors to save: {len(model_state_dict)}")
    if len(model_state_dict) == 0:
        raise ValueError("No tensors found in the checkpoint! Check the model structure.")
    
    # Save as safetensors
    print(f"Converting to safetensors and saving to {output_path}")
    save_file(model_state_dict, output_path)
    print("Conversion completed successfully!")

if __name__ == "__main__":
    # Update these paths according to your model location
    model_path = "./checkpoints"  # Path to your checkpoints directory
    output_path = "./checkpoints/model.safetensors"
    
    convert_model_to_safetensors(model_path, output_path)