Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from scipy.io.wavfile import write | |
| import torchaudio | |
| import utils | |
| from Mels_preprocess import MelSpectrogramFixed | |
| from hierspeechpp_speechsynthesizer import ( | |
| SynthesizerTrn | |
| ) | |
| from ttv_v1.text import text_to_sequence | |
| from ttv_v1.t2w2v_transformer import SynthesizerTrn as Text2W2V | |
| from speechsr24k.speechsr import SynthesizerTrn as AudioSR | |
| from speechsr48k.speechsr import SynthesizerTrn as AudioSR48 | |
| from denoiser.generator import MPNet | |
| from denoiser.infer import denoise | |
| seed = 1111 | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| np.random.seed(seed) | |
| def load_text(fp): | |
| with open(fp, 'r') as f: | |
| filelist = [line.strip() for line in f.readlines()] | |
| return filelist | |
| def load_checkpoint(filepath, device): | |
| print(filepath) | |
| assert os.path.isfile(filepath) | |
| print("Loading '{}'".format(filepath)) | |
| checkpoint_dict = torch.load(filepath, map_location=device) | |
| print("Complete.") | |
| return checkpoint_dict | |
| def get_param_num(model): | |
| num_param = sum(param.numel() for param in model.parameters()) | |
| return num_param | |
| def intersperse(lst, item): | |
| result = [item] * (len(lst) * 2 + 1) | |
| result[1::2] = lst | |
| return result | |
| def add_blank_token(text): | |
| text_norm = intersperse(text, 0) | |
| text_norm = torch.LongTensor(text_norm) | |
| return text_norm | |
| def tts(text, a, hierspeech): | |
| net_g, text2w2v, audiosr, denoiser, mel_fn = hierspeech | |
| os.makedirs(a.output_dir, exist_ok=True) | |
| text = text_to_sequence(str(text), ["english_cleaners2"]) | |
| token = add_blank_token(text).unsqueeze(0).cuda() | |
| token_length = torch.LongTensor([token.size(-1)]).cuda() | |
| # Prompt load | |
| audio, sample_rate = torchaudio.load(a.input_prompt) | |
| # support only single channel | |
| audio = audio[:1,:] | |
| # Resampling | |
| if sample_rate != 16000: | |
| audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window") | |
| if a.scale_norm == 'prompt': | |
| prompt_audio_max = torch.max(audio.abs()) | |
| # We utilize a hop size of 320 but denoiser uses a hop size of 400 so we utilize a hop size of 1600 | |
| ori_prompt_len = audio.shape[-1] | |
| p = (ori_prompt_len // 1600 + 1) * 1600 - ori_prompt_len | |
| audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data | |
| file_name = os.path.splitext(os.path.basename(a.input_prompt))[0] | |
| # If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS | |
| # We will have a plan to replace a memory-efficient denoiser | |
| if a.denoise_ratio == 0: | |
| audio = torch.cat([audio.cuda(), audio.cuda()], dim=0) | |
| else: | |
| with torch.no_grad(): | |
| denoised_audio = denoise(audio.squeeze(0).cuda(), denoiser, hps_denoiser) | |
| audio = torch.cat([audio.cuda(), denoised_audio[:,:audio.shape[-1]]], dim=0) | |
| audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing. | |
| src_mel = mel_fn(audio.cuda()) | |
| src_length = torch.LongTensor([src_mel.size(2)]).to(device) | |
| src_length2 = torch.cat([src_length,src_length], dim=0) | |
| ## TTV (Text --> W2V, F0) | |
| with torch.no_grad(): | |
| w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2, noise_scale=a.noise_scale_ttv, denoise_ratio=a.denoise_ratio) | |
| src_length = torch.LongTensor([w2v_x.size(2)]).cuda() | |
| ## Pitch Clipping | |
| pitch[pitch<torch.log(torch.tensor([55]).cuda())] = 0 | |
| ## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio) | |
| converted_audio = \ | |
| net_g.voice_conversion_noise_control(w2v_x, src_length, src_mel, src_length2, pitch, noise_scale=a.noise_scale_vc, denoise_ratio=a.denoise_ratio) | |
| ## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio) | |
| if a.output_sr == 48000 or 24000: | |
| converted_audio = audiosr(converted_audio) | |
| converted_audio = converted_audio.squeeze() | |
| if a.scale_norm == 'prompt': | |
| converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * prompt_audio_max | |
| else: | |
| converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * 0.999 | |
| converted_audio = converted_audio.cpu().numpy().astype('int16') | |
| file_name2 = "{}.wav".format(file_name) | |
| output_file = os.path.join(a.output_dir, file_name2) | |
| if a.output_sr == 48000: | |
| write(output_file, 48000, converted_audio) | |
| elif a.output_sr == 24000: | |
| write(output_file, 24000, converted_audio) | |
| else: | |
| write(output_file, 16000, converted_audio) | |
| def model_load(a): | |
| mel_fn = MelSpectrogramFixed( | |
| sample_rate=hps.data.sampling_rate, | |
| n_fft=hps.data.filter_length, | |
| win_length=hps.data.win_length, | |
| hop_length=hps.data.hop_length, | |
| f_min=hps.data.mel_fmin, | |
| f_max=hps.data.mel_fmax, | |
| n_mels=hps.data.n_mel_channels, | |
| window_fn=torch.hann_window | |
| ).cuda() | |
| net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps.model).cuda() | |
| net_g.load_state_dict(torch.load(a.ckpt)) | |
| _ = net_g.eval() | |
| text2w2v = Text2W2V(hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps_t2w2v.model).cuda() | |
| text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v)) | |
| text2w2v.eval() | |
| if a.output_sr == 48000: | |
| audiosr = AudioSR48(h_sr48.data.n_mel_channels, | |
| h_sr48.train.segment_size // h_sr48.data.hop_length, | |
| **h_sr48.model).cuda() | |
| utils.load_checkpoint(a.ckpt_sr48, audiosr, None) | |
| audiosr.eval() | |
| elif a.output_sr == 24000: | |
| audiosr = AudioSR(h_sr.data.n_mel_channels, | |
| h_sr.train.segment_size // h_sr.data.hop_length, | |
| **h_sr.model).cuda() | |
| utils.load_checkpoint(a.ckpt_sr, audiosr, None) | |
| audiosr.eval() | |
| else: | |
| audiosr = None | |
| denoiser = MPNet(hps_denoiser).cuda() | |
| state_dict = load_checkpoint(a.denoiser_ckpt, device) | |
| denoiser.load_state_dict(state_dict['generator']) | |
| denoiser.eval() | |
| return net_g, text2w2v, audiosr, denoiser, mel_fn | |
| def inference(a): | |
| hierspeech = model_load(a) | |
| # Input Text | |
| text = load_text(a.input_txt) | |
| # text = "hello I'm hierspeech" | |
| tts(text, a, hierspeech) | |
| def main(): | |
| print('Initializing Inference Process..') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input_prompt', default='example/reference_4.wav') | |
| parser.add_argument('--input_txt', default='example/reference_4.txt') | |
| parser.add_argument('--output_dir', default='output') | |
| parser.add_argument('--ckpt', default='./logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth') | |
| parser.add_argument('--ckpt_text2w2v', '-ct', help='text2w2v checkpoint path', default='./logs/ttv_libritts_v1/ttv_lt960_ckpt.pth') | |
| parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth') | |
| parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth') | |
| parser.add_argument('--denoiser_ckpt', type=str, default='denoiser/g_best') | |
| parser.add_argument('--scale_norm', type=str, default='max') | |
| parser.add_argument('--output_sr', type=float, default=48000) | |
| parser.add_argument('--noise_scale_ttv', type=float, | |
| default=0.333) | |
| parser.add_argument('--noise_scale_vc', type=float, | |
| default=0.333) | |
| parser.add_argument('--denoise_ratio', type=float, | |
| default=0.8) | |
| a = parser.parse_args() | |
| global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json')) | |
| hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json')) | |
| h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') ) | |
| h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') ) | |
| hps_denoiser = utils.get_hparams_from_file(os.path.join(os.path.split(a.denoiser_ckpt)[0], 'config.json')) | |
| inference(a) | |
| if __name__ == '__main__': | |
| main() |