import gradio as gr from PIL import Image from pathlib import Path from vq_gan_3d import load_VQGAN # Numerical computing import numpy as np # PyTorch import torch import torch.nn.functional as F # Utilities to calculate grids from SMILES and visualization from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size ckpt_path = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt") folder = str(ckpt_path.parent) ckpt_file = ckpt_path.name vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval() def comparison(SMILES): density_grids = get_grid_from_smiles([SMILES]) # 1) Prepare density grids → list of ready-to-use tensors processed_tensors = [] for item in density_grids: rho = item["rho"] # raw NumPy array from cube generation smi = item["smiles"] name = item["name"] tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor tensor = torch.log1p(tensor) # apply log(ρ + 1) normalization # enforce consistent 128×128×128 input size for VQGAN if tensor.shape != torch.Size([128, 128, 128]): tensor = tensor.unsqueeze(0).unsqueeze(0) # add batch & channel dims tensor = F.interpolate( tensor, size=(128, 128, 128), mode="trilinear", align_corners=False )[0, 0] # remove extra dims after resizing print(f"[info] {smi} was interpolated to 128³") # store metadata alongside the processed tensor processed_tensors.append({ "name": name, "smiles": smi, "tensor": tensor }) # log shape and min/max to verify normalization and sizing print( f"{smi}: shape={tuple(tensor.shape)}, " f"min={tensor.min():.4f}, max={tensor.max():.4f}" ) # 2) Encode → Decode (inference with VQGAN) reconstructions = [] for item in processed_tensors: smi = item["smiles"] # original SMILES string name = item["name"] # unique grid name vol = item["tensor"] # preprocessed [128³] tensor # add batch & channel dims and move to the selected device x = vol.unsqueeze(0).unsqueeze(0) # shape [1,1,128,128,128] with torch.no_grad(): # disable gradient computation for faster inference indices = vqgan.encode(x) # map input volume to discrete latent codes recon = vqgan.decode(indices) # reconstruct volume from latent codes # convert reconstructed tensor and original tensor to NumPy arrays recon_np = recon.cpu().numpy()[0, 0] orig_np = vol.cpu().numpy() # compute mean squared error between original and reconstruction mse = np.mean((orig_np - recon_np) ** 2) print(f"{smi} → reconstruction done | MSE={mse:.6f}") # collect results for later visualization reconstructions.append({ "smiles": smi, "name": name, "original": orig_np, "reconstructed": recon_np }) original_grid_plot = plot_voxel_grid( change_grid_size( torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0), size=(48, 48, 48) ), title=f"Original 3D Grid Plot from {SMILES}" ) rec_grid_plot = plot_voxel_grid( change_grid_size( torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0), size=(48, 48, 48) ), title=f"Reconstructed 3D Grid Plot from {SMILES}" ) np.save("original_grid.npy", reconstructions[0]["original"]) np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"]) original_grid_plot.savefig("original_grid_plot.png", format='png') rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png') original_grid_plot = Image.open("original_grid_plot.png") rec_grid_plot = Image.open("reconstructed_grid_plot.png") return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy" with gr.Blocks() as demo: gr.Markdown( """ # 3DGrid-VQGAN: SMILES to 3D Grid Reconstruction In this demo, provide a SMILES to generate a 3D electron density grid of shape `128x128x128` and then the demo uses the 3DGrid-VQGAN model to reconstruct the original grid. To speed up the visualization process, we reduced the 3D grid size to `48x48x48`. _This is just a demo environment; for heavy-duty usage, please visit:_ https://github.com/IBM/materials/tree/main/models/3dgrid_vqgan to download the model and run your own experiments. Please, be aware that long and complex SMILES sequences may take very long time to compute. Consider using simple SMILES molecules in this demo. """ ) gr.Interface( fn=comparison, inputs=[ gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True) ], outputs=[ gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2), gr.Number(label="Mean Squared Error (MSE)"), gr.File(label="Original 3D Grid numpy file"), gr.File(label="Reconstructed 3D Grid numpy file") ] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")