update handler for testing
Browse files- handler.py +29 -34
handler.py
CHANGED
@@ -1,53 +1,48 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
from transformers import pipeline
|
3 |
-
|
4 |
-
import sys
|
5 |
import torch
|
|
|
6 |
from transformers import (
|
7 |
-
AutomaticSpeechRecognitionPipeline,
|
8 |
-
WhisperForConditionalGeneration,
|
9 |
-
WhisperTokenizer,
|
10 |
-
WhisperProcessor
|
|
|
11 |
)
|
12 |
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
|
13 |
|
14 |
class EndpointHandler():
|
15 |
def __init__(self, path=""):
|
16 |
-
|
|
|
17 |
language = "Chinese"
|
18 |
-
task = "transcribe"
|
19 |
-
peft_config = PeftConfig.from_pretrained(
|
20 |
-
model
|
21 |
peft_config.base_model_name_or_path
|
22 |
)
|
23 |
-
model = PeftModel.from_pretrained(model,
|
24 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
26 |
feature_extractor = processor.feature_extractor
|
27 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
28 |
-
self.pipeline =
|
29 |
-
self.pipeline
|
30 |
-
self.pipeline.model.
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
33 |
"""
|
34 |
-
|
35 |
-
inputs (:obj: `str`)
|
36 |
-
|
37 |
-
|
38 |
A :obj:`list` | `dict`: will be serialized and returned
|
39 |
"""
|
40 |
-
# get inputs
|
41 |
-
|
42 |
-
# run normal prediction
|
43 |
-
inputs = data.pop("inputs", data)
|
44 |
-
print("a1", inputs)
|
45 |
-
print("a2", inputs, file=sys.stderr)
|
46 |
-
print("a3", inputs, file=sys.stdout)
|
47 |
-
|
48 |
-
prediction = self.pipeline(inputs, return_timestamps=False)
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from typing import Dict, List, Any
|
3 |
from transformers import (
|
4 |
+
AutomaticSpeechRecognitionPipeline,
|
5 |
+
WhisperForConditionalGeneration,
|
6 |
+
WhisperTokenizer,
|
7 |
+
WhisperProcessor,
|
8 |
+
pipeline
|
9 |
)
|
10 |
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
|
11 |
|
12 |
class EndpointHandler():
|
13 |
def __init__(self, path=""):
|
14 |
+
# Preload all the elements you are going to need at inference.
|
15 |
+
peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
|
16 |
language = "Chinese"
|
17 |
+
task = "transcribe"
|
18 |
+
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
19 |
+
model= WhisperForConditionalGeneration.from_pretrained(
|
20 |
peft_config.base_model_name_or_path
|
21 |
)
|
22 |
+
model = PeftModel.from_pretrained(model, peft_model_id)
|
23 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
24 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
feature_extractor = processor.feature_extractor
|
26 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
27 |
+
# self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
28 |
+
self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
29 |
+
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
30 |
+
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
31 |
+
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
32 |
+
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
33 |
+
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
34 |
+
|
35 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
36 |
"""
|
37 |
+
data args:
|
38 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
39 |
+
kwargs
|
40 |
+
Return:
|
41 |
A :obj:`list` | `dict`: will be serialized and returned
|
42 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
inputs = data.pop("inputs", data)
|
45 |
+
with torch.cuda.amp.autocast():
|
46 |
+
# prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
|
47 |
+
prediction = self.pipeline(inputs, return_timestamps=False)
|
48 |
+
return [prediction, {"test": 0}]
|