File size: 1,822 Bytes
3081464 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from abc import ABC
from typing import Optional
from pydantic import BaseModel, Field
### DTO Definitions
class QuantizationConfig(ABC, BaseModel):
pass
class ConvertRequest(ABC, BaseModel):
hf_model_name: str
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
### -------
### Quantization Configurations
class AWQQuantizationConfig(QuantizationConfig):
zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
q_group_size: Optional[int] = Field(128, description="Quantization group size")
w_bit: Optional[int] = Field(4, description="Weight bit")
version: Optional[str] = Field("GEMM", description="Quantization version")
class GPTQQuantizationConfig(QuantizationConfig):
pass
class GGUFQuantizationConfig(QuantizationConfig):
pass
class AWQConvertionRequest(ConvertRequest):
quantization_config: Optional[AWQQuantizationConfig] = Field(
default_factory=lambda: AWQQuantizationConfig(),
description="AWQ quantization configuration"
)
class GPTQConvertionRequest(ConvertRequest):
quantization_config: Optional[GPTQQuantizationConfig] = Field(
default_factory=lambda: GPTQQuantizationConfig(),
description="GPTQ quantization configuration"
)
class GGUFConvertionRequest(ConvertRequest):
quantization_config: Optional[GGUFQuantizationConfig] = Field(
default_factory=lambda: GGUFQuantizationConfig(),
description="GGUF quantization configuration"
)
### ------- |