lj1995 commited on
Commit
c3ca968
·
verified ·
1 Parent(s): 0830dc8

Update text/cleaner.py

Browse files
Files changed (1) hide show
  1. text/cleaner.py +46 -0
text/cleaner.py CHANGED
@@ -17,7 +17,53 @@ special = [
17
  # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
18
  ]
19
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def clean_text(text, language, version=None):
22
  if version is None:
23
  version = os.environ.get("version", "v2")
 
17
  # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
18
  ]
19
 
20
+ def load_nvrtc():
21
+ import torch,sys,os,ctypes
22
+ from pathlib import Path
23
 
24
+ if not torch.cuda.is_available():
25
+ print("[INFO] CUDA is not available, skipping nvrtc setup.")
26
+ return
27
+
28
+ if sys.platform == "win32":
29
+ torch_lib_dir = Path(torch.__file__).parent / "lib"
30
+ if torch_lib_dir.exists():
31
+ os.add_dll_directory(str(torch_lib_dir))
32
+ print(f"[INFO] Added DLL directory: {torch_lib_dir}")
33
+ matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
34
+ if not matching_files:
35
+ print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
36
+ return
37
+ for dll_path in matching_files:
38
+ dll_name = os.path.basename(dll_path)
39
+ try:
40
+ ctypes.CDLL(dll_name)
41
+ print(f"[INFO] Loaded: {dll_name}")
42
+ except OSError as e:
43
+ print(f"[WARNING] Failed to load {dll_name}: {e}")
44
+ else:
45
+ print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
46
+
47
+ elif sys.platform == "linux":
48
+ site_packages = Path(torch.__file__).resolve().parents[1]
49
+ nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
50
+
51
+ if not nvrtc_dir.exists():
52
+ print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
53
+ return
54
+
55
+ matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
56
+ if not matching_files:
57
+ print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
58
+ return
59
+
60
+ for so_path in matching_files:
61
+ try:
62
+ ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
63
+ print(f"[INFO] Loaded: {so_path}")
64
+ except OSError as e:
65
+ print(f"[WARNING] Failed to load {so_path}: {e}")
66
+ load_nvrtc()
67
  def clean_text(text, language, version=None):
68
  if version is None:
69
  version = os.environ.get("version", "v2")