|
from abc import ABC |
|
from typing import Optional |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
class QuantizationConfig(ABC, BaseModel): |
|
pass |
|
class ConvertRequest(ABC, BaseModel): |
|
hf_model_name: str |
|
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name") |
|
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models") |
|
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.") |
|
|
|
|
|
|
|
class AWQQuantizationConfig(QuantizationConfig): |
|
zero_point: Optional[bool] = Field(True, description="Use zero point quantization") |
|
q_group_size: Optional[int] = Field(128, description="Quantization group size") |
|
w_bit: Optional[int] = Field(4, description="Weight bit") |
|
version: Optional[str] = Field("GEMM", description="Quantization version") |
|
|
|
class GPTQQuantizationConfig(QuantizationConfig): |
|
pass |
|
|
|
class GGUFQuantizationConfig(QuantizationConfig): |
|
pass |
|
class AWQConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[AWQQuantizationConfig] = Field( |
|
default_factory=lambda: AWQQuantizationConfig(), |
|
description="AWQ quantization configuration" |
|
) |
|
|
|
class GPTQConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[GPTQQuantizationConfig] = Field( |
|
default_factory=lambda: GPTQQuantizationConfig(), |
|
description="GPTQ quantization configuration" |
|
) |
|
|
|
class GGUFConvertionRequest(ConvertRequest): |
|
quantization_config: Optional[GGUFQuantizationConfig] = Field( |
|
default_factory=lambda: GGUFQuantizationConfig(), |
|
description="GGUF quantization configuration" |
|
) |
|
|