|
import numpy as np |
|
from rdkit.Chem import AllChem as Chem |
|
from rdkit import Geometry |
|
from openbabel import openbabel as ob |
|
from openbabel import pybel |
|
from scipy.spatial.distance import pdist |
|
from scipy.spatial.distance import squareform |
|
|
|
from .protein_ligand import ATOM_FAMILIES_ID |
|
|
|
|
|
class MolReconsError(Exception): |
|
pass |
|
|
|
|
|
def reachable_r(a,b, seenbonds): |
|
'''Recursive helper.''' |
|
|
|
for nbr in ob.OBAtomAtomIter(a): |
|
bond = a.GetBond(nbr).GetIdx() |
|
if bond not in seenbonds: |
|
seenbonds.add(bond) |
|
if nbr == b: |
|
return True |
|
elif reachable_r(nbr,b,seenbonds): |
|
return True |
|
return False |
|
|
|
|
|
def reachable(a,b): |
|
'''Return true if atom b is reachable from a without using the bond between them.''' |
|
if a.GetExplicitDegree() == 1 or b.GetExplicitDegree() == 1: |
|
return False |
|
|
|
seenbonds = set([a.GetBond(b).GetIdx()]) |
|
return reachable_r(a,b,seenbonds) |
|
|
|
|
|
def forms_small_angle(a,b,cutoff=45): |
|
'''Return true if bond between a and b is part of a small angle |
|
with a neighbor of a only.''' |
|
|
|
for nbr in ob.OBAtomAtomIter(a): |
|
if nbr != b: |
|
degrees = b.GetAngle(a,nbr) |
|
if degrees < cutoff: |
|
return True |
|
return False |
|
|
|
|
|
def make_obmol(xyz, atomic_numbers): |
|
mol = ob.OBMol() |
|
mol.BeginModify() |
|
atoms = [] |
|
for xyz,t in zip(xyz, atomic_numbers): |
|
x,y,z = xyz |
|
|
|
atom = mol.NewAtom() |
|
atom.SetAtomicNum(t) |
|
atom.SetVector(x,y,z) |
|
atoms.append(atom) |
|
return mol, atoms |
|
|
|
|
|
def connect_the_dots(mol, atoms, indicators, maxbond=4): |
|
'''Custom implementation of ConnectTheDots. This is similar to |
|
OpenBabel's version, but is more willing to make long bonds |
|
(up to maxbond long) to keep the molecule connected. It also |
|
attempts to respect atom type information from struct. |
|
atoms and struct need to correspond in their order |
|
Assumes no hydrogens or existing bonds. |
|
''' |
|
pt = Chem.GetPeriodicTable() |
|
|
|
if len(atoms) == 0: |
|
return |
|
|
|
mol.BeginModify() |
|
|
|
|
|
coords = np.array([(a.GetX(),a.GetY(),a.GetZ()) for a in atoms]) |
|
dists = squareform(pdist(coords)) |
|
|
|
|
|
for (i,a) in enumerate(atoms): |
|
for (j,b) in enumerate(atoms): |
|
if a == b: |
|
break |
|
if dists[i,j] < 0.01: |
|
continue |
|
if dists[i,j] < maxbond: |
|
flag = 0 |
|
if indicators[i][ATOM_FAMILIES_ID['Aromatic']] and indicators[j][ATOM_FAMILIES_ID['Aromatic']]: |
|
|
|
flag = ob.OB_AROMATIC_BOND |
|
|
|
|
|
mol.AddBond(a.GetIdx(),b.GetIdx(),1,flag) |
|
|
|
atom_maxb = {} |
|
for (i,a) in enumerate(atoms): |
|
|
|
|
|
maxb = ob.GetMaxBonds(a.GetAtomicNum()) |
|
maxb = min(maxb,pt.GetDefaultValence(a.GetAtomicNum())) |
|
|
|
if a.GetAtomicNum() == 16: |
|
if count_nbrs_of_elem(a, 8) >= 2: |
|
maxb = 6 |
|
|
|
|
|
|
|
|
|
|
|
atom_maxb[a.GetIdx()] = maxb |
|
|
|
|
|
for bond in ob.OBMolBondIter(mol): |
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
if atom_maxb[a1.GetIdx()] == 1 and atom_maxb[a2.GetIdx()] == 1: |
|
mol.DeleteBond(bond) |
|
|
|
def get_bond_info(biter): |
|
'''Return bonds sorted by their distortion''' |
|
bonds = [b for b in biter] |
|
binfo = [] |
|
for bond in bonds: |
|
bdist = bond.GetLength() |
|
|
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
ideal = ob.GetCovalentRad(a1.GetAtomicNum()) + ob.GetCovalentRad(a2.GetAtomicNum()) |
|
stretch = bdist-ideal |
|
binfo.append((stretch,bdist,bond)) |
|
binfo.sort(reverse=True, key=lambda t: t[:2]) |
|
return binfo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
binfo = get_bond_info(ob.OBMolBondIter(mol)) |
|
|
|
for stretch,bdist,bond in binfo: |
|
|
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
|
|
|
|
|
|
|
|
|
|
if stretch > 0.45 or forms_small_angle(a1,a2) or forms_small_angle(a2,a1): |
|
|
|
if not reachable(a1,a2): |
|
continue |
|
mol.DeleteBond(bond) |
|
|
|
mol.EndModify() |
|
|
|
|
|
def convert_ob_mol_to_rd_mol(ob_mol,struct=None): |
|
'''Convert OBMol to RDKit mol, fixing up issues''' |
|
ob_mol.DeleteHydrogens() |
|
n_atoms = ob_mol.NumAtoms() |
|
rd_mol = Chem.RWMol() |
|
rd_conf = Chem.Conformer(n_atoms) |
|
|
|
for ob_atom in ob.OBMolAtomIter(ob_mol): |
|
rd_atom = Chem.Atom(ob_atom.GetAtomicNum()) |
|
|
|
if ob_atom.IsAromatic() and ob_atom.IsInRing() and ob_atom.MemberOfRingSize() <= 6: |
|
|
|
|
|
rd_atom.SetIsAromatic(True) |
|
i = rd_mol.AddAtom(rd_atom) |
|
ob_coords = ob_atom.GetVector() |
|
x = ob_coords.GetX() |
|
y = ob_coords.GetY() |
|
z = ob_coords.GetZ() |
|
rd_coords = Geometry.Point3D(x, y, z) |
|
rd_conf.SetAtomPosition(i, rd_coords) |
|
|
|
rd_mol.AddConformer(rd_conf) |
|
|
|
for ob_bond in ob.OBMolBondIter(ob_mol): |
|
i = ob_bond.GetBeginAtomIdx()-1 |
|
j = ob_bond.GetEndAtomIdx()-1 |
|
bond_order = ob_bond.GetBondOrder() |
|
if bond_order == 1: |
|
rd_mol.AddBond(i, j, Chem.BondType.SINGLE) |
|
elif bond_order == 2: |
|
rd_mol.AddBond(i, j, Chem.BondType.DOUBLE) |
|
elif bond_order == 3: |
|
rd_mol.AddBond(i, j, Chem.BondType.TRIPLE) |
|
else: |
|
raise Exception('unknown bond order {}'.format(bond_order)) |
|
|
|
if ob_bond.IsAromatic(): |
|
bond = rd_mol.GetBondBetweenAtoms (i,j) |
|
bond.SetIsAromatic(True) |
|
|
|
rd_mol = Chem.RemoveHs(rd_mol, sanitize=False) |
|
|
|
pt = Chem.GetPeriodicTable() |
|
|
|
|
|
positions = rd_mol.GetConformer().GetPositions() |
|
nonsingles = [] |
|
for bond in rd_mol.GetBonds(): |
|
if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetBondType() == Chem.BondType.TRIPLE: |
|
i = bond.GetBeginAtomIdx() |
|
j = bond.GetEndAtomIdx() |
|
dist = np.linalg.norm(positions[i]-positions[j]) |
|
nonsingles.append((dist,bond)) |
|
nonsingles.sort(reverse=True, key=lambda t: t[0]) |
|
|
|
for (d,bond) in nonsingles: |
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
|
|
if calc_valence(a1) > pt.GetDefaultValence(a1.GetAtomicNum()) or \ |
|
calc_valence(a2) > pt.GetDefaultValence(a2.GetAtomicNum()): |
|
btype = Chem.BondType.SINGLE |
|
if bond.GetBondType() == Chem.BondType.TRIPLE: |
|
btype = Chem.BondType.DOUBLE |
|
bond.SetBondType(btype) |
|
|
|
for atom in rd_mol.GetAtoms(): |
|
|
|
if atom.GetAtomicNum() == 7 and atom.GetDegree() == 4: |
|
atom.SetFormalCharge(1) |
|
|
|
rd_mol = Chem.AddHs(rd_mol,addCoords=True) |
|
|
|
positions = rd_mol.GetConformer().GetPositions() |
|
center = np.mean(positions[np.all(np.isfinite(positions),axis=1)],axis=0) |
|
for atom in rd_mol.GetAtoms(): |
|
i = atom.GetIdx() |
|
pos = positions[i] |
|
if not np.all(np.isfinite(pos)): |
|
|
|
rd_mol.GetConformer().SetAtomPosition(i,center) |
|
|
|
try: |
|
Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE) |
|
except: |
|
raise MolReconsError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for bond in rd_mol.GetBonds(): |
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
if bond.GetIsAromatic(): |
|
if not a1.GetIsAromatic() or not a2.GetIsAromatic(): |
|
bond.SetIsAromatic(False) |
|
elif a1.GetIsAromatic() and a2.GetIsAromatic(): |
|
bond.SetIsAromatic(True) |
|
|
|
return rd_mol |
|
|
|
|
|
def calc_valence(rdatom): |
|
'''Can call GetExplicitValence before sanitize, but need to |
|
know this to fix up the molecule to prevent sanitization failures''' |
|
cnt = 0.0 |
|
for bond in rdatom.GetBonds(): |
|
cnt += bond.GetBondTypeAsDouble() |
|
return cnt |
|
|
|
|
|
def count_nbrs_of_elem(atom, atomic_num): |
|
''' |
|
Count the number of neighbors atoms |
|
of atom with the given atomic_num. |
|
''' |
|
count = 0 |
|
for nbr in ob.OBAtomAtomIter(atom): |
|
if nbr.GetAtomicNum() == atomic_num: |
|
count += 1 |
|
return count |
|
|
|
|
|
def fixup(atoms, mol, indicators): |
|
'''Set atom properties to match channel. Keep doing this |
|
to beat openbabel over the head with what we want to happen.''' |
|
|
|
mol.SetAromaticPerceived(True) |
|
for i, atom in enumerate(atoms): |
|
|
|
ind = indicators[i] |
|
|
|
if ind[ATOM_FAMILIES_ID['Aromatic']]: |
|
atom.SetAromatic(True) |
|
atom.SetHyb(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (atom.GetAtomicNum() in (7, 8)) and atom.IsInRing(): |
|
|
|
|
|
|
|
acnt = 0 |
|
for nbr in ob.OBAtomAtomIter(atom): |
|
if nbr.IsAromatic(): |
|
acnt += 1 |
|
if acnt > 1: |
|
atom.SetAromatic(True) |
|
|
|
|
|
def raw_obmol_from_generated(data): |
|
xyz = data.ligand_context_pos.clone().cpu().tolist() |
|
atomic_nums = data.ligand_context_element.clone().cpu().tolist() |
|
|
|
|
|
mol, atoms = make_obmol(xyz, atomic_nums) |
|
return mol, atoms |
|
|
|
|
|
UPGRADE_BOND_ORDER = {Chem.BondType.SINGLE:Chem.BondType.DOUBLE, Chem.BondType.DOUBLE:Chem.BondType.TRIPLE} |
|
|
|
def postprocess_rd_mol_1(rdmol): |
|
|
|
rdmol = Chem.RemoveHs(rdmol) |
|
|
|
|
|
nbh_list = {} |
|
for bond in rdmol.GetBonds(): |
|
begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
|
if begin not in nbh_list: nbh_list[begin] = [end] |
|
else: nbh_list[begin].append(end) |
|
|
|
if end not in nbh_list: nbh_list[end] = [begin] |
|
else: nbh_list[end].append(begin) |
|
|
|
|
|
for atom in rdmol.GetAtoms(): |
|
idx = atom.GetIdx() |
|
num_radical = atom.GetNumRadicalElectrons() |
|
if num_radical > 0: |
|
for j in nbh_list[idx]: |
|
if j <= idx: continue |
|
nb_atom = rdmol.GetAtomWithIdx(j) |
|
nb_radical = nb_atom.GetNumRadicalElectrons() |
|
if nb_radical > 0: |
|
bond = rdmol.GetBondBetweenAtoms(idx, j) |
|
bond.SetBondType(UPGRADE_BOND_ORDER[bond.GetBondType()]) |
|
nb_atom.SetNumRadicalElectrons(nb_radical - 1) |
|
num_radical -= 1 |
|
atom.SetNumRadicalElectrons(num_radical) |
|
|
|
num_radical = atom.GetNumRadicalElectrons() |
|
if num_radical > 0: |
|
atom.SetNumRadicalElectrons(0) |
|
num_hs = atom.GetNumExplicitHs() |
|
atom.SetNumExplicitHs(num_hs + num_radical) |
|
|
|
return rdmol |
|
|
|
|
|
def postprocess_rd_mol_2(rdmol): |
|
rdmol_edit = Chem.RWMol(rdmol) |
|
|
|
ring_info = rdmol.GetRingInfo() |
|
ring_info.AtomRings() |
|
rings = [set(r) for r in ring_info.AtomRings()] |
|
for i, ring_a in enumerate(rings): |
|
if len(ring_a) == 3: |
|
non_carbon = [] |
|
atom_by_symb = {} |
|
for atom_idx in ring_a: |
|
symb = rdmol.GetAtomWithIdx(atom_idx).GetSymbol() |
|
if symb != 'C': |
|
non_carbon.append(atom_idx) |
|
if symb not in atom_by_symb: |
|
atom_by_symb[symb] = [atom_idx] |
|
else: |
|
atom_by_symb[symb].append(atom_idx) |
|
if len(non_carbon) == 2: |
|
rdmol_edit.RemoveBond(*non_carbon) |
|
if 'O' in atom_by_symb and len(atom_by_symb['O']) == 2: |
|
rdmol_edit.RemoveBond(*atom_by_symb['O']) |
|
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).SetNumExplicitHs( |
|
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).GetNumExplicitHs() + 1 |
|
) |
|
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).SetNumExplicitHs( |
|
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).GetNumExplicitHs() + 1 |
|
) |
|
rdmol = rdmol_edit.GetMol() |
|
|
|
for atom in rdmol.GetAtoms(): |
|
if atom.GetFormalCharge() > 0: |
|
atom.SetFormalCharge(0) |
|
|
|
return rdmol |
|
|
|
|
|
def reconstruct_from_generated(data): |
|
xyz = data.ligand_context_pos.clone().cpu().tolist() |
|
atomic_nums = data.ligand_context_element.clone().cpu().tolist() |
|
indicators = data.ligand_context_feature_full[:, -len(ATOM_FAMILIES_ID):].clone().cpu().bool().tolist() |
|
|
|
mol, atoms = make_obmol(xyz, atomic_nums) |
|
fixup(atoms, mol, indicators) |
|
|
|
connect_the_dots(mol, atoms, indicators, 2) |
|
fixup(atoms, mol, indicators) |
|
mol.EndModify() |
|
|
|
fixup(atoms, mol, indicators) |
|
|
|
mol.AddPolarHydrogens() |
|
mol.PerceiveBondOrders() |
|
fixup(atoms, mol, indicators) |
|
|
|
for (i,a) in enumerate(atoms): |
|
ob.OBAtomAssignTypicalImplicitHydrogens(a) |
|
fixup(atoms, mol, indicators) |
|
|
|
mol.AddHydrogens() |
|
fixup(atoms, mol, indicators) |
|
|
|
|
|
for ring in ob.OBMolRingIter(mol): |
|
if 5 <= ring.Size() <= 6: |
|
carbon_cnt = 0 |
|
aromatic_ccnt = 0 |
|
for ai in ring._path: |
|
a = mol.GetAtom(ai) |
|
if a.GetAtomicNum() == 6: |
|
carbon_cnt += 1 |
|
if a.IsAromatic(): |
|
aromatic_ccnt += 1 |
|
if aromatic_ccnt >= carbon_cnt/2 and aromatic_ccnt != ring.Size(): |
|
|
|
for ai in ring._path: |
|
a = mol.GetAtom(ai) |
|
a.SetAromatic(True) |
|
|
|
|
|
for bond in ob.OBMolBondIter(mol): |
|
a1 = bond.GetBeginAtom() |
|
a2 = bond.GetEndAtom() |
|
if a1.IsAromatic() and a2.IsAromatic(): |
|
bond.SetAromatic(True) |
|
|
|
mol.PerceiveBondOrders() |
|
|
|
rd_mol = convert_ob_mol_to_rd_mol(mol) |
|
|
|
|
|
rd_mol = postprocess_rd_mol_1(rd_mol) |
|
rd_mol = postprocess_rd_mol_2(rd_mol) |
|
|
|
return rd_mol |
|
|