import torch import os import json from safetensors.torch import save_file from safetensors import safe_open from collections import OrderedDict from tqdm import tqdm import gc # For garbage collection # --- Configuration --- # INPUT_MODEL_DIR = "F:/Models/SkyReels-V2-DF-14B-540P" INPUT_MODEL_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P" OUTPUT_SHARD_DIR = os.path.join(INPUT_MODEL_DIR, "converted_fp8_shards") # Subdirectory for new shards # Example output shard filename: fp8-model-00001-of-00012.safetensors TARGET_FP8_DTYPE = torch.float8_e5m2 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"--- SCRIPT START (Shard-by-Shard Conversion) ---") print(f"Using device for conversion: {DEVICE}") print(f"Target FP8 dtype: {TARGET_FP8_DTYPE}") print(f"Input model directory: {INPUT_MODEL_DIR}") print(f"Output shard directory: {OUTPUT_SHARD_DIR}") def should_convert_to_fp8(tensor_name: str) -> bool: if not tensor_name.endswith(".weight"): return False if not "blocks." in tensor_name: return False if "cross_attn" in tensor_name or \ "ffn" in tensor_name or \ "self_attn" in tensor_name: if ".norm_k.weight" in tensor_name or \ ".norm_q.weight" in tensor_name or \ ".norm.weight" in tensor_name: return False return True return False def convert_and_save_shards(): print(f"--- ENTERING convert_and_save_shards() ---") index_json_path = os.path.join(INPUT_MODEL_DIR, "model.safetensors.index.json") print(f"Index JSON path: {index_json_path}") if not os.path.exists(index_json_path): print(f"Error: model.safetensors.index.json not found in {INPUT_MODEL_DIR}") return os.makedirs(OUTPUT_SHARD_DIR, exist_ok=True) print(f"Output directory for converted shards created/exists: {OUTPUT_SHARD_DIR}") print(f"Loading index JSON...") try: with open(index_json_path, 'r') as f: index_data = json.load(f) print(f"Index JSON loaded successfully.") except Exception as e: print(f"Error loading or parsing index.json: {e}") return weight_map = index_data.get("weight_map") if not weight_map: print(f"Error: 'weight_map' not found in {index_json_path} or it is empty.") return print(f"Weight map found with {len(weight_map)} entries.") if not weight_map: print(f"Error: 'weight_map' is empty. Cannot proceed.") return # Group tensors by their original shard filename tensors_by_shard = {} for tensor_name, original_shard_filename in weight_map.items(): if original_shard_filename not in tensors_by_shard: tensors_by_shard[original_shard_filename] = [] tensors_by_shard[original_shard_filename].append(tensor_name) total_original_shards = len(tensors_by_shard) print(f"Found {total_original_shards} unique input shards to process.") # Process each original shard for shard_idx, (original_shard_filename, tensor_names_in_shard) in enumerate( tqdm(tensors_by_shard.items(), desc="Processing input shards", total=total_original_shards) ): current_input_shard_path = os.path.join(INPUT_MODEL_DIR, original_shard_filename) # Construct output shard name, e.g., fp8-model-00001-of-00012.safetensors # Assuming original_shard_filename is like "model-00001-of-00012.safetensors" output_shard_filename_parts = original_shard_filename.split('-') if len(output_shard_filename_parts) == 3: # model-xxxxx-of-yyyyy.safetensors output_shard_filename = f"fp8-{output_shard_filename_parts[0]}-{output_shard_filename_parts[1]}-{output_shard_filename_parts[2]}" else: # Fallback if naming is different output_shard_filename = f"fp8_converted_{original_shard_filename}" current_output_shard_path = os.path.join(OUTPUT_SHARD_DIR, output_shard_filename) print(f"\n--- Processing Shard {shard_idx + 1}/{total_original_shards} ---") print(f"Input shard: {current_input_shard_path}") print(f"Output shard: {current_output_shard_path}") # Skip if output shard already exists (for resumability) if os.path.exists(current_output_shard_path): print(f"Output shard {current_output_shard_path} already exists. Skipping.") # Basic check: try to open it to see if it's valid (optional, adds time) try: with safe_open(current_output_shard_path, framework="pt", device="cpu") as f_test: _ = f_test.keys() # Just try to get keys print(f"Existing output shard {current_output_shard_path} seems valid.") except Exception as e_test: print(f"Warning: Existing output shard {current_output_shard_path} might be corrupted: {e_test}. Consider deleting it and rerunning for this shard.") continue if not os.path.exists(current_input_shard_path): print(f"Error: Input shard file {current_input_shard_path} not found. Skipping this shard.") continue shard_state_dict = OrderedDict() try: with safe_open(current_input_shard_path, framework="pt", device="cpu") as f_in: for tensor_name in tqdm(tensor_names_in_shard, desc=f"Tensors in {original_shard_filename}", leave=False): print(f" Loading tensor: {tensor_name}") # Debug if needed original_tensor = f_in.get_tensor(tensor_name) print(f" Tensor '{tensor_name}' loaded. Dtype: {original_tensor.dtype}, Shape: {original_tensor.shape}") if should_convert_to_fp8(tensor_name): print(f" Converting '{tensor_name}' to {TARGET_FP8_DTYPE} on {DEVICE}...") converted_tensor = original_tensor.to(DEVICE).to(TARGET_FP8_DTYPE).to("cpu") shard_state_dict[tensor_name] = converted_tensor else: print(f" Keeping '{tensor_name}' as {original_tensor.dtype}.") shard_state_dict[tensor_name] = original_tensor.to("cpu") # Ensure on CPU if shard_state_dict: print(f"Saving {len(shard_state_dict)} tensors to new shard: {current_output_shard_path}") save_file(shard_state_dict, current_output_shard_path) print(f"Successfully saved new shard: {current_output_shard_path}") else: print(f"No tensors processed for output shard: {current_output_shard_path}") except Exception as e: print(f"CRITICAL ERROR processing input shard {current_input_shard_path}: {e}") import traceback traceback.print_exc() print(f"Skipping rest of shard {original_shard_filename} due to error.") # Optionally, you might want to delete a partially written output shard if an error occurs mid-save if os.path.exists(current_output_shard_path) and not shard_state_dict: # If error before any save pass # No partial file to worry about if save_file hasn't been called # If error during save_file, it's harder to handle cleanly without more complex logic # Explicitly clear and collect garbage to free memory del shard_state_dict if 'original_tensor' in locals(): del original_tensor if 'converted_tensor' in locals(): del converted_tensor gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Memory cleanup after processing shard {original_shard_filename}") print(f"\n--- All input shards processed. Converted shards are in {OUTPUT_SHARD_DIR} ---") if __name__ == "__main__": print(f"--- __main__ block start ---") if not os.path.exists(INPUT_MODEL_DIR): print(f"Error: Input model directory not found: {INPUT_MODEL_DIR}") else: print(f"Input model directory exists. Calling convert_and_save_shards().") convert_and_save_shards() print(f"--- __main__ block end (Shard-by-Shard Conversion) ---")