Spaces:
Running
on
Zero
Running
on
Zero
Fix ZeroGPU
Browse files- model/model.py +0 -4
model/model.py
CHANGED
|
@@ -17,7 +17,6 @@ import torch
|
|
| 17 |
from typing import *
|
| 18 |
from rdkit import RDLogger
|
| 19 |
RDLogger.DisableLog("rdApp.*")
|
| 20 |
-
import spaces
|
| 21 |
|
| 22 |
from xgboost import XGBClassifier, DMatrix
|
| 23 |
|
|
@@ -68,7 +67,6 @@ class DTIModel:
|
|
| 68 |
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
|
| 69 |
return np.array(morgan)
|
| 70 |
|
| 71 |
-
@spaces.GPU
|
| 72 |
def _encode_sequence(self, sequence: str):
|
| 73 |
# Clear torch cache
|
| 74 |
torch.cuda.empty_cache()
|
|
@@ -88,12 +86,10 @@ class DTIModel:
|
|
| 88 |
print(e)
|
| 89 |
return None
|
| 90 |
|
| 91 |
-
@spaces.GPU
|
| 92 |
def _encode_sequence_mult(self, sequences: List[str]):
|
| 93 |
seq = [self._encode_sequence(sequence) for sequence in sequences]
|
| 94 |
return np.array(seq)
|
| 95 |
|
| 96 |
-
@spaces.GPU
|
| 97 |
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
|
| 98 |
if drug_emb.shape[0] < target_emb.shape[0]:
|
| 99 |
drug_emb = np.tile(drug_emb, (len(target_emb), 1))
|
|
|
|
| 17 |
from typing import *
|
| 18 |
from rdkit import RDLogger
|
| 19 |
RDLogger.DisableLog("rdApp.*")
|
|
|
|
| 20 |
|
| 21 |
from xgboost import XGBClassifier, DMatrix
|
| 22 |
|
|
|
|
| 67 |
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
|
| 68 |
return np.array(morgan)
|
| 69 |
|
|
|
|
| 70 |
def _encode_sequence(self, sequence: str):
|
| 71 |
# Clear torch cache
|
| 72 |
torch.cuda.empty_cache()
|
|
|
|
| 86 |
print(e)
|
| 87 |
return None
|
| 88 |
|
|
|
|
| 89 |
def _encode_sequence_mult(self, sequences: List[str]):
|
| 90 |
seq = [self._encode_sequence(sequence) for sequence in sequences]
|
| 91 |
return np.array(seq)
|
| 92 |
|
|
|
|
| 93 |
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
|
| 94 |
if drug_emb.shape[0] < target_emb.shape[0]:
|
| 95 |
drug_emb = np.tile(drug_emb, (len(target_emb), 1))
|