update handler
Browse files- handler.py +4 -1
handler.py
CHANGED
@@ -23,7 +23,10 @@ class EndpointHandler():
|
|
23 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
24 |
feature_extractor = processor.feature_extractor
|
25 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
26 |
-
self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
|
|
|
|
|
|
27 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
28 |
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
29 |
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
|
|
23 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
24 |
feature_extractor = processor.feature_extractor
|
25 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
26 |
+
# self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
27 |
+
self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
28 |
+
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
29 |
+
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
30 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
31 |
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
32 |
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|