DeepSeek-R1-Block-INT8 / inference /bf16_cast_block_int8.py
pkumc's picture
Add files using upload-large-folder tool
8e8c28e verified
raw
history blame
2.47 kB
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
from huggingface_hub import snapshot_download
from kernel import weight_quant
def main(bf16_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(int8_path, exist_ok=True)
model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
if not os.path.exists(model_index_file):
snapshot_download(
repo_id=model_name,
allow_patterns=["model.safetensors.index.json"],
local_dir=int8_path,
local_dir_use_symlinks=False
)
print(f"model index file downloaded to {model_index_file}")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")])
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()
quant_count = 0
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
state_dict = load_file(safetensor_file, device="cuda")
new_state_dict = {}
for weight_name, weight in state_dict.items():
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
assert weight.element_size() == 2
quant_count += 1
int8_weight, scale_inv = weight_quant(weight)
new_state_dict[weight_name] = int8_weight
new_state_dict[scale_inv_name] = scale_inv
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(int8_path, file_name)
save_file(new_state_dict, new_safetensor_file)
assert quant_count == scale_count
print(f"{quant_count} weights are quantized.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-bf16-hf-path", type=str, required=True)
parser.add_argument("--output-int8-hf-path", type=str, required=True)
parser.add_argument("--model-name", type=str, default="deepseek-ai/DeepSeek-R1")
args = parser.parse_args()
main(args.input_bf16_hf_path, args.output_int8_hf_path, args.model_name)
print("done")