Update handler.py
Browse files- handler.py +5 -3
handler.py
CHANGED
@@ -2,9 +2,12 @@ from typing import Dict, List, Any
|
|
2 |
from transformers import AutoTokenizer
|
3 |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
4 |
import torch
|
|
|
5 |
|
6 |
# check for GPU
|
7 |
-
|
|
|
|
|
8 |
|
9 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
10 |
MAX_MAX_NEW_TOKENS = 2048
|
@@ -32,8 +35,7 @@ class EndpointHandler():
|
|
32 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
33 |
return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
|
34 |
|
35 |
-
|
36 |
-
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
|
37 |
|
38 |
outputs = self.model.generate(**input_ids, **parameters)
|
39 |
|
|
|
2 |
from transformers import AutoTokenizer
|
3 |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
4 |
import torch
|
5 |
+
from loguru import logger
|
6 |
|
7 |
# check for GPU
|
8 |
+
device = 0 if torch.cuda.is_available() else -1
|
9 |
+
|
10 |
+
logger.info(f"cuda: {device}")
|
11 |
|
12 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
13 |
MAX_MAX_NEW_TOKENS = 2048
|
|
|
35 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
36 |
return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
|
37 |
|
38 |
+
input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
|
|
|
39 |
|
40 |
outputs = self.model.generate(**input_ids, **parameters)
|
41 |
|