#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import logging import argparse import json import safetensors.torch import os import sys from pathlib import Path from typing import Any, ContextManager, cast from torch import Tensor import numpy as np import torch import gguf # TODO: add more: SUPPORTED_ARCHS = ["flux", "sd3", "ltxv", "hyvid", "wan", "hidream", "qwen"] logger = logging.getLogger(__name__) class QuantConfig: ftype: gguf.LlamaFileType qtype: gguf.GGMLQuantizationType def __init__(self, ftype: gguf.LlamaFileType, qtype: gguf.GGMLQuantizationType): self.ftype = ftype self.qtype = qtype qconfig_map: dict[str, QuantConfig] = { "F16": QuantConfig(gguf.LlamaFileType.MOSTLY_F16, gguf.GGMLQuantizationType.F16), "BF16": QuantConfig(gguf.LlamaFileType.MOSTLY_BF16, gguf.GGMLQuantizationType.BF16), "Q8_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q8_0, gguf.GGMLQuantizationType.Q8_0), "Q6_K": QuantConfig(gguf.LlamaFileType.MOSTLY_Q6_K, gguf.GGMLQuantizationType.Q6_K), "Q5_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_K_S, gguf.GGMLQuantizationType.Q5_K), "Q5_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_1, gguf.GGMLQuantizationType.Q5_1), "Q5_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_0, gguf.GGMLQuantizationType.Q5_0), "Q4_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_K_S, gguf.GGMLQuantizationType.Q4_K), "Q4_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_1, gguf.GGMLQuantizationType.Q4_1), "Q4_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_0, gguf.GGMLQuantizationType.Q4_0), "Q3_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q3_K_S, gguf.GGMLQuantizationType.Q3_K), # "Q2_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q2_K, gguf.GGMLQuantizationType.Q2_K), # not yet supported in python } # tree of lazy tensors class LazyTorchTensor(gguf.LazyBase): _tensor_type = torch.Tensor # to keep the type-checker happy dtype: torch.dtype shape: torch.Size # only used when converting a torch.Tensor to a np.ndarray _dtype_map: dict[torch.dtype, type] = { torch.float16: np.float16, torch.float32: np.float32, } # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 _dtype_str_map: dict[str, torch.dtype] = { "F64": torch.float64, "F32": torch.float32, "BF16": torch.bfloat16, "F16": torch.float16, # "U64": torch.uint64, "I64": torch.int64, # "U32": torch.uint32, "I32": torch.int32, # "U16": torch.uint16, "I16": torch.int16, "U8": torch.uint8, "I8": torch.int8, "BOOL": torch.bool, "F8_E4M3": torch.float8_e4m3fn, "F8_E5M2": torch.float8_e5m2, } def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] return gguf.LazyNumpyTensor( meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), args=(self,), func=(lambda s: s.numpy()), ) @classmethod def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: return torch.empty(size=shape, dtype=dtype, device="meta") @classmethod def from_safetensors_slice(cls, st_slice: Any) -> Tensor: dtype = cls._dtype_str_map[st_slice.get_dtype()] shape: tuple[int, ...] = tuple(st_slice.get_shape()) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) return cast(torch.Tensor, lazy) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): del types # unused if kwargs is None: kwargs = {} if func is torch.Tensor.numpy: return args[0].numpy() return cls._wrap_fn(func)(*args, **kwargs) class Converter: path_safetensors: Path endianess: gguf.GGUFEndian outtype: QuantConfig outfile: Path gguf_writer: gguf.GGUFWriter def __init__( self, arch: str, path_safetensors: Path, endianess: gguf.GGUFEndian, outtype: QuantConfig, outfile: Path, subfolder: str = None, repo_id: str = None, is_diffusers: bool = False, ): self.path_safetensors = path_safetensors self.endianess = endianess self.outtype = outtype self.outfile = outfile self.gguf_writer = gguf.GGUFWriter(path=None, arch=arch, endianess=self.endianess) self.gguf_writer.add_file_type(self.outtype.ftype) self.gguf_writer.add_type("diffusion") # for HF hub to detect the type correctly if repo_id: self.gguf_writer.add_string("repo_id", repo_id) if subfolder: self.gguf_writer.add_string("subfolder", subfolder) if is_diffusers: self.gguf_writer.add_bool("is_diffusers", True) # load tensors and process from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(path_safetensors, framework="pt", device="cpu")) with ctx as model_part: for name in model_part.keys(): data = model_part.get_slice(name) data = LazyTorchTensor.from_safetensors_slice(data) self.process_tensor(name, data) def process_tensor(self, name: str, data_torch: LazyTorchTensor) -> None: is_1d = len(data_torch.shape) == 1 current_dtype = data_torch.dtype target_dtype = gguf.GGMLQuantizationType.F32 if is_1d else self.outtype.qtype if data_torch.dtype not in (torch.float16, torch.float32): data_torch = data_torch.to(torch.float32) data = data_torch.numpy() if current_dtype != target_dtype: from custom_quants import quantize as custom_quantize, QuantError try: data = custom_quantize(data, target_dtype) except QuantError as e: logger.warning("%s, %s", e, "falling back to F16") target_dtype = gguf.GGMLQuantizationType.F16 data = custom_quantize(data, target_dtype) # reverse shape to make it similar to the internal ggml dimension order shape = gguf.quant_shape_from_byte_shape(data.shape, target_dtype) if data.dtype == np.uint8 else data.shape shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" logger.info(f"{f'%-32s' % f'{name},'} {current_dtype} --> {target_dtype.name}, shape = {shape_str}") # add tensor to gguf self.gguf_writer.add_tensor(name, data, raw_dtype=target_dtype) def write(self) -> None: self.gguf_writer.write_header_to_file(path=self.outfile) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 def _merge_sharded_checkpoints(folder: Path): with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: ckpt_metadata = json.load(f) weight_map = ckpt_metadata.get("weight_map", None) if weight_map is None: raise KeyError("'weight_map' key not found in the shard index file.") # Collect all unique safetensors files from weight_map files_to_load = set(weight_map.values()) merged_state_dict = {} # Load tensors from each unique file for file_name in files_to_load: part_file_path = folder / file_name if not os.path.exists(part_file_path): raise FileNotFoundError(f"Part file {file_name} not found.") with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: for tensor_key in f.keys(): if tensor_key in weight_map: merged_state_dict[tensor_key] = f.get_tensor(tensor_key) return merged_state_dict def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Convert a flux model to GGUF") parser.add_argument( "--outfile", type=Path, default=Path("model-{ftype}.gguf"), help="path to write to; default: 'model-{ftype}.gguf' ; note: {ftype} will be replaced by the outtype", ) parser.add_argument( "--outtype", type=str, choices=qconfig_map.keys(), default="F16", help="output quantization scheme", ) parser.add_argument( "--arch", type=str, choices=SUPPORTED_ARCHS, help="output model architecture", ) parser.add_argument( "--bigendian", action="store_true", help="model is executed on big endian machine", ) parser.add_argument( "model", type=Path, help="directory containing safetensors model file", nargs="?", ) parser.add_argument("--cache_dir", type=Path, help="Directory to store the intermediate files when needed.") parser.add_argument( "--subfolder", type=Path, default=None, help="Subfolder on the HF Hub to load checkpoints from." ) parser.add_argument( "--verbose", action="store_true", help="increase output verbosity", ) args = parser.parse_args() if args.model is None: parser.error("the following arguments are required: model") if args.arch is None: parser.error("the following arguments are required: --arch") if args.arch not in SUPPORTED_ARCHS: parser.error(f"Unsupported architecture: {args.arch}. Supported architectures: {', '.join(SUPPORTED_ARCHS)}") return args def convert(args): if args.verbose: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) if not args.model.is_dir() and not args.model.is_file(): if not len(str(args.model).split("/")) == 2: logging.error(f"Model path {args.model} does not exist.") sys.exit(1) is_diffusers = False repo_id = None merged_state_dict = None if args.model.is_dir(): logging.info("Supplied a directory.") files = list(args.model.glob("*.safetensors")) n = len(files) if n == 0: logging.error("No safetensors files found.") sys.exit(1) if n == 1: logging.info(f"Assinging {files[0]} to `args.model`") args.model = files[0] if n > 1: assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) assert args.cache_dir merged_state_dict = _merge_sharded_checkpoints(args.model) filepath = args.cache_dir / "merged_state_dict.safetensors" safetensors.torch.save_file(merged_state_dict, filepath) logging.info(f"Serialized merged state dict to {filepath}") args.model = Path(filepath) elif len(str(args.model).split("/")) == 2: from huggingface_hub import snapshot_download logging.info("Hub repo ID detected.") allow_patterns = f"{args.subfolder}/*.*" if args.subfolder else None local_dir = snapshot_download( repo_id=str(args.model), local_dir=args.cache_dir, allow_patterns=allow_patterns, token=args.hf_token ) repo_id = str(args.model) local_dir = Path(local_dir) local_dir = local_dir / args.subfolder if args.subfolder else local_dir merged_state_dict = _merge_sharded_checkpoints(local_dir) filepath = ( args.cache_dir / "merged_state_dict.safetensors" if args.cache_dir else "merged_state_dict.safetensors" ) safetensors.torch.save_file(merged_state_dict, filepath) logging.info(f"Serialized merged state dict to {filepath}") args.model = Path(filepath) is_diffusers = True if args.model.suffix != ".safetensors": logging.error(f"Model path {args.model} is not a safetensors file.") sys.exit(1) if args.outfile.suffix != ".gguf": logging.error("Output file must have .gguf extension.") sys.exit(1) qconfig = qconfig_map[args.outtype] outfile = Path(str(args.outfile).format(ftype=args.outtype.upper())) logger.info(f"Converting model in {args.model} to {outfile} with quantization {args.outtype}") converter = Converter( arch=args.arch, path_safetensors=args.model, endianess=gguf.GGUFEndian.BIG if args.bigendian else gguf.GGUFEndian.LITTLE, outtype=qconfig, outfile=outfile, repo_id=repo_id, subfolder=str(args.subfolder) if args.subfolder else None, is_diffusers=is_diffusers, ) converter.write() logger.info( f"Conversion complete. Output written to {outfile}, architecture: {args.arch}, quantization: {qconfig.qtype.name}" ) if merged_state_dict is not None: os.remove(filepath) logging.info(f"Removed the intermediate {filepath}.")