|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
import argparse |
|
import json |
|
import safetensors.torch |
|
import os |
|
import sys |
|
from pathlib import Path |
|
from typing import Any, ContextManager, cast |
|
from torch import Tensor |
|
|
|
import numpy as np |
|
import torch |
|
import gguf |
|
|
|
|
|
SUPPORTED_ARCHS = ["flux", "sd3", "ltxv", "hyvid", "wan", "hidream", "qwen"] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class QuantConfig: |
|
ftype: gguf.LlamaFileType |
|
qtype: gguf.GGMLQuantizationType |
|
|
|
def __init__(self, ftype: gguf.LlamaFileType, qtype: gguf.GGMLQuantizationType): |
|
self.ftype = ftype |
|
self.qtype = qtype |
|
|
|
|
|
qconfig_map: dict[str, QuantConfig] = { |
|
"F16": QuantConfig(gguf.LlamaFileType.MOSTLY_F16, gguf.GGMLQuantizationType.F16), |
|
"BF16": QuantConfig(gguf.LlamaFileType.MOSTLY_BF16, gguf.GGMLQuantizationType.BF16), |
|
"Q8_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q8_0, gguf.GGMLQuantizationType.Q8_0), |
|
"Q6_K": QuantConfig(gguf.LlamaFileType.MOSTLY_Q6_K, gguf.GGMLQuantizationType.Q6_K), |
|
"Q5_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_K_S, gguf.GGMLQuantizationType.Q5_K), |
|
"Q5_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_1, gguf.GGMLQuantizationType.Q5_1), |
|
"Q5_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_0, gguf.GGMLQuantizationType.Q5_0), |
|
"Q4_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_K_S, gguf.GGMLQuantizationType.Q4_K), |
|
"Q4_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_1, gguf.GGMLQuantizationType.Q4_1), |
|
"Q4_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_0, gguf.GGMLQuantizationType.Q4_0), |
|
"Q3_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q3_K_S, gguf.GGMLQuantizationType.Q3_K), |
|
|
|
} |
|
|
|
|
|
|
|
class LazyTorchTensor(gguf.LazyBase): |
|
_tensor_type = torch.Tensor |
|
|
|
dtype: torch.dtype |
|
shape: torch.Size |
|
|
|
|
|
_dtype_map: dict[torch.dtype, type] = { |
|
torch.float16: np.float16, |
|
torch.float32: np.float32, |
|
} |
|
|
|
|
|
|
|
|
|
_dtype_str_map: dict[str, torch.dtype] = { |
|
"F64": torch.float64, |
|
"F32": torch.float32, |
|
"BF16": torch.bfloat16, |
|
"F16": torch.float16, |
|
|
|
"I64": torch.int64, |
|
|
|
"I32": torch.int32, |
|
|
|
"I16": torch.int16, |
|
"U8": torch.uint8, |
|
"I8": torch.int8, |
|
"BOOL": torch.bool, |
|
"F8_E4M3": torch.float8_e4m3fn, |
|
"F8_E5M2": torch.float8_e5m2, |
|
} |
|
|
|
def numpy(self) -> gguf.LazyNumpyTensor: |
|
dtype = self._dtype_map[self.dtype] |
|
return gguf.LazyNumpyTensor( |
|
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), |
|
args=(self,), |
|
func=(lambda s: s.numpy()), |
|
) |
|
|
|
@classmethod |
|
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: |
|
return torch.empty(size=shape, dtype=dtype, device="meta") |
|
|
|
@classmethod |
|
def from_safetensors_slice(cls, st_slice: Any) -> Tensor: |
|
dtype = cls._dtype_str_map[st_slice.get_dtype()] |
|
shape: tuple[int, ...] = tuple(st_slice.get_shape()) |
|
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) |
|
return cast(torch.Tensor, lazy) |
|
|
|
@classmethod |
|
def __torch_function__(cls, func, types, args=(), kwargs=None): |
|
del types |
|
|
|
if kwargs is None: |
|
kwargs = {} |
|
|
|
if func is torch.Tensor.numpy: |
|
return args[0].numpy() |
|
|
|
return cls._wrap_fn(func)(*args, **kwargs) |
|
|
|
|
|
class Converter: |
|
path_safetensors: Path |
|
endianess: gguf.GGUFEndian |
|
outtype: QuantConfig |
|
outfile: Path |
|
gguf_writer: gguf.GGUFWriter |
|
|
|
def __init__( |
|
self, |
|
arch: str, |
|
path_safetensors: Path, |
|
endianess: gguf.GGUFEndian, |
|
outtype: QuantConfig, |
|
outfile: Path, |
|
subfolder: str = None, |
|
repo_id: str = None, |
|
is_diffusers: bool = False, |
|
): |
|
self.path_safetensors = path_safetensors |
|
self.endianess = endianess |
|
self.outtype = outtype |
|
self.outfile = outfile |
|
|
|
self.gguf_writer = gguf.GGUFWriter(path=None, arch=arch, endianess=self.endianess) |
|
self.gguf_writer.add_file_type(self.outtype.ftype) |
|
self.gguf_writer.add_type("diffusion") |
|
if repo_id: |
|
self.gguf_writer.add_string("repo_id", repo_id) |
|
if subfolder: |
|
self.gguf_writer.add_string("subfolder", subfolder) |
|
if is_diffusers: |
|
self.gguf_writer.add_bool("is_diffusers", True) |
|
|
|
|
|
from safetensors import safe_open |
|
|
|
ctx = cast(ContextManager[Any], safe_open(path_safetensors, framework="pt", device="cpu")) |
|
with ctx as model_part: |
|
for name in model_part.keys(): |
|
data = model_part.get_slice(name) |
|
data = LazyTorchTensor.from_safetensors_slice(data) |
|
self.process_tensor(name, data) |
|
|
|
def process_tensor(self, name: str, data_torch: LazyTorchTensor) -> None: |
|
is_1d = len(data_torch.shape) == 1 |
|
current_dtype = data_torch.dtype |
|
target_dtype = gguf.GGMLQuantizationType.F32 if is_1d else self.outtype.qtype |
|
|
|
if data_torch.dtype not in (torch.float16, torch.float32): |
|
data_torch = data_torch.to(torch.float32) |
|
|
|
data = data_torch.numpy() |
|
|
|
if current_dtype != target_dtype: |
|
from custom_quants import quantize as custom_quantize, QuantError |
|
|
|
try: |
|
data = custom_quantize(data, target_dtype) |
|
except QuantError as e: |
|
logger.warning("%s, %s", e, "falling back to F16") |
|
target_dtype = gguf.GGMLQuantizationType.F16 |
|
data = custom_quantize(data, target_dtype) |
|
|
|
|
|
shape = gguf.quant_shape_from_byte_shape(data.shape, target_dtype) if data.dtype == np.uint8 else data.shape |
|
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" |
|
logger.info(f"{f'%-32s' % f'{name},'} {current_dtype} --> {target_dtype.name}, shape = {shape_str}") |
|
|
|
|
|
self.gguf_writer.add_tensor(name, data, raw_dtype=target_dtype) |
|
|
|
def write(self) -> None: |
|
self.gguf_writer.write_header_to_file(path=self.outfile) |
|
self.gguf_writer.write_kv_data_to_file() |
|
self.gguf_writer.write_tensors_to_file(progress=True) |
|
self.gguf_writer.close() |
|
|
|
|
|
|
|
def _merge_sharded_checkpoints(folder: Path): |
|
with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: |
|
ckpt_metadata = json.load(f) |
|
weight_map = ckpt_metadata.get("weight_map", None) |
|
if weight_map is None: |
|
raise KeyError("'weight_map' key not found in the shard index file.") |
|
|
|
|
|
files_to_load = set(weight_map.values()) |
|
merged_state_dict = {} |
|
|
|
|
|
for file_name in files_to_load: |
|
part_file_path = folder / file_name |
|
if not os.path.exists(part_file_path): |
|
raise FileNotFoundError(f"Part file {file_name} not found.") |
|
|
|
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: |
|
for tensor_key in f.keys(): |
|
if tensor_key in weight_map: |
|
merged_state_dict[tensor_key] = f.get_tensor(tensor_key) |
|
|
|
return merged_state_dict |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser(description="Convert a flux model to GGUF") |
|
parser.add_argument( |
|
"--outfile", |
|
type=Path, |
|
default=Path("model-{ftype}.gguf"), |
|
help="path to write to; default: 'model-{ftype}.gguf' ; note: {ftype} will be replaced by the outtype", |
|
) |
|
parser.add_argument( |
|
"--outtype", |
|
type=str, |
|
choices=qconfig_map.keys(), |
|
default="F16", |
|
help="output quantization scheme", |
|
) |
|
parser.add_argument( |
|
"--arch", |
|
type=str, |
|
choices=SUPPORTED_ARCHS, |
|
help="output model architecture", |
|
) |
|
parser.add_argument( |
|
"--bigendian", |
|
action="store_true", |
|
help="model is executed on big endian machine", |
|
) |
|
parser.add_argument( |
|
"model", |
|
type=Path, |
|
help="directory containing safetensors model file", |
|
nargs="?", |
|
) |
|
parser.add_argument("--cache_dir", type=Path, help="Directory to store the intermediate files when needed.") |
|
parser.add_argument( |
|
"--subfolder", type=Path, default=None, help="Subfolder on the HF Hub to load checkpoints from." |
|
) |
|
parser.add_argument( |
|
"--verbose", |
|
action="store_true", |
|
help="increase output verbosity", |
|
) |
|
|
|
args = parser.parse_args() |
|
if args.model is None: |
|
parser.error("the following arguments are required: model") |
|
if args.arch is None: |
|
parser.error("the following arguments are required: --arch") |
|
if args.arch not in SUPPORTED_ARCHS: |
|
parser.error(f"Unsupported architecture: {args.arch}. Supported architectures: {', '.join(SUPPORTED_ARCHS)}") |
|
return args |
|
|
|
|
|
def convert(args): |
|
if args.verbose: |
|
logging.basicConfig(level=logging.DEBUG) |
|
else: |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
if not args.model.is_dir() and not args.model.is_file(): |
|
if not len(str(args.model).split("/")) == 2: |
|
logging.error(f"Model path {args.model} does not exist.") |
|
sys.exit(1) |
|
|
|
is_diffusers = False |
|
repo_id = None |
|
merged_state_dict = None |
|
if args.model.is_dir(): |
|
logging.info("Supplied a directory.") |
|
files = list(args.model.glob("*.safetensors")) |
|
n = len(files) |
|
if n == 0: |
|
logging.error("No safetensors files found.") |
|
sys.exit(1) |
|
if n == 1: |
|
logging.info(f"Assinging {files[0]} to `args.model`") |
|
args.model = files[0] |
|
if n > 1: |
|
assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) |
|
assert args.cache_dir |
|
merged_state_dict = _merge_sharded_checkpoints(args.model) |
|
filepath = args.cache_dir / "merged_state_dict.safetensors" |
|
safetensors.torch.save_file(merged_state_dict, filepath) |
|
logging.info(f"Serialized merged state dict to {filepath}") |
|
args.model = Path(filepath) |
|
|
|
elif len(str(args.model).split("/")) == 2: |
|
from huggingface_hub import snapshot_download |
|
|
|
logging.info("Hub repo ID detected.") |
|
allow_patterns = f"{args.subfolder}/*.*" if args.subfolder else None |
|
local_dir = snapshot_download( |
|
repo_id=str(args.model), local_dir=args.cache_dir, allow_patterns=allow_patterns, token=args.hf_token |
|
) |
|
repo_id = str(args.model) |
|
local_dir = Path(local_dir) |
|
local_dir = local_dir / args.subfolder if args.subfolder else local_dir |
|
merged_state_dict = _merge_sharded_checkpoints(local_dir) |
|
filepath = ( |
|
args.cache_dir / "merged_state_dict.safetensors" if args.cache_dir else "merged_state_dict.safetensors" |
|
) |
|
safetensors.torch.save_file(merged_state_dict, filepath) |
|
logging.info(f"Serialized merged state dict to {filepath}") |
|
args.model = Path(filepath) |
|
is_diffusers = True |
|
|
|
if args.model.suffix != ".safetensors": |
|
logging.error(f"Model path {args.model} is not a safetensors file.") |
|
sys.exit(1) |
|
|
|
if args.outfile.suffix != ".gguf": |
|
logging.error("Output file must have .gguf extension.") |
|
sys.exit(1) |
|
|
|
qconfig = qconfig_map[args.outtype] |
|
outfile = Path(str(args.outfile).format(ftype=args.outtype.upper())) |
|
|
|
logger.info(f"Converting model in {args.model} to {outfile} with quantization {args.outtype}") |
|
converter = Converter( |
|
arch=args.arch, |
|
path_safetensors=args.model, |
|
endianess=gguf.GGUFEndian.BIG if args.bigendian else gguf.GGUFEndian.LITTLE, |
|
outtype=qconfig, |
|
outfile=outfile, |
|
repo_id=repo_id, |
|
subfolder=str(args.subfolder) if args.subfolder else None, |
|
is_diffusers=is_diffusers, |
|
) |
|
converter.write() |
|
logger.info( |
|
f"Conversion complete. Output written to {outfile}, architecture: {args.arch}, quantization: {qconfig.qtype.name}" |
|
) |
|
if merged_state_dict is not None: |
|
os.remove(filepath) |
|
logging.info(f"Removed the intermediate {filepath}.") |
|
|