Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,183 Bytes
be6f6cd c3ca968 be6f6cd c3ca968 be6f6cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from text import cleaned_text_to_sequence
import os
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
special = [
# ("%", "zh", "SP"),
("¥", "zh", "SP2"),
("^", "zh", "SP3"),
# ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
]
def load_nvrtc():
import torch,sys,os,ctypes
from pathlib import Path
if not torch.cuda.is_available():
print("[INFO] CUDA is not available, skipping nvrtc setup.")
return
if sys.platform == "win32":
torch_lib_dir = Path(torch.__file__).parent / "lib"
if torch_lib_dir.exists():
os.add_dll_directory(str(torch_lib_dir))
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
if not matching_files:
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
return
for dll_path in matching_files:
dll_name = os.path.basename(dll_path)
try:
ctypes.CDLL(dll_name)
print(f"[INFO] Loaded: {dll_name}")
except OSError as e:
print(f"[WARNING] Failed to load {dll_name}: {e}")
else:
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
elif sys.platform == "linux":
site_packages = Path(torch.__file__).resolve().parents[1]
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
if not nvrtc_dir.exists():
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
return
matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
if not matching_files:
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
return
for so_path in matching_files:
try:
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
print(f"[INFO] Loaded: {so_path}")
except OSError as e:
print(f"[WARNING] Failed to load {so_path}: {e}")
load_nvrtc()
def clean_text(text, language, version=None):
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
if language not in language_module_map:
language = "en"
text = " "
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text)
else:
norm_text = text
if language == "zh" or language == "yue": ##########
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
elif language == "en":
phones = language_module.g2p(norm_text)
if len(phones) < 4:
phones = [","] + phones
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
"""
特殊静音段sp符号处理
"""
text = text.replace(special_s, ",")
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text)
new_ph = []
for ph in phones[0]:
assert ph in symbols
if ph == ",":
new_ph.append(target_symbol)
else:
new_ph.append(ph)
return new_ph, phones[1], norm_text
def text_to_sequence(text, language, version=None):
version = os.environ.get("version", version)
if version is None:
version = "v2"
phones = clean_text(text)
return cleaned_text_to_sequence(phones, version)
if __name__ == "__main__":
print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))
|