Mir-2002 commited on
Commit
832a366
·
verified ·
1 Parent(s): 1efe044

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +49 -0
handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+ MAX_INPUT_LENGTH = 256
6
+ MAX_OUTPUT_LENGTH = 128
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, model_dir: str = "", **kwargs: Any) -> None:
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
12
+ self.model.eval()
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ self.model.to(self.device)
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
+ inputs = data.get("inputs")
18
+ if not inputs:
19
+ raise ValueError("No 'inputs' found in the request data.")
20
+
21
+ if isinstance(inputs, str):
22
+ inputs = [inputs]
23
+
24
+ tokenized_inputs = self.tokenizer(
25
+ inputs,
26
+ max_length=MAX_INPUT_LENGTH,
27
+ padding=True,
28
+ truncation=True,
29
+ return_tensors="pt"
30
+ ).to(self.device)
31
+
32
+ try:
33
+ with torch.no_grad():
34
+ outputs = self.model.generate(
35
+ tokenized_inputs["input_ids"],
36
+ attention_mask=tokenized_inputs["attention_mask"],
37
+ max_length=MAX_OUTPUT_LENGTH,
38
+ num_beams=4, # Slightly faster
39
+ no_repeat_ngram_size=3,
40
+ early_stopping=True,
41
+ do_sample=False,
42
+ pad_token_id=self.tokenizer.pad_token_id
43
+ )
44
+ decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
45
+ results = [{"generated_text": text} for text in decoded_outputs]
46
+ return results
47
+ except Exception as e:
48
+ # Log error and return a message
49
+ return [{"generated_text": f"Error: {str(e)}"}]