import os import sys import time import random import torch from collections import UserDict from packaging.version import Version from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor.common import logger from neural_compressor.torch.utils import is_hpex_available, get_torch_version # ====== utils.py content inlined and fixed ====== class DataloaderPreprocessor: def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None: self.dataloader_original = dataloader_original self.use_max_length = use_max_length self.max_seq_length = max_seq_length self.nsamples = nsamples self.dataloader = [] self.is_ready = False def get_prepared_dataloader(self): if not self.is_ready: self.prepare_dataloader() return self.dataloader def prepare_dataloader(self): if self.use_max_length: self.obtain_first_n_samples_fulllength() else: self.obtain_first_n_samples() self.is_ready = True def obtain_first_n_samples(self, seed=0): """Get first nsample data as the real calibration dataset.""" self.dataloader.clear() random.seed(seed) for batch in self.dataloader_original: if len(self.dataloader) == self.nsamples: logger.info(f"Successfully collect {self.nsamples} calibration samples.") break # list, tuple if isinstance(batch, list) or isinstance(batch, tuple): if batch[0].shape[-1] > self.max_seq_length: i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1) j = i + self.max_seq_length batch_final = [] for item in batch: if isinstance(item, torch.Tensor) and item.ndim == 2: batch_final.append(item[:, i:j]) else: batch_final.append(item) else: batch_final = batch[:] # dict elif isinstance(batch, dict): try: length = batch["input_ids"].shape[-1] except Exception: logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") continue batch_final = {} if length > self.max_seq_length: i = random.randint(0, length - self.max_seq_length - 1) j = i + self.max_seq_length for key in batch.keys(): if isinstance(batch[key], torch.Tensor): batch_final[key] = batch[key][:, i:j] else: batch_final[key] = batch[key] else: batch_final = batch # tensor else: if batch.shape[-1] > self.max_seq_length: i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1) j = i + self.max_seq_length batch_final = batch[:, i:j] else: batch_final = batch self.dataloader.append(batch_final) if len(self.dataloader) < self.nsamples: logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") def obtain_first_n_samples_fulllength(self, seed=0): self.dataloader.clear() random.seed(seed) unified_length = self.max_seq_length for batch in self.dataloader_original: if len(self.dataloader) == self.nsamples: logger.info(f"Successfully collect {self.nsamples} calibration samples.") break # list & tuple if isinstance(batch, list) or isinstance(batch, tuple): if batch[0].shape[-1] == unified_length: batch_final = batch[:] elif batch[0].shape[-1] > unified_length: i = random.randint(0, batch[0].shape[-1] - unified_length - 1) j = i + unified_length batch_final = [] for item in batch: if isinstance(item, torch.Tensor) and item.ndim == 2: batch_final.append(item[:, i:j]) else: batch_final.append(item) else: continue # dict elif isinstance(batch, dict): try: length = batch["input_ids"].shape[-1] except Exception: logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") continue batch_final = {} if length == self.max_seq_length: batch_final = batch elif length > self.max_seq_length: i = random.randint(0, length - self.max_seq_length - 1) j = i + self.max_seq_length for key in batch.keys(): if isinstance(batch[key], torch.Tensor): batch_final[key] = batch[key][:, i:j] else: batch_final[key] = batch[key] else: continue # tensor else: if batch.shape[-1] == unified_length: batch_final = batch elif batch.shape[-1] > unified_length: i = random.randint(0, batch.shape[-1] - unified_length - 1) j = i + unified_length batch_final = batch[:, i:j] else: continue self.dataloader.append(batch_final) if len(self.dataloader) < self.nsamples: logger.warning( f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, " f"but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value." ) def get_example_inputs(model, dataloader): version = get_torch_version() from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device if dataloader is None: return None device = next(model.parameters()).device try: for idx, (input, label) in enumerate(dataloader): input = move_input_to_device(input, device) if isinstance(input, (dict, UserDict)): assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0" if "label" in input.keys(): input.pop("label") if version.release <= Version("2.0.1").release: return tuple(input.values()) else: return dict(input) if isinstance(input, (list, tuple)): return tuple(input) if isinstance(input, torch.Tensor): return input break except Exception as e: for idx, input in enumerate(dataloader): input = move_input_to_device(input, device) if isinstance(input, (dict, UserDict)): assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0" if "label" in input.keys(): input.pop("label") if version.release <= Version("2.0.1").release: return tuple(input.values()) else: return dict(input) if isinstance(input, list) or isinstance(input, tuple): return tuple(input) if isinstance(input, torch.Tensor): return input break if idx == 0: assert False, "Please checkout the example_inputs format." # ====== End of utils.py content ====== # ====== Hardcoded arguments ====== class Args: model = "meta-llama/Llama-3.2-3B" trust_remote_code = True revision = None dataset = "neuralmagic/LLM_compression_calibration" output_dir = "meta-llama_Llama-3.2-3B-TEQ-int4-gs128-asym" quantize = True seed = 42 load = False accuracy = False performance = False iters = 100 batch_size = 1 pad_max_length = 512 calib_iters = 512 tasks = "lambada_openai,hellaswag,winogrande,piqa" peft_model_id = None # Weight-only quantization configs woq_algo = "TEQ" woq_bits = 4 woq_dtype = "int" woq_group_size = 128 woq_group_dim = 1 woq_scheme = "asym" woq_use_mse_search = False woq_use_full_range = False quant_lm_head = True use_hf_format = False # TEQ/AWQ configs use_auto_scale = False use_auto_clip = False folding = False absorb_layer_dict = {} # DoubleQuant configs double_quant_type = None double_quant_dtype = "fp32" double_quant_bits = 8 double_quant_use_sym = True double_quant_group_size = 256 args = Args() calib_size = 1 if is_hpex_available(): import habana_frameworks.torch.core as htcore htcore.hpu_set_inference_env() device = "hpu" else: device = "cpu" # ====== Helper functions ====== def get_user_model(): torchscript = False if args.woq_algo in ["AWQ", "TEQ"]: torchscript = True user_model = AutoModelForCausalLM.from_pretrained( args.model, torchscript=torchscript, trust_remote_code=args.trust_remote_code, revision=args.revision, ) tokenizer = AutoTokenizer.from_pretrained(args.model) user_model = user_model.float() user_model = user_model.to(memory_format=torch.channels_last) user_model.eval() return user_model, tokenizer def calib_func(prepared_model): for i, calib_input in enumerate(calib_dataloader): if i > args.calib_iters: break prepared_model(calib_input[0]) # ====== Main quantization logic ====== if args.quantize: user_model, tokenizer = get_user_model() calib_dataset = load_dataset(args.dataset, split="train") calib_dataset = calib_dataset.shuffle(seed=args.seed) class Evaluator: def __init__(self, dataset, tokenizer, batch_size=8, pad_val=1, pad_max=196, is_calib=False): self.dataset = dataset self.tokenizer = tokenizer self.batch_size = batch_size self.pad_val = pad_val self.pad_max = pad_max self.is_calib = is_calib self.dataset = self.dataset.map(self.tokenize_function, batched=True) self.dataset.set_format(type="torch", columns=["input_ids"]) @torch.no_grad() def tokenize_function(self, examples): if args.woq_algo in ['TEQ']: if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token example = self.tokenizer(examples["text"], padding="max_length", max_length=self.pad_max) else: example = self.tokenizer(examples["text"]) return example @torch.no_grad() def collate_batch(self, batch): input_ids_padded = [] last_ind = [] for text in batch: input_ids = text["input_ids"] pad_len = self.pad_max - input_ids.shape[0] last_ind.append(input_ids.shape[0] - 1) input_ids = input_ids[:self.pad_max] if len(input_ids) > self.pad_max else input_ids input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=self.pad_val) input_ids_padded.append(input_ids) return (torch.vstack(input_ids_padded), torch.tensor(last_ind)) calib_evaluator = Evaluator(calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True) calib_dataloader = DataLoader( calib_evaluator.dataset, batch_size=calib_size, shuffle=False, collate_fn=calib_evaluator.collate_batch, ) # === TEQ quantization === from neural_compressor.torch.quantization import TEQConfig, prepare, convert weight_sym = True if args.woq_scheme == "sym" else False quant_config = TEQConfig( dtype=args.woq_dtype, bits=args.woq_bits, use_sym=weight_sym, group_size=args.woq_group_size, group_dim=args.woq_group_dim, folding=args.folding, quant_lm_head=args.quant_lm_head, ) example_inputs = torch.ones([1, args.pad_max_length], dtype=torch.long) run_fn = calib_func user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(user_model) user_model = convert(user_model) # === Save quantized model === os.makedirs(args.output_dir, exist_ok=True) print("Saving weight-only quantized model to", args.output_dir) if args.use_hf_format: user_model.save(args.output_dir, format="huggingface") tokenizer.save_pretrained(args.output_dir) else: user_model.save(args.output_dir) print("Saved weight-only quantized model.") else: print("Quantization not enabled. Exiting.")