Rúben Almeida commited on
Commit
3081464
·
1 Parent(s): edebf90

Update version of requirements

Browse files
Files changed (4) hide show
  1. dto.py +45 -0
  2. main.py +4 -45
  3. requirements.txt +3 -3
  4. tests/test_awq.py +5 -3
dto.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Optional
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ ### DTO Definitions
7
+ class QuantizationConfig(ABC, BaseModel):
8
+ pass
9
+ class ConvertRequest(ABC, BaseModel):
10
+ hf_model_name: str
11
+ hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
12
+ hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
13
+ hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
14
+ ### -------
15
+
16
+ ### Quantization Configurations
17
+ class AWQQuantizationConfig(QuantizationConfig):
18
+ zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
19
+ q_group_size: Optional[int] = Field(128, description="Quantization group size")
20
+ w_bit: Optional[int] = Field(4, description="Weight bit")
21
+ version: Optional[str] = Field("GEMM", description="Quantization version")
22
+
23
+ class GPTQQuantizationConfig(QuantizationConfig):
24
+ pass
25
+
26
+ class GGUFQuantizationConfig(QuantizationConfig):
27
+ pass
28
+ class AWQConvertionRequest(ConvertRequest):
29
+ quantization_config: Optional[AWQQuantizationConfig] = Field(
30
+ default_factory=lambda: AWQQuantizationConfig(),
31
+ description="AWQ quantization configuration"
32
+ )
33
+
34
+ class GPTQConvertionRequest(ConvertRequest):
35
+ quantization_config: Optional[GPTQQuantizationConfig] = Field(
36
+ default_factory=lambda: GPTQQuantizationConfig(),
37
+ description="GPTQ quantization configuration"
38
+ )
39
+
40
+ class GGUFConvertionRequest(ConvertRequest):
41
+ quantization_config: Optional[GGUFQuantizationConfig] = Field(
42
+ default_factory=lambda: GGUFQuantizationConfig(),
43
+ description="GGUF quantization configuration"
44
+ )
45
+ ### -------
main.py CHANGED
@@ -1,13 +1,12 @@
1
  import zipfile
2
- from abc import ABC
3
- from typing import Optional, Union
4
  from awq import AutoAWQForCausalLM
5
- from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer
7
  from tempfile import NamedTemporaryFile
8
  from contextlib import asynccontextmanager
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import RedirectResponse, FileResponse
 
11
 
12
  ### FastAPI Initialization
13
  @asynccontextmanager
@@ -17,46 +16,6 @@ async def lifespan(app:FastAPI):
17
  app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
18
  ### -------
19
 
20
- ### DTO Definitions
21
- class QuantizationConfig(ABC, BaseModel):
22
- pass
23
- class ConvertRequest(ABC, BaseModel):
24
- hf_model_name: str
25
- hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
26
- hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
27
- hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
28
- ### -------
29
-
30
- ### Quantization Configurations
31
- class AWQQuantizationConfig(QuantizationConfig):
32
- zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
33
- q_group_size: Optional[int] = Field(128, description="Quantization group size")
34
- w_bit: Optional[int] = Field(4, description="Weight bit")
35
- version: Optional[str] = Field("GEMM", description="Quantization version")
36
-
37
- class GPTQQuantizationConfig(QuantizationConfig):
38
- pass
39
-
40
- class GGUFQuantizationConfig(QuantizationConfig):
41
- pass
42
- class AWQConvertionRequest(ConvertRequest):
43
- quantization_config: Optional[AWQQuantizationConfig] = Field(
44
- default_factory=lambda: AWQQuantizationConfig(),
45
- description="AWQ quantization configuration"
46
- )
47
-
48
- class GPTQConvertionRequest(ConvertRequest):
49
- quantization_config: Optional[GPTQQuantizationConfig] = Field(
50
- default_factory=lambda: GPTQQuantizationConfig(),
51
- description="GPTQ quantization configuration"
52
- )
53
-
54
- class GGUFConvertionRequest(ConvertRequest):
55
- quantization_config: Optional[GGUFQuantizationConfig] = Field(
56
- default_factory=lambda: GGUFQuantizationConfig(),
57
- description="GGUF quantization configuration"
58
- )
59
- ### -------
60
 
