update handler
Browse files- .gitignore +1 -0
- __pycache__/handler.cpython-38.pyc +0 -0
- handler.py +5 -4
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
test_handler.py
|
__pycache__/handler.cpython-38.pyc
ADDED
Binary file (1.88 kB). View file
|
|
handler.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
from transformers import (
|
3 |
AutomaticSpeechRecognitionPipeline,
|
@@ -10,9 +11,10 @@ from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, P
|
|
10 |
class EndpointHandler():
|
11 |
def __init__(self, path=""):
|
12 |
# Preload all the elements you are going to need at inference.
|
|
|
13 |
language = "Chinese"
|
14 |
task = "transcribe"
|
15 |
-
|
16 |
model= WhisperForConditionalGeneration.from_pretrained(
|
17 |
peft_config.base_model_name_or_path
|
18 |
)
|
@@ -20,9 +22,8 @@ class EndpointHandler():
|
|
20 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
21 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
22 |
feature_extractor = processor.feature_extractor
|
23 |
-
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
24 |
self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
25 |
-
|
26 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
27 |
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
28 |
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
@@ -38,5 +39,5 @@ class EndpointHandler():
|
|
38 |
|
39 |
inputs = data.pop("inputs", data)
|
40 |
with torch.cuda.amp.autocast():
|
41 |
-
|
42 |
return {"prediction": prediction}
|
|
|
1 |
+
import torch
|
2 |
from typing import Dict, List, Any
|
3 |
from transformers import (
|
4 |
AutomaticSpeechRecognitionPipeline,
|
|
|
11 |
class EndpointHandler():
|
12 |
def __init__(self, path=""):
|
13 |
# Preload all the elements you are going to need at inference.
|
14 |
+
peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
|
15 |
language = "Chinese"
|
16 |
task = "transcribe"
|
17 |
+
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
18 |
model= WhisperForConditionalGeneration.from_pretrained(
|
19 |
peft_config.base_model_name_or_path
|
20 |
)
|
|
|
22 |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
23 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
24 |
feature_extractor = processor.feature_extractor
|
25 |
+
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
26 |
self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
|
|
27 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
28 |
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
29 |
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
|
|
39 |
|
40 |
inputs = data.pop("inputs", data)
|
41 |
with torch.cuda.amp.autocast():
|
42 |
+
prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
|
43 |
return {"prediction": prediction}
|