ICLR_FLAG / utils /chemutils.py
zaixizhang
renew
10efe81
raw
history blame
21.6 kB
import rdkit
import rdkit.Chem as Chem
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from rdkit.Chem.Descriptors import MolLogP, qed
from torch_geometric.data import Data, Batch
from random import sample
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
import numpy as np
from math import sqrt
import torch
from copy import deepcopy
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def vina_score(mol):
ligand_rdmol = Chem.AddHs(mol, addCoords=True)
if use_uff:
UFFOptimizeMolecule(ligand_rdmol)
def lipinski(mol):
if qed(mol)<=5 and Chem.Lipinski.NumHDonors(mol)<=5 and Chem.Lipinski.NumHAcceptors(mol)<=10 and Chem.Descriptors.ExactMolWt(mol)<=500 and Chem.Lipinski.NumRotatableBonds(mol)<=5:
return True
else:
return False
def list_filter(a,b):
filter = []
for i in a:
if i in b:
filter.append(i)
return filter
def rand_rotate(dir, ref, pos, alpha=None, device=None):
if device is None:
device = 'cpu'
dir = dir/torch.norm(dir)
if alpha is None:
alpha = torch.randn(1).to(device)
n_pos = pos.shape[0]
sin, cos = torch.sin(alpha).to(device), torch.cos(alpha).to(device)
K = 1 - cos
M = torch.dot(dir, ref)
nx, ny, nz = dir[0], dir[1], dir[2]
x0, y0, z0 = ref[0], ref[1], ref[2]
T = torch.tensor([nx ** 2 * K + cos, nx * ny * K - nz * sin, nx * nz * K + ny * sin,
(x0 - nx * M) * K + (nz * y0 - ny * z0) * sin,
nx * ny * K + nz * sin, ny ** 2 * K + cos, ny * nz * K - nx * sin,
(y0 - ny * M) * K + (nx * z0 - nz * x0) * sin,
nx * nz * K - ny * sin, ny * nz * K + nx * sin, nz ** 2 * K + cos,
(z0 - nz * M) * K + (ny * x0 - nx * y0) * sin,
0, 0, 0, 1], device=device).reshape(4, 4)
pos = torch.cat([pos.t(), torch.ones(n_pos, device=device).unsqueeze(0)], dim=0)
rotated_pos = torch.mm(T, pos)[:3]
return rotated_pos.t()
def kabsch(A, B):
# Input:
# Nominal A Nx3 matrix of points
# Measured B Nx3 matrix of points
# Returns R,t
# R = 3x3 rotation matrix (B to A)
# t = 3x1 translation vector (B to A)
assert len(A) == len(B)
N = A.shape[0] # total points
centroid_A = np.mean(A, axis=0)
centroid_B = np.mean(B, axis=0)
# center the points
AA = A - np.tile(centroid_A, (N, 1))
BB = B - np.tile(centroid_B, (N, 1))
H = np.transpose(BB) * AA
U, S, Vt = np.linalg.svd(H)
R = Vt.T * U.T
# special reflection case
if np.linalg.det(R) < 0:
Vt[2, :] *= -1
R = Vt.T * U.T
t = -R * centroid_B.T + centroid_A.T
return R, t
def kabsch_torch(A, B, C):
A=A.double()
B=B.double()
C=C.double()
a_mean = A.mean(dim=0, keepdims=True)
b_mean = B.mean(dim=0, keepdims=True)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = torch.matmul(A_c.transpose(0,1), B_c) # [B, 3, 3]
U, S, V = torch.svd(H)
# Rotation matrix
R = torch.matmul(V, U.transpose(0,1)) # [B, 3, 3]
# Translation vector
t = b_mean - torch.matmul(R, a_mean.transpose(0,1)).transpose(0,1)
C_aligned = torch.matmul(R, C.transpose(0,1)).transpose(0,1) + t
return C_aligned, R, t
def eig_coord_from_dist(D):
M = (D[:1, :] + D[:, :1] - D) / 2
L, V = torch.linalg.eigh(M)
L = torch.diag_embed(torch.sort(L, descending=True)[0])
X = torch.matmul(V, L.clamp(min=0).sqrt())
return X[:, :3].detach()
def self_square_dist(X):
dX = X.unsqueeze(0) - X.unsqueeze(1) # [1, N, 3] - [N, 1, 3]
D = torch.sum(dX**2, dim=-1)
return D
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=False)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if
int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_submol(mol, idxs, mark=[]):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
map = {}
for atom in mol.GetAtoms():
if atom.GetIdx() in idxs:
new_atom = copy_atom(atom)
if atom.GetIdx() in mark:
new_atom.SetAtomMapNum(1)
else:
new_atom.SetAtomMapNum(0)
map[atom.GetIdx()] = new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if a1 in idxs and a2 in idxs:
bt = bond.GetBondType()
new_mol.AddBond(map[a1], map[a2], bt)
return new_mol.GetMol()
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def get_clique_mol_simple(mol, cluster):
smile_cluster = Chem.MolFragmentToSmiles(mol, cluster, canonical=True, kekuleSmiles=True)
mol_cluster = Chem.MolFromSmiles(smile_cluster, sanitize=False)
return mol_cluster
def tree_decomp(mol, reference_vocab=None):
edges = defaultdict(int)
n_atoms = mol.GetNumAtoms()
clusters = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
clusters.append({a1, a2})
# extract rotatable bonds
ssr = [set(x) for x in Chem.GetSymmSSSR(mol)]
# remove too large circles
ssr = [x for x in ssr if len(x) <= 8]
# Merge Rings with intersection >= 2 atoms
# check the reference_vocab if it is not None
for i in range(len(ssr)-1):
if len(ssr[i]) <= 2:
continue
for j in range(i+1, len(ssr)):
if len(ssr[j]) <= 2:
continue
inter = ssr[i] & ssr[j]
if reference_vocab is not None:
if len(inter) >= 2:
merge = ssr[i] | ssr[j]
smile_merge = Chem.MolFragmentToSmiles(mol, merge, canonical=True, kekuleSmiles=True)
if reference_vocab[smile_merge] <= 100 and len(inter) == 2:
continue
ssr[i] = merge
ssr[j] = set()
else:
if len(inter) > 2:
merge = ssr[i] | ssr[j]
ssr[i] = merge
ssr[j] = set()
ssr = [c for c in ssr if len(c) > 0]
clusters.extend(ssr)
nei_list = [[] for _ in range(n_atoms)]
for i in range(len(clusters)):
for atom in clusters[i]:
nei_list[atom].append(i)
# Build edges
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(clusters[c1]) & set(clusters[c2])
if edges[(c1, c2)] < len(inter):
edges[(c1, c2)] = len(inter) # cnei[i] < cnei[j] by construction
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0:
return clusters, edges
# Compute Maximum Spanning Tree
row, col, data = zip(*edges)
n_clique = len(clusters)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return clusters, edges
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(bond1, bond2, reverse=False):
b1 = (bond1.GetBeginAtom(), bond1.GetEndAtom())
if reverse:
b2 = (bond2.GetEndAtom(), bond2.GetBeginAtom())
else:
b2 = (bond2.GetBeginAtom(), bond2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) and bond1.GetBondType() == bond2.GetBondType()
def attach(ctr_mol, nei_mol, amap):
ctr_mol = Chem.RWMol(ctr_mol)
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
new_atom.SetAtomMapNum(2)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol.GetMol(), amap
def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node.nid for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node.nid, nei_node.mol
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei.nid: {} for nei in prev_nodes + neighbors}
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach(ctr_mol, nei_mol):
try:
Chem.Kekulize(ctr_mol)
Chem.Kekulize(nei_mol)
except:
return []
att_confs = []
valence_ctr = {i: 0 for i in range(ctr_mol.GetNumAtoms())}
valence_nei = {i: 0 for i in range(nei_mol.GetNumAtoms())}
ctr_bonds = [bond for bond in ctr_mol.GetBonds() if bond.GetBeginAtom().GetAtomMapNum() == 1 and bond.GetEndAtom().GetAtomMapNum() == 1]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetAtomMapNum() == 1]
if nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
#bond_val = int(bond.GetBondType())
bond_val = int(bond.GetBondTypeAsDouble())
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = {b1.GetIdx(): atom.GetIdx()}
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = {b2.GetIdx(): atom.GetIdx()}
att_confs.append(new_amap)
else:
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
# Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
amap = {a2.GetIdx(): a1.GetIdx()}
att_confs.append(amap)
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
amap = {b2.GetBeginAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
att_confs.append(amap)
if ring_bond_equal(b1, b2, reverse=True):
amap = {b2.GetEndAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
att_confs.append(amap)
return att_confs
def enumerate_assemble(mol, idxs, current, next):
ctr_mol = get_submol(mol, idxs, mark=current.clique)
ground_truth = get_submol(mol, list(set(idxs) | set(next.clique)))
# submol can also obtained with get_clique_mol, future exploration
ground_truth_smiles = get_smiles(ground_truth)
cand_smiles = []
cand_mols = []
cand_amap = enum_attach(ctr_mol, next.mol)
for amap in cand_amap:
try:
cand_mol, _ = attach(ctr_mol, next.mol, amap)
cand_mol = sanitize(cand_mol)
except:
continue
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles or smiles == ground_truth_smiles:
continue
cand_smiles.append(smiles)
cand_mols.append(cand_mol)
if len(cand_mols) >= 1:
cand_mols = sample(cand_mols, 1)
cand_mols.append(ground_truth)
labels = torch.tensor([0, 1])
else:
cand_mols = [ground_truth]
labels = torch.tensor([1])
return labels, cand_mols
# allowable node and edge features
allowable_features = {
'possible_atomic_num_list' : list(range(1, 119)),
'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
'possible_chirality_list' : [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
],
'possible_hybridization_list' : [
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
],
'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'possible_bonds' : [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC
],
'possible_bond_dirs' : [ # only for double bond stereo information
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
}
def mol_to_graph_data_obj_simple(mol):
"""
Converts rdkit mol object to graph Data object required by the pytorch
geometric package. NB: Uses simplified atom and bond features, and represent
as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr
"""
# atoms
num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = [allowable_features['possible_atomic_num_list'].index(
atom.GetAtomicNum())] + [allowable_features[
'possible_chirality_list'].index(atom.GetChiralTag())]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
num_bond_features = 2 # bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = [allowable_features['possible_bonds'].index(
bond.GetBondType())] + [allowable_features[
'possible_bond_dirs'].index(
bond.GetBondDir())]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data
# For inference
def assemble(mol_list, next_motif_smiles):
attach_fail = torch.zeros(len(mol_list)).bool()
cand_mols, cand_batch, new_atoms, cand_smiles, one_atom_attach, intersection = [], [], [], [], [], []
for i in range(len(mol_list)):
next = Chem.MolFromSmiles(next_motif_smiles[i])
cand_amap = enum_attach(mol_list[i], next)
if len(cand_amap) == 0:
attach_fail[i] = True
cand_mols.append(mol_list[i])
cand_batch.append(i)
one_atom_attach.append(-1)
intersection.append([])
new_atoms.append([])
else:
valid_cand = 0
for amap in cand_amap:
amap_len = len(amap)
iter_atoms = [v for v in amap.values()]
ctr_mol = deepcopy(mol_list[i])
cand_mol, amap1 = attach(ctr_mol, next, amap)
if sanitize(deepcopy(cand_mol)) is None:
continue
smiles = get_smiles(cand_mol)
cand_smiles.append(smiles)
cand_mols.append(cand_mol)
cand_batch.append(i)
new_atoms.append([v for v in amap1.values()])
one_atom_attach.append(amap_len)
intersection.append(iter_atoms)
valid_cand+=1
if valid_cand==0:
attach_fail[i] = True
cand_mols.append(mol_list[i])
cand_batch.append(i)
one_atom_attach.append(-1)
intersection.append([])
new_atoms.append([])
cand_batch = torch.tensor(cand_batch)
one_atom_attach = torch.tensor(one_atom_attach) == 1
return cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail
if __name__ == "__main__":
import sys
from mol_tree import MolTree
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1", "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2",
"ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3",
"C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1",
'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1',
"O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
mol_tree = MolTree("C")
assert len(mol_tree.nodes) > 0
def count():
cnt, n = 0, 0
for s in sys.stdin:
s = s.split()[0]
tree = MolTree(s)
tree.recover()
tree.assemble()
for node in tree.nodes:
cnt += len(node.cands)
n += len(tree.nodes)
# print cnt * 1.0 / n
count()