cathyi commited on
Commit
5745cca
·
1 Parent(s): cc8e20d

update handler and some files

Browse files
Files changed (4) hide show
  1. handler.py +23 -9
  2. merges.txt +0 -0
  3. tokenizer.json +0 -0
  4. tokenizer_config.json +1 -1
handler.py CHANGED
@@ -1,18 +1,31 @@
1
  from typing import Dict, List, Any
2
- from transformers import WhisperForConditionalGeneration, pipeline
 
 
 
 
 
3
  from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
  # Preload all the elements you are going to need at inference.
8
- peft_config = PeftConfig.from_pretrained(path)
9
- self.model= WhisperForConditionalGeneration.from_pretrained(
 
 
10
  peft_config.base_model_name_or_path
11
  )
12
- self.model = PeftModel.from_pretrained(self.model, peft_model_id)
13
- self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
14
- self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
15
- self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
 
 
 
 
 
 
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
  """
@@ -24,5 +37,6 @@ class EndpointHandler():
24
  """
25
 
26
  inputs = data.pop("inputs", data)
27
- prediction = self.pipeline(inputs, return_timestamps=False)
28
- return prediction
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import (
3
+ AutomaticSpeechRecognitionPipeline,
4
+ WhisperForConditionalGeneration,
5
+ WhisperTokenizer,
6
+ WhisperProcessor
7
+ )
8
  from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  # Preload all the elements you are going to need at inference.
13
+ language = "Chinese"
14
+ task = "transcribe"
15
+ self.peft_config = PeftConfig.from_pretrained(path)
16
+ model= WhisperForConditionalGeneration.from_pretrained(
17
  peft_config.base_model_name_or_path
18
  )
19
+ model = PeftModel.from_pretrained(model, peft_model_id)
20
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
21
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
22
+ feature_extractor = processor.feature_extractor
23
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
24
+ self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
25
+
26
+ # self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
27
+ # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
28
+ # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
29
 
30
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
  """
 
37
  """
38
 
39
  inputs = data.pop("inputs", data)
40
+ with torch.cuda.amp.autocast():
41
+ predicion = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]
42
+ return {"prediction": prediction}
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -9,6 +9,7 @@
9
  "rstrip": false,
10
  "single_word": false
11
  },
 
12
  "dropout": 0.0,
13
  "eos_token": {
14
  "__type": "AddedToken",
@@ -23,7 +24,6 @@
23
  "pad_token": null,
24
  "processor_class": "WhisperProcessor",
25
  "return_attention_mask": false,
26
- "special_tokens_map_file": null,
27
  "tokenizer_class": "WhisperTokenizer",
28
  "unk_token": {
29
  "__type": "AddedToken",
 
9
  "rstrip": false,
10
  "single_word": false
11
  },
12
+ "clean_up_tokenization_spaces": true,
13
  "dropout": 0.0,
14
  "eos_token": {
15
  "__type": "AddedToken",
 
24
  "pad_token": null,
25
  "processor_class": "WhisperProcessor",
26
  "return_attention_mask": false,
 
27
  "tokenizer_class": "WhisperTokenizer",
28
  "unk_token": {
29
  "__type": "AddedToken",