phazei commited on
Commit
95b91f6
·
1 Parent(s): fc2a75c

Add readme and scripts

Browse files
README.md CHANGED
@@ -1,3 +1,67 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - skywork
5
+ - skyreels
6
+ - text-to-video
7
+ - video-generation
8
+ - fp8
9
+ - e5m2
10
+ - quantized
11
+ - 14b
12
+ - 540p
13
+ - comfyui
14
+ # Add more relevant tags
15
+ base_model:
16
+ - Skywork/SkyReels-V2-DF-14B-540P
17
+ - Skywork/SkyReels-V2-T2V-14B-540P
18
  ---
19
+
20
+ # SkyReels-V2-14B-540P FP8-E5M2 Quantized Models
21
+
22
+ This repository contains FP8-E5M2 quantized versions of the Skywork SkyReels-V2 14B 540P models, suitable for use with hardware supporting this precision (e.g., NVIDIA RTX 3090/40-series with `torch.compile`) and popular workflows like those in ComfyUI.
23
+
24
+ These models were quantized by [phazei](https://huggingface.co/phazei).
25
+
26
+ ## Original Models
27
+
28
+ These quantized models are based on the following original FP32 models from Skywork:
29
+
30
+ * **DF Variant:** [Skywork/SkyReels-V2-DF-14B-540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P)
31
+ * **T2V Variant:** [Skywork/SkyReels-V2-T2V-14B-540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P)
32
+
33
+ Please refer to the original model cards for details on their architecture, training, and intended use cases.
34
+
35
+ ## Quantization Details & Acknowledgements
36
+
37
+ The models were converted from their original FP32 sharded format to a mixed-precision format. The specific layers quantized to `FP8-E5M2` (primarily weight layers within attention and FFN blocks, while biases and normalization layers were kept in FP32) were identified by analyzing the FP8 quantized models provided by **[Kijai](https://huggingface.co/Kijai)** from his repository **[Kijai/WanVideo_comfy](https://huggingface.co/Kijai/WanVideo_comfy)**.
38
+
39
+ This conversion process replicates the quantization pattern observed in Kijai's converted files to produce these `FP8-E5M2` variants. Many thanks to Kijai for sharing his quantized models, which served as a clear reference for this work and benefit the ComfyUI community.
40
+
41
+ The conversion was performed using PyTorch and `safetensors`. The scripts used for downloading the original models and performing this conversion are included in the `scripts/` directory of this repository.
42
+
43
+ **Key characteristics of the quantized models:**
44
+
45
+ * **Precision:** Mixed (FP32, FP8-E5M2, U8 for metadata)
46
+ * **Target FP8 type:** `torch.float8_e5m2`
47
+ * **Compatibility:** Intended for use with PyTorch versions supporting `torch.float8_e5m2` and `torch.compile`. Well-suited for ComfyUI workflows that can leverage these models.
48
+
49
+ ## Files in this Repository
50
+
51
+ * `SkyReels-V2-DF-14B-540P-fp8e5m2.safetensors`: The quantized DF variant (single file).
52
+ * `SkyReels-V2-T2V-14B-540P-fp8e5m2.safetensors`: The quantized T2V variant (single file).
53
+ * `scripts/`: Contains Python scripts for downloading original models and performing the quantization.
54
+ * `model_download.py`
55
+ * `convert_to_fp8e5m2.py`
56
+ * `merge_fp8_shards.py`
57
+ * `safetensors_info.py`
58
+ * `README.md`: This model card.
59
+
60
+ ## Disclaimer
61
+
62
+ This is a community-contributed quantization. While efforts were made to maintain model quality by following an established quantization pattern, performance may differ from the original FP32 models or other quantized versions. Use at your own discretion.
63
+
64
+ ## Acknowledgements
65
+
66
+ * **Skywork AI** for releasing the original SkyReels models.
67
+ * **[Kijai](https://huggingface.co/Kijai)** for providing the quantized model versions that served as a reference for the quantization pattern applied in this repository.
scripts/convert_to_fp8e5m2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import json
4
+ from safetensors.torch import save_file
5
+ from safetensors import safe_open
6
+ from collections import OrderedDict
7
+ from tqdm import tqdm
8
+ import gc # For garbage collection
9
+
10
+ # --- Configuration ---
11
+ # INPUT_MODEL_DIR = "F:/Models/SkyReels-V2-DF-14B-540P"
12
+ INPUT_MODEL_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P"
13
+ OUTPUT_SHARD_DIR = os.path.join(INPUT_MODEL_DIR, "converted_fp8_shards") # Subdirectory for new shards
14
+ # Example output shard filename: fp8-model-00001-of-00012.safetensors
15
+
16
+ TARGET_FP8_DTYPE = torch.float8_e5m2
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ print(f"--- SCRIPT START (Shard-by-Shard Conversion) ---")
20
+ print(f"Using device for conversion: {DEVICE}")
21
+ print(f"Target FP8 dtype: {TARGET_FP8_DTYPE}")
22
+ print(f"Input model directory: {INPUT_MODEL_DIR}")
23
+ print(f"Output shard directory: {OUTPUT_SHARD_DIR}")
24
+
25
+ def should_convert_to_fp8(tensor_name: str) -> bool:
26
+ if not tensor_name.endswith(".weight"):
27
+ return False
28
+ if not "blocks." in tensor_name:
29
+ return False
30
+ if "cross_attn" in tensor_name or \
31
+ "ffn" in tensor_name or \
32
+ "self_attn" in tensor_name:
33
+ if ".norm_k.weight" in tensor_name or \
34
+ ".norm_q.weight" in tensor_name or \
35
+ ".norm.weight" in tensor_name:
36
+ return False
37
+ return True
38
+ return False
39
+
40
+ def convert_and_save_shards():
41
+ print(f"--- ENTERING convert_and_save_shards() ---")
42
+ index_json_path = os.path.join(INPUT_MODEL_DIR, "model.safetensors.index.json")
43
+ print(f"Index JSON path: {index_json_path}")
44
+
45
+ if not os.path.exists(index_json_path):
46
+ print(f"Error: model.safetensors.index.json not found in {INPUT_MODEL_DIR}")
47
+ return
48
+
49
+ os.makedirs(OUTPUT_SHARD_DIR, exist_ok=True)
50
+ print(f"Output directory for converted shards created/exists: {OUTPUT_SHARD_DIR}")
51
+
52
+ print(f"Loading index JSON...")
53
+ try:
54
+ with open(index_json_path, 'r') as f:
55
+ index_data = json.load(f)
56
+ print(f"Index JSON loaded successfully.")
57
+ except Exception as e:
58
+ print(f"Error loading or parsing index.json: {e}")
59
+ return
60
+
61
+ weight_map = index_data.get("weight_map")
62
+ if not weight_map:
63
+ print(f"Error: 'weight_map' not found in {index_json_path} or it is empty.")
64
+ return
65
+
66
+ print(f"Weight map found with {len(weight_map)} entries.")
67
+ if not weight_map:
68
+ print(f"Error: 'weight_map' is empty. Cannot proceed.")
69
+ return
70
+
71
+ # Group tensors by their original shard filename
72
+ tensors_by_shard = {}
73
+ for tensor_name, original_shard_filename in weight_map.items():
74
+ if original_shard_filename not in tensors_by_shard:
75
+ tensors_by_shard[original_shard_filename] = []
76
+ tensors_by_shard[original_shard_filename].append(tensor_name)
77
+
78
+ total_original_shards = len(tensors_by_shard)
79
+ print(f"Found {total_original_shards} unique input shards to process.")
80
+
81
+ # Process each original shard
82
+ for shard_idx, (original_shard_filename, tensor_names_in_shard) in enumerate(
83
+ tqdm(tensors_by_shard.items(), desc="Processing input shards", total=total_original_shards)
84
+ ):
85
+ current_input_shard_path = os.path.join(INPUT_MODEL_DIR, original_shard_filename)
86
+ # Construct output shard name, e.g., fp8-model-00001-of-00012.safetensors
87
+ # Assuming original_shard_filename is like "model-00001-of-00012.safetensors"
88
+ output_shard_filename_parts = original_shard_filename.split('-')
89
+ if len(output_shard_filename_parts) == 3: # model-xxxxx-of-yyyyy.safetensors
90
+ output_shard_filename = f"fp8-{output_shard_filename_parts[0]}-{output_shard_filename_parts[1]}-{output_shard_filename_parts[2]}"
91
+ else: # Fallback if naming is different
92
+ output_shard_filename = f"fp8_converted_{original_shard_filename}"
93
+
94
+ current_output_shard_path = os.path.join(OUTPUT_SHARD_DIR, output_shard_filename)
95
+
96
+ print(f"\n--- Processing Shard {shard_idx + 1}/{total_original_shards} ---")
97
+ print(f"Input shard: {current_input_shard_path}")
98
+ print(f"Output shard: {current_output_shard_path}")
99
+
100
+ # Skip if output shard already exists (for resumability)
101
+ if os.path.exists(current_output_shard_path):
102
+ print(f"Output shard {current_output_shard_path} already exists. Skipping.")
103
+ # Basic check: try to open it to see if it's valid (optional, adds time)
104
+ try:
105
+ with safe_open(current_output_shard_path, framework="pt", device="cpu") as f_test:
106
+ _ = f_test.keys() # Just try to get keys
107
+ print(f"Existing output shard {current_output_shard_path} seems valid.")
108
+ except Exception as e_test:
109
+ print(f"Warning: Existing output shard {current_output_shard_path} might be corrupted: {e_test}. Consider deleting it and rerunning for this shard.")
110
+ continue
111
+
112
+ if not os.path.exists(current_input_shard_path):
113
+ print(f"Error: Input shard file {current_input_shard_path} not found. Skipping this shard.")
114
+ continue
115
+
116
+ shard_state_dict = OrderedDict()
117
+
118
+ try:
119
+ with safe_open(current_input_shard_path, framework="pt", device="cpu") as f_in:
120
+ for tensor_name in tqdm(tensor_names_in_shard, desc=f"Tensors in {original_shard_filename}", leave=False):
121
+ print(f" Loading tensor: {tensor_name}") # Debug if needed
122
+ original_tensor = f_in.get_tensor(tensor_name)
123
+ print(f" Tensor '{tensor_name}' loaded. Dtype: {original_tensor.dtype}, Shape: {original_tensor.shape}")
124
+
125
+ if should_convert_to_fp8(tensor_name):
126
+ print(f" Converting '{tensor_name}' to {TARGET_FP8_DTYPE} on {DEVICE}...")
127
+ converted_tensor = original_tensor.to(DEVICE).to(TARGET_FP8_DTYPE).to("cpu")
128
+ shard_state_dict[tensor_name] = converted_tensor
129
+ else:
130
+ print(f" Keeping '{tensor_name}' as {original_tensor.dtype}.")
131
+ shard_state_dict[tensor_name] = original_tensor.to("cpu") # Ensure on CPU
132
+
133
+ if shard_state_dict:
134
+ print(f"Saving {len(shard_state_dict)} tensors to new shard: {current_output_shard_path}")
135
+ save_file(shard_state_dict, current_output_shard_path)
136
+ print(f"Successfully saved new shard: {current_output_shard_path}")
137
+ else:
138
+ print(f"No tensors processed for output shard: {current_output_shard_path}")
139
+
140
+ except Exception as e:
141
+ print(f"CRITICAL ERROR processing input shard {current_input_shard_path}: {e}")
142
+ import traceback
143
+ traceback.print_exc()
144
+ print(f"Skipping rest of shard {original_shard_filename} due to error.")
145
+ # Optionally, you might want to delete a partially written output shard if an error occurs mid-save
146
+ if os.path.exists(current_output_shard_path) and not shard_state_dict: # If error before any save
147
+ pass # No partial file to worry about if save_file hasn't been called
148
+ # If error during save_file, it's harder to handle cleanly without more complex logic
149
+
150
+ # Explicitly clear and collect garbage to free memory
151
+ del shard_state_dict
152
+ if 'original_tensor' in locals(): del original_tensor
153
+ if 'converted_tensor' in locals(): del converted_tensor
154
+ gc.collect()
155
+ if torch.cuda.is_available():
156
+ torch.cuda.empty_cache()
157
+ print(f"Memory cleanup after processing shard {original_shard_filename}")
158
+
159
+
160
+ print(f"\n--- All input shards processed. Converted shards are in {OUTPUT_SHARD_DIR} ---")
161
+
162
+ if __name__ == "__main__":
163
+ print(f"--- __main__ block start ---")
164
+ if not os.path.exists(INPUT_MODEL_DIR):
165
+ print(f"Error: Input model directory not found: {INPUT_MODEL_DIR}")
166
+ else:
167
+ print(f"Input model directory exists. Calling convert_and_save_shards().")
168
+ convert_and_save_shards()
169
+ print(f"--- __main__ block end (Shard-by-Shard Conversion) ---")
170
+
scripts/merge_fp8_shards.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import json
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from collections import OrderedDict
7
+ from tqdm import tqdm
8
+ import glob # For finding shard files
9
+
10
+ # --- Configuration ---
11
+ # Should match OUTPUT_SHARD_DIR from the previous script
12
+ # CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-DF-14B-540P/converted_fp8_shards" # Or T2V path
13
+ CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P/converted_fp8_shards" # Or T2V path
14
+ # Define the final single output file
15
+ FINAL_OUTPUT_MODEL_NAME = "SkyReels-V2-T2V-14B-540P-fp8_e5m2.safetensors" # Example final name
16
+ FINAL_OUTPUT_MODEL_PATH = os.path.join(os.path.dirname(CONVERTED_SHARDS_DIR), FINAL_OUTPUT_MODEL_NAME) # Saves in parent of shards dir
17
+
18
+ # This index is needed to know the *intended order* of tensors if it matters,
19
+ # and also to map tensor names to the *new* shard files if your merge logic needs it.
20
+ # However, for a simple merge, we can just load all tensors from all new shards.
21
+ # For a more robust merge that respects original ordering from an index, we'd need one.
22
+ # For now, let's assume we just load everything and save in whatever order they come.
23
+ # If specific order is critical, the original index.json from the FP32 model would be needed
24
+ # to guide the loading order.
25
+ # ORIGINAL_FP32_INDEX_JSON = "F:/Models/SkyReels-V2-DF-14B-540P/model.safetensors.index.json"
26
+
27
+
28
+ print(f"--- SCRIPT START (Merge Converted Shards) ---")
29
+ print(f"Converted shards directory: {CONVERTED_SHARDS_DIR}")
30
+ print(f"Final output model path: {FINAL_OUTPUT_MODEL_PATH}")
31
+
32
+ def merge_converted_shards():
33
+ if not os.path.exists(CONVERTED_SHARDS_DIR):
34
+ print(f"Error: Directory with converted shards not found: {CONVERTED_SHARDS_DIR}")
35
+ return
36
+
37
+ # Find all .safetensors files in the converted_shards_dir
38
+ # Ensure they are sorted to process in a consistent order (e.g., 00001, 00002, ...)
39
+ shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "fp8_converted_model-*-of-*.safetensors")))
40
+ # Or a more generic pattern if your naming was different:
41
+ # shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "*.safetensors")))
42
+
43
+
44
+ if not shard_files:
45
+ print(f"Error: No converted shard files found in {CONVERTED_SHARDS_DIR}")
46
+ return
47
+
48
+ print(f"Found {len(shard_files)} converted shards to merge.")
49
+
50
+ merged_state_dict = OrderedDict()
51
+
52
+ for shard_path in tqdm(shard_files, desc="Merging shards"):
53
+ print(f"Loading tensors from: {shard_path}")
54
+ try:
55
+ # Load all tensors from the current converted shard
56
+ # No need for safe_open with individual get_tensor here, load_file is fine
57
+ # as these shards are smaller.
58
+ current_shard_state_dict = load_file(shard_path, device="cpu")
59
+ merged_state_dict.update(current_shard_state_dict)
60
+ print(f" Added {len(current_shard_state_dict)} tensors from {os.path.basename(shard_path)}")
61
+ except Exception as e:
62
+ print(f"Error loading shard {shard_path}: {e}")
63
+ # Decide if you want to stop or continue
64
+ return # Stop if a shard can't be loaded for the merge
65
+
66
+ if not merged_state_dict:
67
+ print("No tensors were loaded from shards. Final model file will not be created.")
68
+ return
69
+
70
+ print(f"\nMerge complete. Total tensors in merged model: {len(merged_state_dict)}")
71
+ print(f"Saving merged model to {FINAL_OUTPUT_MODEL_PATH}...")
72
+ try:
73
+ os.makedirs(os.path.dirname(FINAL_OUTPUT_MODEL_PATH), exist_ok=True)
74
+ save_file(merged_state_dict, FINAL_OUTPUT_MODEL_PATH)
75
+ print(f"Successfully saved final merged model to {FINAL_OUTPUT_MODEL_PATH}")
76
+ except Exception as e:
77
+ print(f"Error saving the final merged model: {e}")
78
+
79
+ if __name__ == "__main__":
80
+ print(f"--- __main__ block start ---")
81
+ if not os.path.exists(CONVERTED_SHARDS_DIR):
82
+ print(f"Error: Converted shards directory not found: {CONVERTED_SHARDS_DIR}")
83
+ else:
84
+ merge_converted_shards()
85
+ print(f"--- __main__ block end (Merge Converted Shards) ---")
scripts/model_download.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ from huggingface_hub.utils import HfHubHTTPError # More specific import path
4
+ from tqdm import tqdm # For progress bars
5
+ # --- Configuration ---
6
+ MODELS_TO_DOWNLOAD = [
7
+ {
8
+ "repo_id": "Skywork/SkyReels-V2-DF-14B-540P",
9
+ "local_base_path": "F:/Models/SkyReels-V2-DF-14B-540P", # Base path for this model
10
+ "num_shards": 12,
11
+ },
12
+ {
13
+ "repo_id": "Skywork/SkyReels-V2-T2V-14B-540P",
14
+ "local_base_path": "F:/Models/SkyReels-V2-T2V-14B-540P", # Base path for this model
15
+ "num_shards": 12,
16
+ },
17
+ ]
18
+
19
+ # Common files to download in addition to shards
20
+ COMMON_FILES = [
21
+ "model.safetensors.index.json"
22
+ # Add other essential files like config.json, tokenizer_config.json, etc., if needed for loading later
23
+ # For now, we'll stick to the index file as specifically requested for sharded models.
24
+ # "config.json",
25
+ # "generation_config.json",
26
+ # "special_tokens_map.json",
27
+ # "tokenizer.json",
28
+ # "tokenizer_config.json",
29
+ # "vocab.json"
30
+ ]
31
+
32
+ def download_model_files(repo_id, local_base_path, num_shards):
33
+ """
34
+ Downloads sharded .safetensors model files and common configuration files
35
+ from a Hugging Face repository.
36
+ """
37
+ print(f"\nDownloading files for repository: {repo_id}")
38
+ print(f"Target local directory: {local_base_path}")
39
+
40
+ # Create the local directory if it doesn't exist
41
+ os.makedirs(local_base_path, exist_ok=True)
42
+
43
+ # --- Download common files ---
44
+ for common_file in COMMON_FILES:
45
+ print(f"Attempting to download: {common_file}...")
46
+ try:
47
+ hf_hub_download(
48
+ repo_id=repo_id,
49
+ filename=common_file,
50
+ local_dir=local_base_path,
51
+ local_dir_use_symlinks=False, # Download actual file
52
+ resume_download=True,
53
+ )
54
+ print(f"Successfully downloaded {common_file}")
55
+ except HfHubHTTPError as e:
56
+ if e.response.status_code == 404:
57
+ print(f"Warning: {common_file} not found in repository {repo_id}. Skipping.")
58
+ else:
59
+ print(f"Error downloading {common_file}: {e}")
60
+ except Exception as e:
61
+ print(f"An unexpected error occurred while downloading {common_file}: {e}")
62
+
63
+
64
+ # --- Download sharded model files ---
65
+ shard_filenames = []
66
+ for i in range(1, num_shards + 1):
67
+ # Filename format: model-00001-of-00012.safetensors
68
+ shard_filename = f"model-{i:05d}-of-{num_shards:05d}.safetensors"
69
+ shard_filenames.append(shard_filename)
70
+
71
+ print(f"\nAttempting to download {num_shards} model shards...")
72
+ for shard_filename in tqdm(shard_filenames, desc=f"Downloading shards for {repo_id}"):
73
+ try:
74
+ # print(f"Downloading {shard_filename} to {local_base_path}...") # tqdm provides progress
75
+ hf_hub_download(
76
+ repo_id=repo_id,
77
+ filename=shard_filename,
78
+ local_dir=local_base_path,
79
+ local_dir_use_symlinks=False, # Important: download the actual file
80
+ resume_download=True, # Good for large files
81
+ )
82
+ # print(f"Successfully downloaded {shard_filename}") # tqdm indicates completion
83
+ except HfHubHTTPError as e:
84
+ print(f"Error downloading {shard_filename}: {e}")
85
+ if e.response.status_code == 404:
86
+ print(f" {shard_filename} not found. Please check repository and shard count.")
87
+ return False # Stop if a shard download fails
88
+ except Exception as e:
89
+ print(f"An unexpected error occurred while downloading {shard_filename}: {e}")
90
+ return False
91
+ print(f"All {num_shards} shards for {repo_id} downloaded successfully (or skipped if not found).")
92
+ return True
93
+
94
+ if __name__ == "__main__":
95
+ print("Starting model download process...")
96
+ all_successful = True
97
+ for model_config in MODELS_TO_DOWNLOAD:
98
+ success = download_model_files(
99
+ repo_id=model_config["repo_id"],
100
+ local_base_path=model_config["local_base_path"],
101
+ num_shards=model_config["num_shards"]
102
+ )
103
+ if not success:
104
+ all_successful = False
105
+ print(f"Failed to download all files for {model_config['repo_id']}.")
106
+
107
+ if all_successful:
108
+ print("\nAll specified model files downloaded successfully.")
109
+ else:
110
+ print("\nSome model files failed to download. Please check the logs.")
scripts/safetensors_info.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from safetensors import safe_open
3
+ from collections import Counter
4
+ import os
5
+ import math # math.prod is Python 3.8+
6
+
7
+ # --- Dtype to Bytes Mapping ---
8
+ # Safetensors Dtype strings:
9
+ # BOOL, F8_E5M2, F8_E4M3FN, F16, BF16, F32, F64,
10
+ # I8, I16, I32, I64, U8, U16, U32, U64,
11
+ # F8_E5M2FNUZ, F8_E4M3FNUZ
12
+ DTYPE_TO_BYTES = {
13
+ "BOOL": 1,
14
+ # Float8 variants
15
+ "F8_E5M2": 1,
16
+ "F8E5M2": 1, # Common alternative naming
17
+ "F8_E4M3FN": 1,
18
+ "F8E4M3FN": 1, # Common alternative naming
19
+ "F8_E4M3": 1, # As seen in user example, likely E4M3FN
20
+ "F8_E5M2FNUZ": 1,
21
+ "F8E5M2FNUZ": 1, # Common alternative naming
22
+ "F8_E4M3FNUZ": 1,
23
+ "F8E4M3FNUZ": 1, # Common alternative naming
24
+ # Standard floats
25
+ "F16": 2,
26
+ "BF16": 2,
27
+ "F32": 4,
28
+ "F64": 8,
29
+ # Integers
30
+ "I8": 1,
31
+ "I16": 2,
32
+ "I32": 4,
33
+ "I64": 8,
34
+ # Unsigned Integers
35
+ "U8": 1,
36
+ "U16": 2,
37
+ "U32": 4,
38
+ "U64": 8,
39
+ }
40
+
41
+ def get_bytes_per_element(dtype_str):
42
+ """Returns the number of bytes for a given safetensors dtype string."""
43
+ return DTYPE_TO_BYTES.get(dtype_str.upper(), None)
44
+
45
+ def calculate_num_elements(shape):
46
+ """Calculates the total number of elements from a tensor shape tuple."""
47
+ if not shape: # Scalar tensor (shape is ())
48
+ return 1
49
+ if 0 in shape: # If any dimension is 0, total elements is 0
50
+ return 0
51
+ # Using math.prod for conciseness if Python 3.8+
52
+ # For broader compatibility, a loop can be used:
53
+ num_elements = 1
54
+ for dim_size in shape:
55
+ num_elements *= dim_size
56
+ return num_elements
57
+
58
+ def inspect_safetensors_precision_and_size(filepath):
59
+ """
60
+ Reads a .safetensors file, iterates through its tensors,
61
+ and reports the precision (dtype), actual size, and theoretical FP32 size.
62
+ """
63
+ if not os.path.exists(filepath):
64
+ print(f"Error: File not found at '{filepath}'")
65
+ return
66
+
67
+ if not filepath.lower().endswith(".safetensors"):
68
+ print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.")
69
+
70
+ tensor_info_list = []
71
+ dtype_counts = Counter()
72
+ total_actual_mb = 0.0
73
+ total_fp32_equiv_mb = 0.0
74
+
75
+ try:
76
+ print(f"Inspecting tensors in: {filepath}\n")
77
+ with safe_open(filepath, framework="pt", device="cpu") as f:
78
+ tensor_keys = list(f.keys())
79
+ if not tensor_keys:
80
+ print("No tensors found in the file.")
81
+ return
82
+
83
+ max_key_len = len("Tensor Name") # Default/minimum
84
+ if tensor_keys:
85
+ max_key_len = max(max_key_len, max(len(k) for k in tensor_keys))
86
+
87
+ header = (
88
+ f"{'Tensor Name':<{max_key_len}} | "
89
+ f"{'Precision (dtype)':<17} | "
90
+ f"{'Actual Size (MB)':>16} | "
91
+ f"{'FP32 Equiv. (MB)':>18}"
92
+ )
93
+ print(header)
94
+ print(
95
+ f"{'-' * max_key_len}-|-------------------|------------------|-------------------"
96
+ )
97
+
98
+ for key in tensor_keys:
99
+ tensor_slice = f.get_slice(key)
100
+ dtype_str = tensor_slice.get_dtype()
101
+ shape = tensor_slice.get_shape()
102
+
103
+ num_elements = calculate_num_elements(shape)
104
+ bytes_per_el_actual = get_bytes_per_element(dtype_str)
105
+
106
+ actual_size_mb_str = "N/A"
107
+ fp32_equiv_size_mb_str = "N/A"
108
+ actual_size_mb_val = 0.0
109
+
110
+ if bytes_per_el_actual is not None:
111
+ actual_bytes = num_elements * bytes_per_el_actual
112
+ actual_size_mb_val = actual_bytes / (1024 * 1024)
113
+ total_actual_mb += actual_size_mb_val
114
+ actual_size_mb_str = f"{actual_size_mb_val:.3f}"
115
+
116
+ # Theoretical FP32 size (FP32 is 4 bytes per element)
117
+ fp32_equiv_bytes = num_elements * 4
118
+ fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
119
+ total_fp32_equiv_mb += fp32_equiv_size_mb_val
120
+ fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
121
+ else:
122
+ print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.")
123
+
124
+ print(
125
+ f"{key:<{max_key_len}} | "
126
+ f"{dtype_str:<17} | "
127
+ f"{actual_size_mb_str:>16} | "
128
+ f"{fp32_equiv_size_mb_str:>18}"
129
+ )
130
+ dtype_counts[dtype_str] += 1
131
+
132
+ print("\n--- Summary ---")
133
+ print(f"Total tensors found: {len(tensor_keys)}")
134
+ if dtype_counts:
135
+ print("Precision distribution:")
136
+ for dtype, count in dtype_counts.most_common():
137
+ print(f" - {dtype:<12}: {count} tensor(s)")
138
+ else:
139
+ print("No dtypes to summarize.")
140
+
141
+ print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
142
+ print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")
143
+
144
+ if total_fp32_equiv_mb > 0.00001: # Avoid division by zero or near-zero
145
+ savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
146
+ print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
147
+ else:
148
+ print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")
149
+
150
+ except Exception as e:
151
+ print(f"An error occurred while processing '{filepath}':")
152
+ print(f" {e}")
153
+ print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.")
154
+
155
+ if __name__ == "__main__":
156
+ parser = argparse.ArgumentParser(
157
+ description="Inspect tensor precision (dtype) and size in a .safetensors file."
158
+ )
159
+ parser.add_argument(
160
+ "filepath",
161
+ help="Path to the .safetensors file to inspect."
162
+ )
163
+ args = parser.parse_args()
164
+
165
+ inspect_safetensors_precision_and_size(args.filepath)