import re import sys import yaml from munch import Munch import unicodedata import numpy as np import librosa import noisereduce as nr import phonemizer import torch import torchaudio from nltk.tokenize import word_tokenize import nltk nltk.download('punkt_tab') from models import ProsodyPredictor, TextEncoder, StyleEncoder from Modules.hifigan import Decoder if sys.platform.startswith("win"): try: from phonemizer.backend.espeak.wrapper import EspeakWrapper import espeakng_loader EspeakWrapper.set_library(espeakng_loader.get_library_path()) except Exception as e: print(e) def espeak_phn(text, lang): try: my_phonemizer = phonemizer.backend.EspeakBackend(language=lang, preserve_punctuation=True, with_stress=True, language_switch='remove-flags') return my_phonemizer.phonemize([text])[0] except Exception as e: print(e) # IPA Phonemizer: https://github.com/bootphon/phonemizer # Total including extend chars 189 _pad = "$" _punctuation = ';:,.!?¡¿—…"«»“” ' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" _extend = "∫̆ăη͡123456" # Export all symbols: symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_extend) dicts = {} for i in range(len((symbols))): dicts[symbols[i]] = i class TextCleaner: def __init__(self, dummy=None): self.word_index_dictionary = dicts #print(len(dicts)) def __call__(self, text): indexes = [] for char in text: try: indexes.append(self.word_index_dictionary[char]) except KeyError as e: #print(char) continue return indexes class Preprocess: def __text_normalize(self, text): punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"] map_to = "." punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]") #ensure consistency. text = unicodedata.normalize('NFKC', text) #replace punctuation that acts like a comma or period #text = re.sub(r'\.{2,}', '.', text) text = punctuation_pattern.sub(map_to, text) #remove or replace special chars except . , { } % $ & ' - \ / text = re.sub(r'[^\w\s.,{}%$&\'\-\[\]\/]', ' ', text) #replace consecutive whitespace chars with a single space and strip leading/trailing spaces text = re.sub(r'\s+', ' ', text).strip() return text def __merge_fragments(self, texts, n): merged = [] i = 0 while i < len(texts): fragment = texts[i] j = i + 1 while len(fragment.split()) < n and j < len(texts): fragment += ", " + texts[j] j += 1 merged.append(fragment) i = j if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence merged[-2] = merged[-2] + ", " + merged[-1] del merged[-1] else: merged[-1] = merged[-1] return merged def wave_preprocess(self, wave): to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) mean, std = -4, 4 wave_tensor = torch.from_numpy(wave).float() mel_tensor = to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std return mel_tensor def text_preprocess(self, text, n_merge=12): text_norm = self.__text_normalize(text).replace(",", ".").split(".")#split. text_norm = [s.strip() for s in text_norm] text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n return text_norm def length_to_mask(self, lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask+1, lengths.unsqueeze(1)) return mask #For inference only class StyleTTS2(torch.nn.Module): def __init__(self, config_path, models_path): super().__init__() self.register_buffer("get_device", torch.empty(0)) self.preprocess = Preprocess() config = yaml.safe_load(open(config_path)) args = self.__recursive_munch(config['model_params']) assert args.decoder.type in ['hifigan'], 'Decoder type unknown' self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, upsample_rates = args.decoder.upsample_rates, upsample_initial_channel=args.decoder.upsample_initial_channel, resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, upsample_kernel_sizes=args.decoder.upsample_kernel_sizes) self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout) self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token) self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)# acoustic style encoder self.__load_models(models_path) def __recursive_munch(self, d): if isinstance(d, dict): return Munch((k, self.__recursive_munch(v)) for k, v in d.items()) elif isinstance(d, list): return [self.__recursive_munch(v) for v in d] else: return d def __init_replacement_func(self, replacements): replacement_iter = iter(replacements) def replacement(match): return next(replacement_iter) return replacement def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95): mean = tensor.mean() std = tensor.std() z = (tensor - mean) / std # Identify outliers outlier_mask = torch.abs(z) > threshold # Compute replacement value, respecting sign sign = torch.sign(tensor - mean) replacement = mean + sign * (threshold * std * factor) result = tensor.clone() result[outlier_mask] = replacement[outlier_mask] return result def __load_models(self, models_path): module_params = [] model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder} params_whole = torch.load(models_path, map_location='cpu') params = params_whole['net'] params = {key: value for key, value in params.items() if key in model.keys()} for key in model: try: model[key].load_state_dict(params[key]) except: from collections import OrderedDict state_dict = params[key] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model[key].load_state_dict(new_state_dict, strict=False) total_params = sum(p.numel() for p in model[key].parameters()) print(key,":",total_params) module_params.append(total_params) print('\nTotal',":",sum(module_params)) def __compute_style(self, path, denoise, split_dur): device = self.get_device.device denoise = min(denoise, 1) if split_dur != 0: split_dur = max(int(split_dur), 1) max_samples = 24000*20 #max 20 seconds ref audio print("Computing the style for:", path) wave, sr = librosa.load(path, sr=24000) audio, index = librosa.effects.trim(wave, top_db=30) if sr != 24000: audio = librosa.resample(audio, sr, 24000) if len(audio) > max_samples: audio = audio[:max_samples] if denoise > 0.0: audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300) audio = audio*(1-denoise) + audio_denoise*denoise with torch.no_grad(): if split_dur>0 and len(audio)/sr>=4: #Only effective if audio length is >= 4s #This option will split the ref audio to multiple parts, calculate styles and average them count = 0 ref_s = None jump = sr*split_dur total_len = len(audio) #Need to init before the loop mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device) ref_s = self.style_encoder(mel_tensor.unsqueeze(1)) count += 1 for i in range(jump, total_len, jump): if i+jump >= total_len: left_dur = (total_len-i)/sr if left_dur >= 0.5: #Still count if left over dur is >= 0.5s mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device) ref_s += self.style_encoder(mel_tensor.unsqueeze(1)) count += 1 continue mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device) ref_s += self.style_encoder(mel_tensor.unsqueeze(1)) count += 1 ref_s /= count else: mel_tensor = self.preprocess.wave_preprocess(audio).to(device) ref_s = self.style_encoder(mel_tensor.unsqueeze(1)) return ref_s def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1): device = self.get_device.device speed = min(max(speed, 0.0001), 2) #speed range [0, 2] phonem = ' '.join(word_tokenize(phonem)) tokens = TextCleaner()(phonem) tokens.insert(0, 0) tokens.append(0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = self.preprocess.length_to_mask(input_lengths).to(device) # encode t_en = self.text_encoder(tokens, input_lengths, text_mask) s = ref_s.to(device) # cal alignment d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask) x, _ = self.predictor.lstm(d) duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) if prev_d_mean != 0:#Stabilize speaking speed between splits dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device) else: dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device) duration = duration*(1-t) + dur_stats*t duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2]) #Normalize outlier duration /= speed pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) alignment = pred_aln_trg.unsqueeze(0).to(device) # encode prosody en = (d.transpose(-1, -2) @ alignment) F0_pred, N_pred = self.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) out = self.decoder(asr, F0_pred, N_pred, s) return out.squeeze().cpu().numpy(), duration.mean() def get_styles(self, speakers, denoise=0.3, avg_style=True): if avg_style: split_dur = 2 else: split_dur = 0 styles = {} for id in speakers: ref_s = self.__compute_style(speakers[id]['path'], denoise=denoise, split_dur=split_dur) styles[id] = { 'style': ref_s, 'path': speakers[id]['path'], 'lang': speakers[id]['lang'], 'speed': speakers[id]['speed'], } return styles def generate(self, text, styles, stabilize=True, n_merge=16, default_speaker= "[id_1]"): if stabilize: smooth_value=0.2 else: smooth_value=0 list_wav = [] prev_d_mean = 0 lang_pattern = r'\[([^\]]+)\]\{([^}]+)\}' text = re.sub(r'[\n\r\t\f\v]', '', text) #fix lang tokens span to multiple sents find_lang_tokens = re.findall(lang_pattern, text) if find_lang_tokens: cus_text = [] for lang, t in find_lang_tokens: parts = self.preprocess.text_preprocess(t, n_merge=0) parts = ".".join([f"[{lang}]" + f"{{{p}}}"for p in parts]) cus_text.append(parts) replacement_func = self.__init_replacement_func(cus_text) text = re.sub(lang_pattern, replacement_func, text) texts = re.split(r'(\[id_\d+\])', text) #split the text by speaker ids while keeping the ids. if len(texts) <= 1 or bool(re.match(r'(\[id_\d+\])', texts[0]) == False): #Add a default speaker texts.insert(0, default_speaker) curr_id = None for i in range(len(texts)): #remove consecutive ids if bool(re.match(r'(\[id_\d+\])', texts[i])): if texts[i]!=curr_id: curr_id = texts[i] else: texts[i] = '' del curr_id texts = list(filter(lambda x: x != '', texts)) print("Generating Audio...") for i in texts: if bool(re.match(r'(\[id_\d+\])', i)): #Set up env for matched speaker speaker_id = i.strip('[]') current_ref_s = styles[speaker_id]['style'] speed = styles[speaker_id]['speed'] continue text_norm = self.preprocess.text_preprocess(i, n_merge=n_merge) for sentence in text_norm: cus_phonem = [] find_lang_tokens = re.findall(lang_pattern, sentence) if find_lang_tokens: for lang, t in find_lang_tokens: try: phonem = espeak_phn(t, lang) cus_phonem.append(phonem) except Exception as e: print(e) replacement_func = self.__init_replacement_func(cus_phonem) phonem = espeak_phn(sentence, styles[speaker_id]['lang']) phonem = re.sub(lang_pattern, replacement_func, phonem) wav, prev_d_mean = self.__inference(phonem, current_ref_s, speed=speed, prev_d_mean=prev_d_mean, t=smooth_value) wav = wav[4000:-4000] #Remove weird pulse and silent tokens list_wav.append(wav) final_wav = np.concatenate(list_wav) final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0) # add padding return final_wav