61
  @app.get("/", include_in_schema=False)
62
  def redirect_to_docs():
@@ -102,11 +61,11 @@ def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]:
102
  raise HTTPException(status_code=500, detail="Failed to convert model")
103
 
104
  @app.post("/convert_gpt_q", response_model=None)
105
- def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]:
106
  raise HTTPException(status_code=501, detail="Not implemented yet")
107
 
108
  @app.post("/convert_gguf", response_model=None)
109
- def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]:
110
  raise HTTPException(status_code=501, detail="Not implemented yet")
111
 
112
  @app.get("/health")
 
1
  import zipfile
2
+ from typing import Union
 
3
  from awq import AutoAWQForCausalLM
 
4
  from transformers import AutoTokenizer
5
  from tempfile import NamedTemporaryFile
6
  from contextlib import asynccontextmanager
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.responses import RedirectResponse, FileResponse
9
+ from .dto import AWQConvertionRequest, GGUFConvertionRequest, GPTQConvertionRequest
10
 
11
  ### FastAPI Initialization
12
  @asynccontextmanager
 
16
  app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
17
  ### -------
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @app.get("/", include_in_schema=False)
21
  def redirect_to_docs():
 
61
  raise HTTPException(status_code=500, detail="Failed to convert model")
62
 
63
  @app.post("/convert_gpt_q", response_model=None)
64
+ def convert_gpt_q(request: GPTQConvertionRequest)->Union[FileResponse, dict]:
65
  raise HTTPException(status_code=501, detail="Not implemented yet")
66
 
67
  @app.post("/convert_gguf", response_model=None)
68
+ def convert_gguf(request: GGUFConvertionRequest)->Union[FileResponse, dict]:
69
  raise HTTPException(status_code=501, detail="Not implemented yet")
70
 
71
  @app.get("/health")
requirements.txt CHANGED
@@ -5,10 +5,10 @@ torchaudio
5
  setuptools
6
  wheel
7
  pydantic
8
- fastapi[standard]
9
- transformers
10
  huggingface_hub
11
- autoawq[kernels]
12
  starlette>=0.46.2
13
  pytest
14
  requests
 
5
  setuptools
6
  wheel
7
  pydantic
8
+ fastapi[standard]>=0.115.12
9
+ transformers>=4.51.3
10
  huggingface_hub
11
+ autoawq[kernels]>=0.2.8
12
  starlette>=0.46.2
13
  pytest
14
  requests
tests/test_awq.py CHANGED
@@ -1,7 +1,6 @@
1
  import pytest
2
  import requests
3
  from environs import Env
4
- from huggingface_hub import login
5
 
6
  env = Env()
7
  env.read_env(override=True)
@@ -16,6 +15,9 @@ def test_incompatible_model():
16
  "hf_push_repo": None,
17
  }
18
  )
 
 
 
19
  assert response.status_code == 400
20
 
21
 
@@ -23,7 +25,7 @@ def test_convert_download():
23
  response = requests.post(
24
  f"{env.str('ENDPOINT')}/convert_awq",
25
  json={
26
- "hf_model_name": "Qwen/Qwen2.5-14B-Instruct",
27
  }
28
  )
29
 
@@ -33,7 +35,7 @@ def test_convert_download():
33
 
34
 
35
  def test_convert_push():
36
- model_name = "Qwen/Qwen2.5-14B-Instruct"
37
 
38
  response = requests.post(
39
  f"{env.str('ENDPOINT')}/convert_awq",
 
1
  import pytest
2
  import requests
3
  from environs import Env
 
4
 
5
  env = Env()
6
  env.read_env(override=True)
 
15
  "hf_push_repo": None,
16
  }
17
  )
18
+
19
+ response.raise_for_status()
20
+
21
  assert response.status_code == 400
22
 
23
 
 
25
  response = requests.post(
26
  f"{env.str('ENDPOINT')}/convert_awq",
27
  json={
28
+ "hf_model_name": "Qwen/Qwen2.5-7B-Instruct",
29
  }
30
  )
31
 
 
35
 
36
 
37
  def test_convert_push():
38
+ model_name = "Qwen/Qwen2.5-7B-Instruct"
39
 
40
  response = requests.post(
41
  f"{env.str('ENDPOINT')}/convert_awq",