from text import cleaned_text_to_sequence import os # if os.environ.get("version","v1")=="v1": # from text import chinese # from text.symbols import symbols # else: # from text import chinese2 as chinese # from text.symbols2 import symbols from text import symbols as symbols_v1 from text import symbols2 as symbols_v2 special = [ # ("%", "zh", "SP"), ("¥", "zh", "SP2"), ("^", "zh", "SP3"), # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 ] def load_nvrtc(): import torch,sys,os,ctypes from pathlib import Path if not torch.cuda.is_available(): print("[INFO] CUDA is not available, skipping nvrtc setup.") return if sys.platform == "win32": torch_lib_dir = Path(torch.__file__).parent / "lib" if torch_lib_dir.exists(): os.add_dll_directory(str(torch_lib_dir)) print(f"[INFO] Added DLL directory: {torch_lib_dir}") matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll")) if not matching_files: print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}") return for dll_path in matching_files: dll_name = os.path.basename(dll_path) try: ctypes.CDLL(dll_name) print(f"[INFO] Loaded: {dll_name}") except OSError as e: print(f"[WARNING] Failed to load {dll_name}: {e}") else: print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") elif sys.platform == "linux": site_packages = Path(torch.__file__).resolve().parents[1] nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib" if not nvrtc_dir.exists(): print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}") return matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*")) if not matching_files: print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}") return for so_path in matching_files: try: ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore print(f"[INFO] Loaded: {so_path}") except OSError as e: print(f"[WARNING] Failed to load {so_path}: {e}") load_nvrtc() def clean_text(text, language, version=None): if version is None: version = os.environ.get("version", "v2") if version == "v1": symbols = symbols_v1.symbols language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} else: symbols = symbols_v2.symbols language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} if language not in language_module_map: language = "en" text = " " for special_s, special_l, target_symbol in special: if special_s in text and language == special_l: return clean_special(text, language, special_s, target_symbol, version) language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) if hasattr(language_module, "text_normalize"): norm_text = language_module.text_normalize(text) else: norm_text = text if language == "zh" or language == "yue": ########## phones, word2ph = language_module.g2p(norm_text) assert len(phones) == sum(word2ph) assert len(norm_text) == len(word2ph) elif language == "en": phones = language_module.g2p(norm_text) if len(phones) < 4: phones = [","] + phones word2ph = None else: phones = language_module.g2p(norm_text) word2ph = None phones = ["UNK" if ph not in symbols else ph for ph in phones] return phones, word2ph, norm_text def clean_special(text, language, special_s, target_symbol, version=None): if version is None: version = os.environ.get("version", "v2") if version == "v1": symbols = symbols_v1.symbols language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} else: symbols = symbols_v2.symbols language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} """ 特殊静音段sp符号处理 """ text = text.replace(special_s, ",") language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) norm_text = language_module.text_normalize(text) phones = language_module.g2p(norm_text) new_ph = [] for ph in phones[0]: assert ph in symbols if ph == ",": new_ph.append(target_symbol) else: new_ph.append(ph) return new_ph, phones[1], norm_text def text_to_sequence(text, language, version=None): version = os.environ.get("version", version) if version is None: version = "v2" phones = clean_text(text) return cleaned_text_to_sequence(phones, version) if __name__ == "__main__": print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))