Spaces:
Runtime error
Runtime error
| # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
| # | |
| # See LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import lru_cache | |
| import sherpa_onnx | |
| from huggingface_hub import hf_hub_download | |
| sample_rate = 16000 | |
| def _get_nn_model_filename( | |
| repo_id: str, | |
| filename: str, | |
| subfolder: str = "exp", | |
| ) -> str: | |
| nn_model_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| ) | |
| return nn_model_filename | |
| get_file = _get_nn_model_filename | |
| def _get_bpe_model_filename( | |
| repo_id: str, | |
| filename: str = "bpe.model", | |
| subfolder: str = "data/lang_bpe_500", | |
| ) -> str: | |
| bpe_model_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| ) | |
| return bpe_model_filename | |
| def _get_token_filename( | |
| repo_id: str, | |
| filename: str = "tokens.txt", | |
| subfolder: str = "data/lang_char", | |
| ) -> str: | |
| token_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| ) | |
| return token_filename | |
| def _get_whisper_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| name = repo_id.split("-")[1] | |
| assert name in ("tiny.en", "base.en", "small.en", "medium.en"), repo_id | |
| full_repo_id = "csukuangfj/sherpa-onnx-whisper-" + name | |
| encoder = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"{name}-encoder.int8.onnx", | |
| subfolder=".", | |
| ) | |
| decoder = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"{name}-decoder.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename( | |
| repo_id=full_repo_id, subfolder=".", filename=f"{name}-tokens.txt" | |
| ) | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | |
| encoder=encoder, | |
| decoder=decoder, | |
| tokens=tokens, | |
| num_threads=2, | |
| tail_paddings=2000, | |
| ) | |
| return recognizer | |
| def _get_paraformer_zh_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in [ | |
| "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", | |
| ], repo_id | |
| nn_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="model.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | |
| paraformer=nn_model, | |
| tokens=tokens, | |
| num_threads=2, | |
| sample_rate=sample_rate, | |
| feature_dim=80, | |
| decoding_method="greedy_search", | |
| debug=False, | |
| ) | |
| return recognizer | |
| def _get_chinese_dialect_models(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in [ | |
| "csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04", | |
| ], repo_id | |
| nn_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="model.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc( | |
| model=nn_model, | |
| tokens=tokens, | |
| num_threads=2, | |
| ) | |
| return recognizer | |
| def _get_russian_pre_trained_model_ctc(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ( | |
| "csukuangfj/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24", | |
| ), repo_id | |
| model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="model.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | |
| model=model, | |
| tokens=tokens, | |
| num_threads=2, | |
| ) | |
| return recognizer | |
| def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ( | |
| "alphacep/vosk-model-ru", | |
| "alphacep/vosk-model-small-ru", | |
| "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24", | |
| ), repo_id | |
| if repo_id == "alphacep/vosk-model-ru": | |
| model_dir = "am-onnx" | |
| encoder = "encoder.onnx" | |
| model_type = "transducer" | |
| elif repo_id == "alphacep/vosk-model-small-ru": | |
| model_dir = "am" | |
| encoder = "encoder.onnx" | |
| model_type = "transducer" | |
| elif repo_id == "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": | |
| model_dir = "." | |
| encoder = "encoder.int8.onnx" | |
| model_type = "nemo_transducer" | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename=encoder, | |
| subfolder=model_dir, | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder.onnx", | |
| subfolder=model_dir, | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner.onnx", | |
| subfolder=model_dir, | |
| ) | |
| if repo_id == "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| else: | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder="lang") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| model_type=model_type, | |
| ) | |
| return recognizer | |
| def get_punct_model() -> sherpa_onnx.OfflinePunctuation: | |
| model = _get_nn_model_filename( | |
| repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12", | |
| filename="model.onnx", | |
| subfolder=".", | |
| ) | |
| config = sherpa_onnx.OfflinePunctuationConfig( | |
| model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model), | |
| ) | |
| punct = sherpa_onnx.OfflinePunctuation(config) | |
| return punct | |
| def get_vad() -> sherpa_onnx.VoiceActivityDetector: | |
| vad_model = _get_nn_model_filename( | |
| repo_id="csukuangfj/vad", | |
| filename="silero_vad_v5.onnx", | |
| subfolder=".", | |
| ) | |
| config = sherpa_onnx.VadModelConfig() | |
| config.silero_vad.model = vad_model | |
| config.silero_vad.threshold = 0.3 | |
| config.silero_vad.min_silence_duration = 0.25 | |
| config.silero_vad.min_speech_duration = 0.25 | |
| config.sample_rate = sample_rate | |
| config.silero_vad.max_speech_duration = 20 # seconds | |
| vad = sherpa_onnx.VoiceActivityDetector( | |
| config, | |
| buffer_size_in_seconds=180, | |
| ) | |
| return vad | |
| def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| if repo_id in chinese_models: | |
| return chinese_models[repo_id](repo_id) | |
| elif repo_id in chinese_dialect_models: | |
| return chinese_dialect_models[repo_id](repo_id) | |
| elif repo_id in english_models: | |
| return english_models[repo_id](repo_id) | |
| elif repo_id in chinese_english_mixed_models: | |
| return chinese_english_mixed_models[repo_id](repo_id) | |
| elif repo_id in russian_models: | |
| return russian_models[repo_id](repo_id) | |
| elif repo_id in korean_models: | |
| return korean_models[repo_id](repo_id) | |
| elif repo_id in thai_models: | |
| return thai_models[repo_id](repo_id) | |
| elif repo_id in japanese_models: | |
| return japanese_models[repo_id](repo_id) | |
| elif repo_id in zh_en_ko_ja_yue_models: | |
| return zh_en_ko_ja_yue_models[repo_id](repo_id) | |
| else: | |
| raise ValueError(f"Unsupported repo_id: {repo_id}") | |
| def _get_wenetspeech_pre_trained_model(repo_id): | |
| assert repo_id in ( | |
| "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23", | |
| ), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| decoding_method="greedy_search", | |
| ) | |
| return recognizer | |
| def _get_multi_zh_hans_pre_trained_model(repo_id): | |
| assert repo_id in ("zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2",), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-20-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-20-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-20-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| decoding_method="greedy_search", | |
| ) | |
| return recognizer | |
| def _get_moonshine_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ("moonshine-tiny", "moonshine-base"), repo_id | |
| if repo_id == "moonshine-tiny": | |
| full_repo_id = "csukuangfj/sherpa-onnx-moonshine-tiny-en-int8" | |
| elif repo_id == "moonshine-base": | |
| full_repo_id = "csukuangfj/sherpa-onnx-moonshine-base-en-int8" | |
| else: | |
| raise ValueError(f"Unknown repo_id: {repo_id}") | |
| preprocessor = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"preprocess.onnx", | |
| subfolder=".", | |
| ) | |
| encoder = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"encode.int8.onnx", | |
| subfolder=".", | |
| ) | |
| uncached_decoder = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"uncached_decode.int8.onnx", | |
| subfolder=".", | |
| ) | |
| cached_decoder = _get_nn_model_filename( | |
| repo_id=full_repo_id, | |
| filename=f"cached_decode.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename( | |
| repo_id=full_repo_id, | |
| subfolder=".", | |
| filename="tokens.txt", | |
| ) | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( | |
| preprocessor=preprocessor, | |
| encoder=encoder, | |
| uncached_decoder=uncached_decoder, | |
| cached_decoder=cached_decoder, | |
| tokens=tokens, | |
| num_threads=2, | |
| ) | |
| return recognizer | |
| def _get_english_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert ( | |
| repo_id | |
| == "yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" | |
| ), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-30-avg-4.onnx", | |
| subfolder="exp", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-30-avg-4.onnx", | |
| subfolder="exp", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-30-avg-4.onnx", | |
| subfolder="exp", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder="lang_bpe_500") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| decoding_method="greedy_search", | |
| ) | |
| return recognizer | |
| def _get_korean_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ("k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24",), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-99-avg-1.int8.onnx", | |
| subfolder=".", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| ) | |
| return recognizer | |
| def _get_japanese_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ("reazon-research/reazonspeech-k2-v2",), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-99-avg-1.int8.onnx", | |
| subfolder=".", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-99-avg-1.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| ) | |
| return recognizer | |
| def _get_yifan_thai_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in ( | |
| "yfyeung/icefall-asr-gigaspeech2-th-zipformer-2024-06-20", | |
| ), repo_id | |
| encoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="encoder-epoch-12-avg-5.int8.onnx", | |
| subfolder="exp", | |
| ) | |
| decoder_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="decoder-epoch-12-avg-5.onnx", | |
| subfolder="exp", | |
| ) | |
| joiner_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="joiner-epoch-12-avg-5.int8.onnx", | |
| subfolder="exp", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder="data/lang_bpe_2000") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| tokens=tokens, | |
| encoder=encoder_model, | |
| decoder=decoder_model, | |
| joiner=joiner_model, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| ) | |
| return recognizer | |
| def _get_sense_voice_pre_trained_model( | |
| repo_id: str, | |
| ) -> sherpa_onnx.OfflineRecognizer: | |
| assert repo_id in [ | |
| "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", | |
| ], repo_id | |
| nn_model = _get_nn_model_filename( | |
| repo_id=repo_id, | |
| filename="model.int8.onnx", | |
| subfolder=".", | |
| ) | |
| tokens = _get_token_filename(repo_id=repo_id, subfolder=".") | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( | |
| model=nn_model, | |
| tokens=tokens, | |
| num_threads=2, | |
| sample_rate=sample_rate, | |
| feature_dim=80, | |
| decoding_method="greedy_search", | |
| debug=True, | |
| use_itn=True, | |
| ) | |
| return recognizer | |
| chinese_dialect_models = { | |
| "csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04": _get_chinese_dialect_models, | |
| } | |
| zh_en_ko_ja_yue_models = { | |
| "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17": _get_sense_voice_pre_trained_model, | |
| } | |
| chinese_models = { | |
| "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28": _get_paraformer_zh_pre_trained_model, | |
| "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23": _get_wenetspeech_pre_trained_model, # noqa | |
| "zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2": _get_multi_zh_hans_pre_trained_model, # noqa | |
| } | |
| english_models = { | |
| "whisper-tiny.en": _get_whisper_model, | |
| "moonshine-tiny": _get_moonshine_model, | |
| "moonshine-base": _get_moonshine_model, | |
| "whisper-base.en": _get_whisper_model, | |
| "whisper-small.en": _get_whisper_model, | |
| "whisper-distil-small.en": _get_whisper_model, | |
| "whisper-medium.en": _get_whisper_model, | |
| "whisper-distil-medium.en": _get_whisper_model, | |
| "yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04": _get_english_model, # noqa | |
| } | |
| chinese_english_mixed_models = { | |
| "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28": _get_paraformer_zh_pre_trained_model, | |
| } | |
| korean_models = { | |
| "k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24": _get_korean_pre_trained_model, | |
| } | |
| russian_models = { | |
| "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": _get_russian_pre_trained_model, | |
| "csukuangfj/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24": _get_russian_pre_trained_model_ctc, | |
| "alphacep/vosk-model-ru": _get_russian_pre_trained_model, | |
| "alphacep/vosk-model-small-ru": _get_russian_pre_trained_model, | |
| } | |
| thai_models = { | |
| "yfyeung/icefall-asr-gigaspeech2-th-zipformer-2024-06-20": _get_yifan_thai_pretrained_model, | |
| } | |
| japanese_models = { | |
| "reazon-research/reazonspeech-k2-v2": _get_japanese_pre_trained_model | |
| } | |
| language_to_models = { | |
| "超多种中文方言": list(chinese_dialect_models.keys()), | |
| "Chinese+English": list(chinese_english_mixed_models.keys()), | |
| "Chinese+English+Korean+Japanese+Cantoes(中英韩日粤语)": list( | |
| zh_en_ko_ja_yue_models.keys() | |
| ), | |
| "Chinese": list(chinese_models.keys()), | |
| "English": list(english_models.keys()), | |
| "Russian": list(russian_models.keys()), | |
| "Korean": list(korean_models.keys()), | |
| "Thai": list(thai_models.keys()), | |
| "Japanese": list(japanese_models.keys()), | |
| } | |