import torch import os import json from safetensors.torch import load_file, save_file from safetensors import safe_open from collections import OrderedDict from tqdm import tqdm import glob # For finding shard files # --- Configuration --- # Should match OUTPUT_SHARD_DIR from the previous script # CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-DF-14B-540P/converted_fp8_shards" # Or T2V path CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P/converted_fp8_shards" # Or T2V path # Define the final single output file FINAL_OUTPUT_MODEL_NAME = "SkyReels-V2-T2V-14B-540P-fp8_e5m2.safetensors" # Example final name FINAL_OUTPUT_MODEL_PATH = os.path.join(os.path.dirname(CONVERTED_SHARDS_DIR), FINAL_OUTPUT_MODEL_NAME) # Saves in parent of shards dir # This index is needed to know the *intended order* of tensors if it matters, # and also to map tensor names to the *new* shard files if your merge logic needs it. # However, for a simple merge, we can just load all tensors from all new shards. # For a more robust merge that respects original ordering from an index, we'd need one. # For now, let's assume we just load everything and save in whatever order they come. # If specific order is critical, the original index.json from the FP32 model would be needed # to guide the loading order. # ORIGINAL_FP32_INDEX_JSON = "F:/Models/SkyReels-V2-DF-14B-540P/model.safetensors.index.json" print(f"--- SCRIPT START (Merge Converted Shards) ---") print(f"Converted shards directory: {CONVERTED_SHARDS_DIR}") print(f"Final output model path: {FINAL_OUTPUT_MODEL_PATH}") def merge_converted_shards(): if not os.path.exists(CONVERTED_SHARDS_DIR): print(f"Error: Directory with converted shards not found: {CONVERTED_SHARDS_DIR}") return # Find all .safetensors files in the converted_shards_dir # Ensure they are sorted to process in a consistent order (e.g., 00001, 00002, ...) shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "fp8_converted_model-*-of-*.safetensors"))) # Or a more generic pattern if your naming was different: # shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "*.safetensors"))) if not shard_files: print(f"Error: No converted shard files found in {CONVERTED_SHARDS_DIR}") return print(f"Found {len(shard_files)} converted shards to merge.") merged_state_dict = OrderedDict() for shard_path in tqdm(shard_files, desc="Merging shards"): print(f"Loading tensors from: {shard_path}") try: # Load all tensors from the current converted shard # No need for safe_open with individual get_tensor here, load_file is fine # as these shards are smaller. current_shard_state_dict = load_file(shard_path, device="cpu") merged_state_dict.update(current_shard_state_dict) print(f" Added {len(current_shard_state_dict)} tensors from {os.path.basename(shard_path)}") except Exception as e: print(f"Error loading shard {shard_path}: {e}") # Decide if you want to stop or continue return # Stop if a shard can't be loaded for the merge if not merged_state_dict: print("No tensors were loaded from shards. Final model file will not be created.") return print(f"\nMerge complete. Total tensors in merged model: {len(merged_state_dict)}") print(f"Saving merged model to {FINAL_OUTPUT_MODEL_PATH}...") try: os.makedirs(os.path.dirname(FINAL_OUTPUT_MODEL_PATH), exist_ok=True) save_file(merged_state_dict, FINAL_OUTPUT_MODEL_PATH) print(f"Successfully saved final merged model to {FINAL_OUTPUT_MODEL_PATH}") except Exception as e: print(f"Error saving the final merged model: {e}") if __name__ == "__main__": print(f"--- __main__ block start ---") if not os.path.exists(CONVERTED_SHARDS_DIR): print(f"Error: Converted shards directory not found: {CONVERTED_SHARDS_DIR}") else: merge_converted_shards() print(f"--- __main__ block end (Merge Converted Shards) ---")