|
|
|
""" |
|
Standalone inference script for Pyrosage TMP AttentiveFP Model |
|
Usage: python inference.py "SMILES_STRING" |
|
""" |
|
|
|
import sys |
|
import torch |
|
from torch_geometric.nn import AttentiveFP |
|
from rdkit import Chem |
|
from torch_geometric.data import Data |
|
|
|
|
|
def smiles_to_data(smiles): |
|
"""Convert SMILES string to PyG Data object with enhanced features""" |
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return None |
|
|
|
|
|
atom_features = [] |
|
for atom in mol.GetAtoms(): |
|
features = [ |
|
atom.GetAtomicNum(), |
|
atom.GetTotalDegree(), |
|
atom.GetFormalCharge(), |
|
atom.GetTotalNumHs(), |
|
atom.GetNumRadicalElectrons(), |
|
int(atom.GetIsAromatic()), |
|
int(atom.IsInRing()), |
|
|
|
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP), |
|
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2), |
|
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3) |
|
] |
|
atom_features.append(features) |
|
|
|
x = torch.tensor(atom_features, dtype=torch.float) |
|
|
|
|
|
edges_list = [] |
|
edge_features = [] |
|
for bond in mol.GetBonds(): |
|
i = bond.GetBeginAtomIdx() |
|
j = bond.GetEndAtomIdx() |
|
edges_list.extend([[i, j], [j, i]]) |
|
|
|
features = [ |
|
|
|
int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE), |
|
int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE), |
|
int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE), |
|
int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC), |
|
|
|
int(bond.GetIsConjugated()), |
|
int(bond.IsInRing()) |
|
] |
|
edge_features.extend([features, features]) |
|
|
|
if not edges_list: |
|
return None |
|
|
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t() |
|
edge_attr = torch.tensor(edge_features, dtype=torch.float) |
|
|
|
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) |
|
|
|
|
|
def load_model(): |
|
"""Load the AttentiveFP model""" |
|
model_dict = torch.load('pytorch_model.pt', map_location='cpu') |
|
state_dict = model_dict['model_state_dict'] |
|
hyperparams = model_dict['hyperparameters'] |
|
|
|
model = AttentiveFP( |
|
in_channels=10, |
|
hidden_channels=hyperparams["hidden_channels"], |
|
out_channels=1, |
|
edge_dim=6, |
|
num_layers=hyperparams["num_layers"], |
|
num_timesteps=hyperparams["num_timesteps"], |
|
dropout=hyperparams["dropout"], |
|
) |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |
|
|
|
|
|
def predict(model, smiles): |
|
"""Make prediction for a SMILES string""" |
|
data = smiles_to_data(smiles) |
|
if data is None: |
|
return None |
|
|
|
batch = torch.zeros(data.num_nodes, dtype=torch.long) |
|
with torch.no_grad(): |
|
output = model(data.x, data.edge_index, data.edge_attr, batch) |
|
return output.item() |
|
|
|
|
|
def main(): |
|
if len(sys.argv) != 2: |
|
print("Usage: python inference.py 'SMILES_STRING'") |
|
print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'") |
|
sys.exit(1) |
|
|
|
smiles = sys.argv[1] |
|
print(f"Loading TMP AttentiveFP model...") |
|
|
|
try: |
|
model = load_model() |
|
print(f"Making prediction for: {smiles}") |
|
|
|
prediction = predict(model, smiles) |
|
if prediction is not None: |
|
print(f'Regression result: {prediction:.4f}') |
|
else: |
|
print("Error: Could not process SMILES string") |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|