pyrosage-tmp-attentivefp / inference.py
alarv's picture
Upload TMP AttentiveFP model
9ee7153 verified
#!/usr/bin/env python3
"""
Standalone inference script for Pyrosage TMP AttentiveFP Model
Usage: python inference.py "SMILES_STRING"
"""
import sys
import torch
from torch_geometric.nn import AttentiveFP
from rdkit import Chem
from torch_geometric.data import Data
def smiles_to_data(smiles):
"""Convert SMILES string to PyG Data object with enhanced features"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# Enhanced atom features (10 dimensions)
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(),
atom.GetTotalDegree(),
atom.GetFormalCharge(),
atom.GetTotalNumHs(),
atom.GetNumRadicalElectrons(),
int(atom.GetIsAromatic()),
int(atom.IsInRing()),
# Hybridization as one-hot (3 dimensions)
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP),
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2),
int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3)
]
atom_features.append(features)
x = torch.tensor(atom_features, dtype=torch.float)
# Enhanced bond features (6 dimensions)
edges_list = []
edge_features = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edges_list.extend([[i, j], [j, i]])
features = [
# Bond type as one-hot (4 dimensions)
int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
# Additional features (2 dimensions)
int(bond.GetIsConjugated()),
int(bond.IsInRing())
]
edge_features.extend([features, features])
if not edges_list:
return None
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
edge_attr = torch.tensor(edge_features, dtype=torch.float)
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
def load_model():
"""Load the AttentiveFP model"""
model_dict = torch.load('pytorch_model.pt', map_location='cpu')
state_dict = model_dict['model_state_dict']
hyperparams = model_dict['hyperparameters']
model = AttentiveFP(
in_channels=10, # Enhanced atom features
hidden_channels=hyperparams["hidden_channels"],
out_channels=1,
edge_dim=6, # Enhanced bond features
num_layers=hyperparams["num_layers"],
num_timesteps=hyperparams["num_timesteps"],
dropout=hyperparams["dropout"],
)
model.load_state_dict(state_dict)
model.eval()
return model
def predict(model, smiles):
"""Make prediction for a SMILES string"""
data = smiles_to_data(smiles)
if data is None:
return None
batch = torch.zeros(data.num_nodes, dtype=torch.long)
with torch.no_grad():
output = model(data.x, data.edge_index, data.edge_attr, batch)
return output.item()
def main():
if len(sys.argv) != 2:
print("Usage: python inference.py 'SMILES_STRING'")
print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'")
sys.exit(1)
smiles = sys.argv[1]
print(f"Loading TMP AttentiveFP model...")
try:
model = load_model()
print(f"Making prediction for: {smiles}")
prediction = predict(model, smiles)
if prediction is not None:
print(f'Regression result: {prediction:.4f}')
else:
print("Error: Could not process SMILES string")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()