import numpy as np import soundfile import onnxruntime as ort import argparse import time from utils import * from rknnlite.api import RKNNLite def get_args(): parser = argparse.ArgumentParser( prog="melotts", description="Run TTS on input sentence" ) parser.add_argument("--sentence", "-s", type=str, required=False, default="爱芯元智半导体股份有限公司,致力于打造世界领先的人工智能感知与边缘计算芯片。服务智慧城市、智能驾驶、机器人的海量普惠的应用") parser.add_argument("--wav", "-w", type=str, required=False, default="output.wav") parser.add_argument("--encoder", "-e", type=str, required=False, default="encoder.onnx") parser.add_argument("--decoder", "-d", type=str, required=False, default="decoder.rknn") parser.add_argument("--sample_rate", "-sr", type=int, required=False, default=44100) parser.add_argument("--speed", type=float, required=False, default=0.8) parser.add_argument("--lexicon", type=str, required=False, default="lexicon.txt") parser.add_argument("--token", type=str, required=False, default="tokens.txt") return parser.parse_args() def audio_numpy_concat(segment_data_list, sr, speed=1.): audio_segments = [] for segment_data in segment_data_list: audio_segments += segment_data.reshape(-1).tolist() audio_segments += [0] * int((sr * 0.05) / speed) audio_segments = np.array(audio_segments).astype(np.float32) return audio_segments def merge_sub_audio(sub_audio_list, pad_size, audio_len): # Average pad part if pad_size > 0: for i in range(len(sub_audio_list) - 1): sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] sub_audio_list[i][-pad_size:] /= 2 if i > 0: sub_audio_list[i] = sub_audio_list[i][pad_size:] sub_audio = np.concatenate(sub_audio_list, axis=-1) return sub_audio[:audio_len] def generate_pronounce_slice(yinjie_num): # 按照每个词的音节数产生发音切片 pron_slice = [] start = 0 end = 0 for i, n in enumerate(yinjie_num): end = start + n pron_slice.append(slice(start, end)) start = end return pron_slice def generate_word_pron_num(pron_lens, pron_slices): # 求每个词语的发音长度 pron_num = [] for s in pron_slices: pron_num.append(pron_lens[s].sum()) return pron_num def decode_long_word(sess_dec, z_p, g, dec_len): z_p_len = z_p.shape[-1] slice_num = int(np.ceil(z_p_len / dec_len)) sub_audio_list = [] for i in range(slice_num): z_p_slice = z_p[..., i * dec_len : (i + 1) * dec_len] sub_dec_len = z_p_slice.shape[-1] sub_audio_len = 512 * sub_dec_len if z_p_slice.shape[-1] < dec_len: z_p_slice = np.concatenate((z_p_slice, np.zeros((*z_p_slice.shape[:-1], dec_len - z_p_slice.shape[-1]), dtype=np.float32)), axis=-1) start = time.time() audio = sess_dec.inference(inputs=[z_p_slice, g])[0].flatten() audio = audio[:sub_audio_len] print(f"Long word slice[{i}]: decoder run take {1000 * (time.time() - start):.2f}ms") sub_audio_list.append(audio) return sub_audio_list def generate_decode_slices(pron_num, dec_len): pron_slices = [] zp_slices = [] strip_flags = [] # head tail pron_lens = [] is_long = [] start = end = 0 zp_start, zp_end = 0, 0 while (end < len(pron_num)): if len(pron_lens) > 0 and pron_lens[-1] > 2 and pron_num[end] < dec_len: prev_end = pron_slices[-1][1] pad_size = pron_num[prev_end - 1] + pron_num[prev_end - 2] start = end - 2 zp_start = zp_end - pad_size strip_head, strip_tail = True, False if is_long[-1]: strip_flags[-1][1] = False else: strip_flags[-1][1] = True else: pad_size = 0 start = end zp_start = zp_end strip_head, strip_tail = False, False if len(strip_flags) > 0: strip_flags[-1][1] = False sub_dec_len = abs(zp_end - zp_start) while (end < len(pron_num) and sub_dec_len < dec_len): sub_dec_len += pron_num[end] end += 1 if end - start == 1 and sub_dec_len > dec_len: # Long word is_long.append(True) if len(strip_flags) > 0: strip_flags[-1][1] = False else: is_long.append(False) if sub_dec_len > dec_len: sub_dec_len -= pron_num[end - 1] end -= 1 zp_end = zp_start + sub_dec_len pron_slices.append([start, end]) zp_slices.append([zp_start, zp_end]) strip_flags.append([strip_head, strip_tail]) pron_lens.append(end - start) return pron_slices, zp_slices, strip_flags, pron_lens, is_long def main(): args = get_args() sentence = args.sentence sample_rate = args.sample_rate lexicon_filename = args.lexicon token_filename = args.token enc_model = args.encoder # default="../models/encoder.onnx" dec_model = args.decoder # default="../models/decoder.axmodel" print(f"sentence: {sentence}") print(f"sample_rate: {sample_rate}") print(f"lexicon: {lexicon_filename}") print(f"token: {token_filename}") print(f"encoder: {enc_model}") print(f"decoder: {dec_model}") # Split sentence sens = split_sentences_zh(sentence) # Load lexicon lexicon = Lexicon(lexicon_filename, token_filename) # Load models start = time.time() sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) # 初始化RKNNLite decoder = RKNNLite() ret = decoder.load_rknn(dec_model) if ret != 0: print('Load decoder RKNN model failed') exit(ret) ret = decoder.init_runtime() if ret != 0: print('Init runtime failed') exit(ret) dec_len = 65536 // 512 print(f"load models take {1000 * (time.time() - start)}ms") # Load static input g = np.fromfile("g.bin", dtype=np.float32).reshape(1, 256, 1) # Final wav audio_list = [] # Iterate over splitted sentences for n, se in enumerate(sens): print(f"\nSentence[{n}]: {se}") # Convert sentence to phones and tones phone_str, yinjie_num, phones, tones = lexicon.convert(se) # Add blank between words phone_str = intersperse(phone_str, 0) phones = np.array(intersperse(phones, 0), dtype=np.int32) tones = np.array(intersperse(tones, 0), dtype=np.int32) yinjie_num = np.array(yinjie_num, dtype=np.int32) * 2 yinjie_num[0] += 1 assert (yinjie_num.sum() == phones.shape[0]) pron_slices = generate_pronounce_slice(yinjie_num) # for s in pron_slices: # print(phone_str[s]) phone_len = phones.shape[-1] language = np.array([3] * phone_len, dtype=np.int32) start = time.time() z_p, pronoun_lens, audio_len = sess_enc.run(None, input_feed={ 'phone': phones, 'g': g, 'tone': tones, 'language': language, 'noise_scale': np.array([0], dtype=np.float32), 'length_scale': np.array([1.0 / args.speed], dtype=np.float32), 'noise_scale_w': np.array([0], dtype=np.float32), 'sdp_ratio': np.array([0], dtype=np.float32)}) print(f"encoder run take {1000 * (time.time() - start):.2f}ms") audio_len = audio_len[0] actual_size = z_p.shape[-1] dec_slice_num = int(np.ceil(actual_size / dec_len)) # print(f"origin z_p.shape: {z_p.shape}") z_p = np.pad(z_p, pad_width=((0,0),(0,0),(0, dec_slice_num * dec_len - actual_size)), mode="constant", constant_values=0) # print(f"phone_len: {phone_len}") # print(f"z_p.shape: {z_p.shape}") # print(f"dec_slice_num: {dec_slice_num}") # print(f"audio_len: {audio_len}") # 生成每个词的发音数量 pron_num = generate_word_pron_num(pronoun_lens, pron_slices) # assert (sum(pron_num) == pronoun_lens.sum()) sub_audio_list = [] pron_num_slices, zp_slices, strip_flags, pron_lens, is_long = \ generate_decode_slices(pron_num, dec_len) for i in range(len(pron_num_slices)): pron_start, pron_end = pron_num_slices[i] zp_start, zp_end = zp_slices[i] phone_strs = [] for n in range(pron_start, pron_end): phone_strs.extend(phone_str[pron_slices[n]]) # print(f"phone str: {phone_strs}") if is_long[i]: sub_audio_list.extend(decode_long_word(decoder, z_p[..., zp_start : zp_end], g, dec_len)) else: sub_dec_len = zp_end - zp_start sub_audio_len = 512 * sub_dec_len zp_slice = z_p[..., zp_start : zp_end] if zp_slice.shape[-1] < dec_len: zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) start = time.time() outputs = decoder.inference(inputs=[zp_slice, g]) audio = outputs[0].flatten() audio = audio[:sub_audio_len] print(f"Sentence[{n}] Slice[{i}]: decoder run take {1000 * (time.time() - start):.2f}ms") if strip_flags[i][0]: # strip head head = 512 * pron_num[pron_start] # print(f"Strip head: {phone_str[pron_slices[pron_start]]}") audio = audio[head : ] if strip_flags[i][1]: # strip tail tail = 512 * pron_num[pron_end - 1] # print(f"Strip tail: {phone_str[pron_slices[pron_end - 1]]}") audio = audio[: -tail] sub_audio_list.append(audio) sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) audio_list.append(sub_audio) audio = audio_numpy_concat(audio_list, sr=sample_rate, speed=args.speed) soundfile.write(args.wav, audio, sample_rate) print(f"Save to {args.wav}") # 在最后添加资源释放 decoder.release() if __name__ == "__main__": main()