kajdun commited on
Commit
9a254c2
·
1 Parent(s): a5a6053

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -3
handler.py CHANGED
@@ -2,9 +2,12 @@ from typing import Dict, List, Any
2
  from transformers import AutoTokenizer
3
  from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
4
  import torch
 
5
 
6
  # check for GPU
7
- #device = 0 if torch.cuda.is_available() else -1
 
 
8
 
9
  MAX_INPUT_TOKEN_LENGTH = 4000
10
  MAX_MAX_NEW_TOKENS = 2048
@@ -32,8 +35,7 @@ class EndpointHandler():
32
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
33
  return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
34
 
35
- #input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
36
- input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
37
 
38
  outputs = self.model.generate(**input_ids, **parameters)
39
 
 
2
  from transformers import AutoTokenizer
3
  from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
4
  import torch
5
+ from loguru import logger
6
 
7
  # check for GPU
8
+ device = 0 if torch.cuda.is_available() else -1
9
+
10
+ logger.info(f"cuda: {device}")
11
 
12
  MAX_INPUT_TOKEN_LENGTH = 4000
13
  MAX_MAX_NEW_TOKENS = 2048
 
35
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
36
  return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
37
 
38
+ input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
 
39
 
40
  outputs = self.model.generate(**input_ids, **parameters)
41