Clybius commited on
Commit
043e8b2
·
verified ·
1 Parent(s): 61063b1

Upload convert_fp8_scaled_stochastic.py

Browse files

Usage: `python convert_fp8_scaled_stochastic.py --input /path/to/chroma-unlocked.safetensors` on a file with a dtype of higher precision than FP8.

It will output a .safetensors in the same directory, in FP8, with scaling tensors, under another name with the associated quant type.

Files changed (1) hide show
  1. convert_fp8_scaled_stochastic.py +328 -0
convert_fp8_scaled_stochastic.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from safetensors import safe_open
6
+ from safetensors.torch import save_file
7
+ from typing import Dict, Tuple
8
+
9
+ # --- Configuration ---
10
+ # Keys containing these substrings will not be quantized if --t5xxl is set
11
+ AVOID_KEY_NAMES = ["norm", "bias", "embed_tokens", "shared"] #T5XXL, may need to be changed for other TEs.
12
+ # Target FP8 format
13
+ TARGET_FP8_DTYPE = torch.float8_e4m3fn
14
+ # Intermediate dtype for calculations
15
+ COMPUTE_DTYPE = torch.float64 # Don't think more hurts here since we're working tensor by tensor.
16
+ # Dtype for storing scale factors
17
+ SCALE_DTYPE = torch.float64 # Might be overkill, float32 should do just fine, but since these are so tiny may as well :3
18
+ # --- End Configuration ---
19
+
20
+ def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
21
+ mantissa_scaled = torch.where(
22
+ normal_mask,
23
+ (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
24
+ (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
25
+ )
26
+
27
+ mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
28
+ return mantissa_scaled.floor() / (2**MANTISSA_BITS)
29
+
30
+ #Not 100% sure about this
31
+ def manual_stochastic_round_to_float8(x, dtype, generator=None):
32
+ if dtype == torch.float8_e4m3fn:
33
+ EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
34
+ elif dtype == torch.float8_e5m2:
35
+ EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
36
+ else:
37
+ raise ValueError("Unsupported dtype")
38
+
39
+ x = x.half()
40
+ sign = torch.sign(x)
41
+ abs_x = x.abs()
42
+ sign = torch.where(abs_x == 0, 0, sign)
43
+
44
+ # Combine exponent calculation and clamping
45
+ exponent = torch.clamp(
46
+ torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
47
+ 0, 2**EXPONENT_BITS - 1
48
+ )
49
+
50
+ # Combine mantissa calculation and rounding
51
+ normal_mask = ~(exponent == 0)
52
+
53
+ abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
54
+
55
+ sign *= torch.where(
56
+ normal_mask,
57
+ (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
58
+ (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
59
+ )
60
+
61
+ inf = torch.finfo(dtype)
62
+ torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
63
+ return sign
64
+
65
+
66
+
67
+ def stochastic_rounding(value, dtype=TARGET_FP8_DTYPE, seed=0):
68
+ if dtype == torch.float32:
69
+ return value.to(dtype=torch.float32)
70
+ if dtype == torch.float16:
71
+ return value.to(dtype=torch.float16)
72
+ if dtype == torch.bfloat16:
73
+ return value.to(dtype=torch.bfloat16)
74
+ if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
75
+ generator = torch.Generator(device=value.device)
76
+ generator.manual_seed(seed)
77
+ output = torch.empty_like(value, dtype=dtype)
78
+ num_slices = max(1, (value.numel() / (1536 * 1536)))
79
+ slice_size = max(1, round(value.shape[0] / num_slices))
80
+ for i in range(0, value.shape[0], slice_size):
81
+ output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
82
+ #output.copy_(manual_stochastic_round_to_float8(value, dtype, generator=generator))
83
+ return output
84
+
85
+ return value.to(dtype=dtype)
86
+
87
+ def get_fp8_constants(fp8_dtype: torch.dtype) -> Tuple[float, float, float]:
88
+ """Gets the min, max, and smallest positive normal value for a given FP8 dtype."""
89
+ finfo = torch.finfo(fp8_dtype)
90
+ # Smallest positive normal value approximation (may vary based on exact FP8 spec interpretation)
91
+ # For E4M3FN: exponent bias 7, smallest normal exp is -6. 1.0 * 2^-6 = 1/64
92
+ # Smallest subnormal is 2^-9 for E4M3FN from the paper. Let's use subnormal min.
93
+ # Find the smallest positive value representable (subnormal)
94
+ # This is tricky as finfo.tiny is often the smallest *normal*.
95
+ # Let's hardcode based on E4M3FN spec (S=0, E=0000, M=001) -> 2^-9
96
+ if fp8_dtype == torch.float8_e4m3fn:
97
+ fp8_min_pos = 2**-9 # Smallest subnormal for E4M3FN
98
+ elif fp8_dtype == torch.float8_e5m2:
99
+ # E5M2: exponent bias 15, smallest normal exp -14. Smallest subnormal 2^-16
100
+ fp8_min_pos = 2**-16 # Smallest subnormal for E5M2
101
+ else:
102
+ # Fallback using finfo.tiny (likely smallest normal)
103
+ fp8_min_pos = finfo.tiny * finfo.eps # A guess if unknown type
104
+
105
+ # Ensure min_pos is a Python float for consistency
106
+ fp8_min_pos = float(fp8_min_pos)
107
+
108
+ return float(finfo.min), float(finfo.max), fp8_min_pos
109
+
110
+ # Global FP8 constants
111
+ FP8_MIN, FP8_MAX, FP8_MIN_POS = get_fp8_constants(TARGET_FP8_DTYPE)
112
+
113
+ def convert_to_fp8_scaled(input_file: str, output_file: str, t5xxl: bool):
114
+ """
115
+ Converts a safetensors file to a version with FP8 scaled weights using stochastic rounding.
116
+
117
+ For each tensor ending with '.weight' (unless excluded):
118
+ 1. Calculates a scale factor based on the tensor's max absolute value.
119
+ 2. Scales the tensor to fit within the FP8 range [-FP8_MAX, FP8_MAX].
120
+ 3. Clamps the scaled tensor.
121
+ 4. Applies stochastic rounding during quantization to TARGET_FP8_DTYPE.
122
+ 5. Stores the quantized tensor.
123
+ 6. Stores '.scale_weight' tensor: the factor to dequantize the weight (1.0 / scale_factor).
124
+ 7. Stores '.scale_input' tensor: the factor to dequantize the input (using 1.0 / scale_factor as proxy).
125
+ """
126
+ print(f"Processing: {input_file}")
127
+ print(f"Output will be saved to: {output_file}")
128
+ print(f"Using FP8 format: {TARGET_FP8_DTYPE}")
129
+ print(f"FP8 Range: [{FP8_MIN}, {FP8_MAX}], Min Pos Subnormal: {FP8_MIN_POS:.2e}")
130
+ print(f"Using Stochastic Rounding: True")
131
+
132
+ # Load the original model
133
+ tensors: Dict[str, torch.Tensor] = {}
134
+ try:
135
+ with safe_open(input_file, framework="pt", device="cpu") as f:
136
+ for key in f.keys():
137
+ # Load directly to CPU to avoid potential GPU OOM for large models
138
+ tensors[key] = f.get_tensor(key).cpu()
139
+ except Exception as e:
140
+ print(f"Error loading '{input_file}': {e}")
141
+ return
142
+
143
+ # Keep track of new/modified tensors
144
+ new_tensors: Dict[str, torch.Tensor] = {}
145
+
146
+ # Process each tensor ending with '.weight'
147
+ weight_keys = sorted([key for key in tensors.keys() if key.endswith('.weight')])
148
+ total_weights = len(weight_keys)
149
+ skipped_count = 0
150
+ processed_count = 0
151
+
152
+ print(f"Found {total_weights} weight tensors to potentially process.")
153
+
154
+ for i, key in enumerate(weight_keys):
155
+ process_this_key = True
156
+ if t5xxl:
157
+ for avoid_name in AVOID_KEY_NAMES:
158
+ if avoid_name in key:
159
+ print(f"({i+1}/{total_weights}) Skipping excluded tensor: {key}")
160
+ # Keep original tensor
161
+ new_tensors[key] = tensors[key]
162
+ process_this_key = False
163
+ skipped_count += 1
164
+ break # Stop checking avoid names for this key
165
+
166
+ if not process_this_key:
167
+ continue
168
+
169
+ print(f"({i+1}/{total_weights}) Processing tensor: {key}")
170
+ processed_count += 1
171
+
172
+ # Get the original tensor and convert to high precision for calculations
173
+ original_tensor = tensors[key].to(COMPUTE_DTYPE)
174
+
175
+ if original_tensor.numel() == 0:
176
+ print(f" - Skipping empty tensor: {key}")
177
+ new_tensors[key] = tensors[key].to(TARGET_FP8_DTYPE) # Store as empty FP8
178
+ # Add dummy scales
179
+ base_name = key[:-len('.weight')]
180
+ scale_weight_key = f"{base_name}.scale_weight"
181
+ dequant_scale = torch.tensor([1.0], dtype=SCALE_DTYPE)
182
+ new_tensors[scale_weight_key] = dequant_scale.detach().clone()
183
+ continue
184
+
185
+ # Calculate the scaling factor needed to map the max absolute value to FP8_MAX
186
+ abs_max = torch.max(torch.abs(original_tensor))
187
+ # Handle all-zero tensors or edge cases
188
+ if abs_max < 1e-12: # Use a small threshold instead of exact zero
189
+ print(f" - Tensor has near-zero max value ({abs_max.item():.2e}). Using scale factor 1.0.")
190
+ scale_factor = torch.tensor(1.0, dtype=COMPUTE_DTYPE)
191
+ scaled_tensor = original_tensor # No scaling needed
192
+ else:
193
+ # Ensure abs_max is positive before division
194
+ abs_max = abs_max.clamp(min=FP8_MIN_POS) # Clamp to smallest positive FP8 value
195
+ scale_factor = (FP8_MAX - FP8_MIN_POS) / abs_max
196
+ # Scale the tensor
197
+ scaled_tensor = original_tensor.mul(scale_factor)
198
+
199
+ # Clamp the scaled tensor to the representable FP8 range
200
+ #print(scale_factor)
201
+ clamped_tensor = torch.clamp(scaled_tensor, FP8_MIN, FP8_MAX)
202
+
203
+ # Perform stochastic rounding and quantization to FP8
204
+ quantized_fp8_tensor = stochastic_rounding(clamped_tensor)
205
+
206
+ # Store the quantized tensor
207
+ new_tensors[key] = quantized_fp8_tensor
208
+
209
+ # Calculate dequantization scale factor (inverse of the scaling factor)
210
+ dequant_scale = scale_factor.reciprocal()
211
+
212
+ # Create scale tensor keys
213
+ base_name = key[:-len('.weight')]
214
+ scale_weight_key = f"{base_name}.scale_weight"
215
+ # scale_input_key = f"{base_name}.scale_input" # scale_input Is not necessary, I think? Leaving this here as a cookie trail or smth if necessary in the future.
216
+
217
+ # Store scale tensors
218
+ new_tensors[scale_weight_key] = dequant_scale.detach().clone()
219
+
220
+ # --- Debug/Info Printing ---
221
+ print(f" - Abs Max : {abs_max.item():.5}")
222
+ print(f" - Scale Factor : {scale_factor.item():.5}")
223
+ print(f" - Dequant Scale : {dequant_scale.item():.5}")
224
+
225
+ # Combine original non-weight tensors with new/modified ones
226
+ added_scale_keys = set()
227
+ for key in new_tensors:
228
+ if key.endswith(".scale_weight") or key.endswith(".scale_input"):
229
+ added_scale_keys.add(key)
230
+
231
+ original_keys = set(tensors.keys())
232
+ processed_weight_keys = set(k for k, v in new_tensors.items() if k.endswith(".weight"))
233
+
234
+ for key, tensor in tensors.items():
235
+ # Add if it's not a weight tensor OR if it's a weight tensor that was skipped
236
+ is_weight = key.endswith(".weight")
237
+ if key not in new_tensors:
238
+ if not is_weight:
239
+ # Non-weight tensor, just copy it over
240
+ new_tensors[key] = tensor
241
+ print(f"(+) Adding original non-weight tensor: {key}")
242
+
243
+ # Add FP8 marker key for compatibility (e.g., ComfyUI)
244
+ new_tensors["scaled_fp8"] = torch.empty((2), dtype=TARGET_FP8_DTYPE) if not t5xxl else torch.empty((0), dtype=TARGET_FP8_DTYPE)
245
+
246
+ # Save the modified model
247
+ print("-" * 40)
248
+ print(f"Saving {len(new_tensors)} tensors to {output_file}")
249
+ try:
250
+ # Ensure parent directory exists
251
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
252
+ # Metadata can be useful
253
+ #metadata = {'format': f'pt_scaled_{TARGET_FP8_DTYPE.__str__().split(".")[-1]}'}
254
+ save_file(new_tensors, output_file)
255
+ print("Conversion complete!")
256
+ except Exception as e:
257
+ print(f"Error saving file '{output_file}': {e}")
258
+ return
259
+
260
+ # Print summary
261
+ final_tensor_count = len(new_tensors)
262
+ original_tensor_count = len(tensors)
263
+ added_tensors_count = final_tensor_count - original_tensor_count
264
+ added_scales_count = len(added_scale_keys)
265
+
266
+ print("-" * 40)
267
+ print(f"Summary:")
268
+ print(f" - Original tensor count : {original_tensor_count}")
269
+ print(f" - Weight tensors found : {total_weights}")
270
+ print(f" - Weights processed : {processed_count}")
271
+ print(f" - Weights skipped : {skipped_count}")
272
+ print(f" - Added scale tensors : {added_scales_count}") # Should be processed_count * 2 + skipped_count * 2
273
+ print(f" - Added marker tensor : 1")
274
+ print(f" - Final tensor count : {final_tensor_count}")
275
+ print("-" * 40)
276
+
277
+
278
+ def main():
279
+ parser = argparse.ArgumentParser(
280
+ description=f"Convert safetensors weights to Scaled {TARGET_FP8_DTYPE} format using stochastic rounding.",
281
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
282
+ )
283
+ parser.add_argument(
284
+ "--input",
285
+ type=str,
286
+ required=True,
287
+ help="Input safetensors file path."
288
+ )
289
+ parser.add_argument(
290
+ "--output",
291
+ type=str,
292
+ help="Output safetensors file path. If not provided, generated based on input name."
293
+ )
294
+ parser.add_argument(
295
+ "--t5xxl",
296
+ action='store_true', # Use action='store_true' for boolean flags
297
+ help=f"Exclude certain layers from quantization while quantizing T5XXL."
298
+ )
299
+ args = parser.parse_args()
300
+
301
+ input_file = args.input
302
+ output_file = args.output
303
+ t5xxl = args.t5xxl
304
+
305
+ if not os.path.exists(input_file):
306
+ print(f"Error: Input file not found: {input_file}")
307
+ return
308
+
309
+ fp8_type_str = TARGET_FP8_DTYPE.__str__().split('.')[-1] # e.g., float8_e4m3fn
310
+
311
+ if not output_file:
312
+ # Generate output file name based on input file
313
+ base_name = os.path.splitext(input_file)[0]
314
+ output_file = f"{base_name}_{fp8_type_str}_scaled_stochastic.safetensors"
315
+
316
+ # Prevent overwriting input file
317
+ if os.path.abspath(input_file) == os.path.abspath(output_file):
318
+ print("Error: Output file cannot be the same as the input file.")
319
+ # Suggest a modified name
320
+ base, ext = os.path.splitext(output_file)
321
+ output_file = f"{base}_converted{ext}"
322
+ print(f"Suggestion: Use --output {output_file}")
323
+ return
324
+
325
+ convert_to_fp8_scaled(input_file, output_file, t5xxl)
326
+
327
+ if __name__ == "__main__":
328
+ main()