Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| from typing import List | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import numpy as np | |
| import threading | |
| from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError | |
| import time | |
| import requests | |
| import joblib | |
| # from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder | |
| from Bio import SeqIO | |
| import rdkit | |
| from rdkit import Chem, DataStructs | |
| from rdkit.Chem import AllChem | |
| import torch | |
| from typing import * | |
| from rdkit import RDLogger | |
| RDLogger.DisableLog("rdApp.*") | |
| import spaces | |
| from xgboost import XGBClassifier, DMatrix | |
| from model.barlow_twins import BarlowTwins | |
| # sys.path.append("../utils/") | |
| from utils.sequence import uniprot2sequence, encode_sequences | |
| class DTIModel: | |
| def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"): | |
| self.bt_model = BarlowTwins() | |
| self.bt_model.load_model(bt_model_path) | |
| self.gbm_model = XGBClassifier() | |
| self.gbm_model.load_model(gbm_model_path) | |
| self.encoder = encoder | |
| self.smiles_cache = {} | |
| self.sequence_cache = {} | |
| def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False): | |
| if smiles is None: | |
| return None | |
| # Check if the SMILES is already in the cache | |
| if smiles in self.smiles_cache: | |
| return self.smiles_cache[smiles] | |
| else: | |
| # Encode the SMILES and store it in the cache | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| morgan = AllChem.GetMorganFingerprintAsBitVect( | |
| mol, | |
| radius=radius, | |
| nBits=bits, | |
| useFeatures=features, | |
| ) | |
| morgan = np.array(morgan) | |
| self.smiles_cache[smiles] = morgan | |
| return morgan | |
| except Exception as e: | |
| print(f"Failed to encode SMILES: {smiles}") | |
| print(e) | |
| return None | |
| def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False): | |
| morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles] | |
| return np.array(morgan) | |
| def _encode_sequence(self, sequence: str): | |
| # Clear torch cache | |
| torch.cuda.empty_cache() | |
| if sequence is None: | |
| return None | |
| # Check if the sequence is already in the cache | |
| if sequence in self.sequence_cache: | |
| return self.sequence_cache[sequence] | |
| else: | |
| # Encode the sequence and store it in the cache | |
| try: | |
| encoded_sequence = encode_sequences([sequence], encoder=self.encoder) | |
| self.sequence_cache[sequence] = encoded_sequence | |
| return encoded_sequence | |
| except Exception as e: | |
| print(f"Failed to encode sequence: {sequence}") | |
| print(e) | |
| return None | |
| def _encode_sequence_mult(self, sequences: List[str]): | |
| seq = [self._encode_sequence(sequence) for sequence in sequences] | |
| return np.array(seq) | |
| def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool): | |
| if drug_emb.shape[0] < target_emb.shape[0]: | |
| drug_emb = np.tile(drug_emb, (len(target_emb), 1)) | |
| elif len(drug_emb) > len(target_emb): | |
| target_emb = np.tile(target_emb, (len(drug_emb), 1)) | |
| emb = self.bt_model.zero_shot(drug_emb, target_emb) | |
| if pred_leaf: | |
| d_emb = DMatrix(emb) | |
| return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True) | |
| else: | |
| return self.gbm_model.predict_proba(emb)[:, 1] | |
| def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False): | |
| if isinstance(drug, str): | |
| drug_emb = self._encode_smiles(drug) | |
| else: | |
| drug_emb = self._encode_smiles_mult(drug) | |
| target_emb = self._encode_sequence(target) | |
| return self.__predict_pair(drug_emb, target_emb, pred_leaf) | |
| def get_leaf_weights(self): | |
| return self.gbm_model.get_booster().get_score(importance_type="weight") | |
| def _predict_fasta(self, drug: str, fasta_path: str): | |
| drug_emb = self._encode_smiles(drug) | |
| results = [] | |
| # Extract targets from fasta | |
| for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"): | |
| target_emb = self._encode_sequence(str(target.seq)) | |
| pred = self.__predict_pair(drug_emb, target_emb) | |
| results.append( | |
| { | |
| "drug": drug, | |
| "target": target.id, | |
| "name": target.name, | |
| "description": target.description, | |
| "prediction": pred[0] | |
| } | |
| ) | |
| return pd.DataFrame(results) | |
| def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120): | |
| def process_target(target, results): | |
| target_emb = self._encode_sequence(str(target.seq)) | |
| pred = self.__predict_pair(drug_emb, target_emb) | |
| results.append({ | |
| "drug": drug, | |
| "target": target.id, | |
| "name": target.name, | |
| "description": target.description, | |
| "prediction": pred[0] | |
| }) | |
| drug_emb = self._encode_smiles(drug) | |
| results = [] | |
| # First, count the total number of records for the progress bar | |
| total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta")) | |
| # Extract targets from fasta with a properly initialized tqdm progress bar | |
| for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"): | |
| thread_results = [] | |
| thread = threading.Thread(target=process_target, args=(target, thread_results)) | |
| thread.start() | |
| thread.join(timeout_seconds) | |
| if thread.is_alive(): | |
| print(f"Skipping target {target.id} due to timeout") | |
| continue | |
| results.extend(thread_results) | |
| return pd.DataFrame(results) | |
| def predict_uniprot(self, drug: List[str] or str, uniprot_id: str): | |
| return self.predict(drug, uniprot2sequence(uniprot_id)) | |