|
import os |
|
import json |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
|
|
|
|
def merge_safetensors(input_dir, output_file, config_file): |
|
|
|
merged_tensors = {} |
|
|
|
|
|
with open(config_file, 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
metadata = { |
|
"format": "pt", |
|
"total_size": "", |
|
"_diffusers_version": config.get("_diffusers_version", ""), |
|
"_class_name": config.get("_class_name", ""), |
|
|
|
} |
|
|
|
total_size = 0 |
|
|
|
|
|
for filename in os.listdir(input_dir): |
|
if filename.endswith('.safetensors'): |
|
file_path = os.path.join(input_dir, filename) |
|
|
|
|
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
file_metadata = f.metadata() |
|
if file_metadata and "__metadata__" in file_metadata: |
|
total_size += int(file_metadata["__metadata__"].get("total_size", 0)) |
|
|
|
for key in f.keys(): |
|
tensor = f.get_tensor(key) |
|
merged_tensors[key] = tensor |
|
|
|
|
|
metadata["total_size"] = str(total_size) |
|
|
|
|
|
save_file(merged_tensors, output_file, metadata) |
|
|
|
|
|
input_directory = './10_1' |
|
output_file = './10_1/flux1-merge-S10_D1.safetensors' |
|
config_file = './10_1/config.json' |
|
merge_safetensors(input_directory, output_file, config_file) |