cathyi commited on
Commit
39239f4
·
1 Parent(s): 05b2014

add handler

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. handler.py +28 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "openai/whisper-small",
3
  "activation_dropout": 0.0,
4
  "activation_function": "gelu",
5
  "architectures": [
 
1
  {
2
+ "_name_or_path": "EricChang/openai-whisper-large-v2-Lora",
3
  "activation_dropout": 0.0,
4
  "activation_function": "gelu",
5
  "architectures": [
handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import WhisperForConditionalGeneration, pipeline
3
+ from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ # Preload all the elements you are going to need at inference.
8
+ peft_config = PeftConfig.from_pretrained(path)
9
+ self.model= WhisperForConditionalGeneration.from_pretrained(
10
+ peft_config.base_model_name_or_path
11
+ )
12
+ self.model = PeftModel.from_pretrained(self.model, peft_model_id)
13
+ self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
14
+ self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
15
+ self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ """
19
+ data args:
20
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
21
+ kwargs
22
+ Return:
23
+ A :obj:`list` | `dict`: will be serialized and returned
24
+ """
25
+
26
+ inputs = data.pop("inputs",data)
27
+ prediction = self.pipeline(inputs, return_timestamps=False)
28
+ return prediction