lj1995 commited on
Commit
0830dc8
·
verified ·
1 Parent(s): 71d6acc

Update text/g2pw/onnx_api.py

Browse files
Files changed (1) hide show
  1. text/g2pw/onnx_api.py +0 -47
text/g2pw/onnx_api.py CHANGED
@@ -1,52 +1,5 @@
1
  # This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
2
  # This code is modified from https://github.com/GitYCC/g2pW
3
- def load_nvrtc():
4
- import torch,sys,os,ctypes
5
- from pathlib import Path
6
-
7
- if not torch.cuda.is_available():
8
- print("[INFO] CUDA is not available, skipping nvrtc setup.")
9
- return
10
-
11
- if sys.platform == "win32":
12
- torch_lib_dir = Path(torch.__file__).parent / "lib"
13
- if torch_lib_dir.exists():
14
- os.add_dll_directory(str(torch_lib_dir))
15
- print(f"[INFO] Added DLL directory: {torch_lib_dir}")
16
- matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
17
- if not matching_files:
18
- print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
19
- return
20
- for dll_path in matching_files:
21
- dll_name = os.path.basename(dll_path)
22
- try:
23
- ctypes.CDLL(dll_name)
24
- print(f"[INFO] Loaded: {dll_name}")
25
- except OSError as e:
26
- print(f"[WARNING] Failed to load {dll_name}: {e}")
27
- else:
28
- print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
29
-
30
- elif sys.platform == "linux":
31
- site_packages = Path(torch.__file__).resolve().parents[1]
32
- nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
33
-
34
- if not nvrtc_dir.exists():
35
- print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
36
- return
37
-
38
- matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
39
- if not matching_files:
40
- print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
41
- return
42
-
43
- for so_path in matching_files:
44
- try:
45
- ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
46
- print(f"[INFO] Loaded: {so_path}")
47
- except OSError as e:
48
- print(f"[WARNING] Failed to load {so_path}: {e}")
49
- load_nvrtc()
50
  import warnings
51
 
52
  warnings.filterwarnings("ignore")
 
1
  # This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
2
  # This code is modified from https://github.com/GitYCC/g2pW
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import warnings
4
 
5
  warnings.filterwarnings("ignore")