Spaces:
Running
Running
Commit
·
cb7a4c9
1
Parent(s):
38198b1
Removed explicit references to cuda in functions where spaces GPU are loaded
Browse files- funcs/embeddings.py +12 -12
- funcs/representation_model.py +22 -7
funcs/embeddings.py
CHANGED
|
@@ -35,18 +35,18 @@ def make_or_load_embeddings(docs: list, file_list: list, embeddings_out: np.ndar
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
# Check for torch cuda
|
| 38 |
-
from torch import cuda, backends, version
|
| 39 |
-
|
| 40 |
-
print("Is CUDA enabled? ", cuda.is_available())
|
| 41 |
-
print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
| 42 |
-
if cuda.is_available():
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
else:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
if high_quality_mode_opt == "Yes":
|
| 52 |
# Define a list of possible local locations to search for the model
|
|
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
# Check for torch cuda
|
| 38 |
+
# from torch import cuda, backends, version
|
| 39 |
+
|
| 40 |
+
# print("Is CUDA enabled? ", cuda.is_available())
|
| 41 |
+
# print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
| 42 |
+
# if cuda.is_available():
|
| 43 |
+
# torch_device = "gpu"
|
| 44 |
+
# print("Cuda version installed is: ", version.cuda)
|
| 45 |
+
# high_quality_mode = "Yes"
|
| 46 |
+
# os.system("nvidia-smi")
|
| 47 |
+
# else:
|
| 48 |
+
# torch_device = "cpu"
|
| 49 |
+
# high_quality_mode = "No"
|
| 50 |
|
| 51 |
if high_quality_mode_opt == "Yes":
|
| 52 |
# Define a list of possible local locations to search for the model
|
funcs/representation_model.py
CHANGED
|
@@ -19,16 +19,30 @@ random_seed = 42
|
|
| 19 |
RUNNING_ON_AWS = get_or_create_env_var('RUNNING_ON_AWS', '0')
|
| 20 |
print(f'The value of RUNNING_ON_AWS is {RUNNING_ON_AWS}')
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
torch_device = "gpu"
|
| 28 |
print("Cuda version installed is: ", version.cuda)
|
| 29 |
high_quality_mode = "Yes"
|
| 30 |
os.system("nvidia-smi")
|
| 31 |
-
else:
|
|
|
|
| 32 |
torch_device = "cpu"
|
| 33 |
high_quality_mode = "No"
|
| 34 |
|
|
@@ -42,6 +56,7 @@ else: # torch_device = "cpu"
|
|
| 42 |
n_gpu_layers = 0
|
| 43 |
|
| 44 |
#print("Running on device:", torch_device)
|
|
|
|
| 45 |
n_threads = get_num_threads()
|
| 46 |
print("CPU n_threads:", n_threads)
|
| 47 |
|
|
@@ -56,7 +71,7 @@ seed: int = random_seed
|
|
| 56 |
reset: bool = True
|
| 57 |
stream: bool = False
|
| 58 |
n_threads: int = n_threads
|
| 59 |
-
n_batch:int =
|
| 60 |
n_ctx:int = 8192 #4096. # Set to 8192 just to avoid any exceeded context window issues
|
| 61 |
sample:bool = True
|
| 62 |
trust_remote_code:bool =True
|
|
|
|
| 19 |
RUNNING_ON_AWS = get_or_create_env_var('RUNNING_ON_AWS', '0')
|
| 20 |
print(f'The value of RUNNING_ON_AWS is {RUNNING_ON_AWS}')
|
| 21 |
|
| 22 |
+
USE_GPU = get_or_create_env_var('USE_GPU', '0')
|
| 23 |
+
print(f'The value of USE_GPU is {USE_GPU}')
|
| 24 |
+
|
| 25 |
+
# from torch import cuda, backends, version, get_num_threads
|
| 26 |
+
|
| 27 |
+
# print("Is CUDA enabled? ", cuda.is_available())
|
| 28 |
+
# print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
| 29 |
+
# if cuda.is_available():
|
| 30 |
+
# torch_device = "gpu"
|
| 31 |
+
# print("Cuda version installed is: ", version.cuda)
|
| 32 |
+
# high_quality_mode = "Yes"
|
| 33 |
+
# os.system("nvidia-smi")
|
| 34 |
+
# else:
|
| 35 |
+
# torch_device = "cpu"
|
| 36 |
+
# high_quality_mode = "No"
|
| 37 |
+
|
| 38 |
+
if USE_GPU == "1":
|
| 39 |
+
print("Using GPU for representation functions")
|
| 40 |
torch_device = "gpu"
|
| 41 |
print("Cuda version installed is: ", version.cuda)
|
| 42 |
high_quality_mode = "Yes"
|
| 43 |
os.system("nvidia-smi")
|
| 44 |
+
else:
|
| 45 |
+
print("Using CPU for representation functions")
|
| 46 |
torch_device = "cpu"
|
| 47 |
high_quality_mode = "No"
|
| 48 |
|
|
|
|
| 56 |
n_gpu_layers = 0
|
| 57 |
|
| 58 |
#print("Running on device:", torch_device)
|
| 59 |
+
from torch import get_num_threads
|
| 60 |
n_threads = get_num_threads()
|
| 61 |
print("CPU n_threads:", n_threads)
|
| 62 |
|
|
|
|
| 71 |
reset: bool = True
|
| 72 |
stream: bool = False
|
| 73 |
n_threads: int = n_threads
|
| 74 |
+
n_batch:int = 512
|
| 75 |
n_ctx:int = 8192 #4096. # Set to 8192 just to avoid any exceeded context window issues
|
| 76 |
sample:bool = True
|
| 77 |
trust_remote_code:bool =True
|