cathyi commited on
Commit
e60d835
·
1 Parent(s): 03cd09d

update handler

Browse files
Files changed (1) hide show
  1. handler.py +4 -1
handler.py CHANGED
@@ -23,7 +23,10 @@ class EndpointHandler():
23
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
24
  feature_extractor = processor.feature_extractor
25
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
26
- self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
 
 
 
27
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
28
  # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
29
  # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
 
23
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
24
  feature_extractor = processor.feature_extractor
25
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
26
+ # self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
27
+ self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
28
+ self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
29
+ self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
30
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
31
  # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
32
  # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids