Spaces:
Paused
Paused
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import json | |
| import math | |
| import numpy as np | |
| def _load_json_cmvn(json_cmvn_file): | |
| """Load the json format cmvn stats file and calculate cmvn | |
| Args: | |
| json_cmvn_file: cmvn stats file in json format | |
| Returns: | |
| a numpy array of [means, vars] | |
| """ | |
| with open(json_cmvn_file) as f: | |
| cmvn_stats = json.load(f) | |
| means = cmvn_stats["mean_stat"] | |
| variance = cmvn_stats["var_stat"] | |
| count = cmvn_stats["frame_num"] | |
| for i in range(len(means)): | |
| means[i] /= count | |
| variance[i] = variance[i] / count - means[i] * means[i] | |
| if variance[i] < 1.0e-20: | |
| variance[i] = 1.0e-20 | |
| variance[i] = 1.0 / math.sqrt(variance[i]) | |
| cmvn = np.array([means, variance]) | |
| return cmvn | |
| def _load_kaldi_cmvn(kaldi_cmvn_file): | |
| """Load the kaldi format cmvn stats file and calculate cmvn | |
| Args: | |
| kaldi_cmvn_file: kaldi text style global cmvn file, which | |
| is generated by: | |
| compute-cmvn-stats --binary=false scp:feats.scp global_cmvn | |
| Returns: | |
| a numpy array of [means, vars] | |
| """ | |
| means = [] | |
| variance = [] | |
| with open(kaldi_cmvn_file, "r") as fid: | |
| # kaldi binary file start with '\0B' | |
| if fid.read(2) == "\0B": | |
| logging.error( | |
| "kaldi cmvn binary file is not supported, please " | |
| "recompute it by: compute-cmvn-stats --binary=false " | |
| " scp:feats.scp global_cmvn" | |
| ) | |
| sys.exit(1) | |
| fid.seek(0) | |
| arr = fid.read().split() | |
| assert arr[0] == "[" | |
| assert arr[-2] == "0" | |
| assert arr[-1] == "]" | |
| feat_dim = int((len(arr) - 2 - 2) / 2) | |
| for i in range(1, feat_dim + 1): | |
| means.append(float(arr[i])) | |
| count = float(arr[feat_dim + 1]) | |
| for i in range(feat_dim + 2, 2 * feat_dim + 2): | |
| variance.append(float(arr[i])) | |
| for i in range(len(means)): | |
| means[i] /= count | |
| variance[i] = variance[i] / count - means[i] * means[i] | |
| if variance[i] < 1.0e-20: | |
| variance[i] = 1.0e-20 | |
| variance[i] = 1.0 / math.sqrt(variance[i]) | |
| cmvn = np.array([means, variance]) | |
| return cmvn | |
| def load_cmvn(cmvn_file, is_json): | |
| if is_json: | |
| cmvn = _load_json_cmvn(cmvn_file) | |
| else: | |
| cmvn = _load_kaldi_cmvn(cmvn_file) | |
| return cmvn[0], cmvn[1] | |