add handler
Browse files- config.json +1 -1
- handler.py +28 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "openai
|
3 |
"activation_dropout": 0.0,
|
4 |
"activation_function": "gelu",
|
5 |
"architectures": [
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "EricChang/openai-whisper-large-v2-Lora",
|
3 |
"activation_dropout": 0.0,
|
4 |
"activation_function": "gelu",
|
5 |
"architectures": [
|
handler.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import WhisperForConditionalGeneration, pipeline
|
3 |
+
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
|
4 |
+
|
5 |
+
class EndpointHandler():
|
6 |
+
def __init__(self, path=""):
|
7 |
+
# Preload all the elements you are going to need at inference.
|
8 |
+
peft_config = PeftConfig.from_pretrained(path)
|
9 |
+
self.model= WhisperForConditionalGeneration.from_pretrained(
|
10 |
+
peft_config.base_model_name_or_path
|
11 |
+
)
|
12 |
+
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
|
13 |
+
self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
14 |
+
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
15 |
+
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
16 |
+
|
17 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
18 |
+
"""
|
19 |
+
data args:
|
20 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
21 |
+
kwargs
|
22 |
+
Return:
|
23 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
24 |
+
"""
|
25 |
+
|
26 |
+
inputs = data.pop("inputs",data)
|
27 |
+
prediction = self.pipeline(inputs, return_timestamps=False)
|
28 |
+
return prediction
|