#!/usr/bin/env python3 """ Standalone inference script for Pyrosage TMP AttentiveFP Model Usage: python inference.py "SMILES_STRING" """ import sys import torch from torch_geometric.nn import AttentiveFP from rdkit import Chem from torch_geometric.data import Data def smiles_to_data(smiles): """Convert SMILES string to PyG Data object with enhanced features""" mol = Chem.MolFromSmiles(smiles) if mol is None: return None # Enhanced atom features (10 dimensions) atom_features = [] for atom in mol.GetAtoms(): features = [ atom.GetAtomicNum(), atom.GetTotalDegree(), atom.GetFormalCharge(), atom.GetTotalNumHs(), atom.GetNumRadicalElectrons(), int(atom.GetIsAromatic()), int(atom.IsInRing()), # Hybridization as one-hot (3 dimensions) int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP), int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2), int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3) ] atom_features.append(features) x = torch.tensor(atom_features, dtype=torch.float) # Enhanced bond features (6 dimensions) edges_list = [] edge_features = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edges_list.extend([[i, j], [j, i]]) features = [ # Bond type as one-hot (4 dimensions) int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE), int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE), int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE), int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC), # Additional features (2 dimensions) int(bond.GetIsConjugated()), int(bond.IsInRing()) ] edge_features.extend([features, features]) if not edges_list: return None edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_attr = torch.tensor(edge_features, dtype=torch.float) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) def load_model(): """Load the AttentiveFP model""" model_dict = torch.load('pytorch_model.pt', map_location='cpu') state_dict = model_dict['model_state_dict'] hyperparams = model_dict['hyperparameters'] model = AttentiveFP( in_channels=10, # Enhanced atom features hidden_channels=hyperparams["hidden_channels"], out_channels=1, edge_dim=6, # Enhanced bond features num_layers=hyperparams["num_layers"], num_timesteps=hyperparams["num_timesteps"], dropout=hyperparams["dropout"], ) model.load_state_dict(state_dict) model.eval() return model def predict(model, smiles): """Make prediction for a SMILES string""" data = smiles_to_data(smiles) if data is None: return None batch = torch.zeros(data.num_nodes, dtype=torch.long) with torch.no_grad(): output = model(data.x, data.edge_index, data.edge_attr, batch) return output.item() def main(): if len(sys.argv) != 2: print("Usage: python inference.py 'SMILES_STRING'") print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'") sys.exit(1) smiles = sys.argv[1] print(f"Loading TMP AttentiveFP model...") try: model = load_model() print(f"Making prediction for: {smiles}") prediction = predict(model, smiles) if prediction is not None: print(f'Regression result: {prediction:.4f}') else: print("Error: Could not process SMILES string") except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": main()