Yuan (Cyrus) Chiang
Add convenient ZBL torch calculator (#44)
7cc6c4a unverified
import numpy as np
import torch
# TODO: consider using vesin
from matscipy.neighbours import neighbour_list
from torch_geometric.data import Data
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
def get_neighbor(
atoms: Atoms, cutoff: float, self_interaction: bool = False
):
pbc = atoms.pbc
cell = atoms.cell.array
i, j, S = neighbour_list(
quantities="ijS",
pbc=pbc,
cell=cell,
positions=atoms.positions,
cutoff=cutoff
)
if not self_interaction:
# Eliminate self-edges that don't cross periodic boundaries
true_self_edge = i == j
true_self_edge &= np.all(S == 0, axis=1)
keep_edge = ~true_self_edge
i = i[keep_edge]
j = j[keep_edge]
S = S[keep_edge]
edge_index = np.stack((i, j)).astype(np.int64)
edge_shift = np.dot(S, cell)
return edge_index, edge_shift
def collate_fn(batch: list[Atoms], cutoff: float) -> Data:
"""Collate a list of Atoms objects into a single batched Atoms object."""
# Offset the edge indices for each graph to ensure they remain disconnected
offset = 0
node_batch = []
numbers_batch = []
positions_batch = []
# ec_batch = []
forces_batch = []
charges_batch = []
magmoms_batch = []
dipoles_batch = []
edge_index_batch = []
edge_shift_batch = []
cell_batch = []
natoms_batch = []
energy_batch = []
stress_batch = []
for i, atoms in enumerate(batch):
edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False)
edge_index[0] += offset
edge_index[1] += offset
edge_index_batch.append(torch.tensor(edge_index))
edge_shift_batch.append(torch.tensor(edge_shift))
natoms = len(atoms)
offset += natoms
node_batch.append(torch.ones(natoms, dtype=torch.long) * i)
natoms_batch.append(natoms)
cell_batch.append(torch.tensor(atoms.cell.array))
numbers_batch.append(torch.tensor(atoms.numbers))
positions_batch.append(torch.tensor(atoms.positions))
# ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers])
charges_batch.append(
atoms.get_initial_charges()
if atoms.get_initial_charges().any()
else torch.full((natoms,), torch.nan)
)
magmoms_batch.append(
atoms.get_initial_magnetic_moments()
if atoms.get_initial_magnetic_moments().any()
else torch.full((natoms,), torch.nan)
)
# Create the new 'arrays' data for the batch
cell_batch = torch.stack(cell_batch, dim=0)
node_batch = torch.cat(node_batch, dim=0)
positions_batch = torch.cat(positions_batch, dim=0)
numbers_batch = torch.cat(numbers_batch, dim=0)
natoms_batch = torch.tensor(natoms_batch, dtype=torch.long)
charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None
magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None
# ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch))
# ec_batch = torch.stack(ec_batch, dim=0)
edge_index_batch = torch.cat(edge_index_batch, dim=1)
edge_shift_batch = torch.cat(edge_shift_batch, dim=0)
arrays_batch_concatenated = {
"cell": cell_batch,
"positions": positions_batch,
"edge_index": edge_index_batch,
"edge_shift": edge_shift_batch,
"numbers": numbers_batch,
"num_nodes": offset,
"batch": node_batch,
"charges": charges_batch,
"magmoms": magmoms_batch,
# "ec": ec_batch,
"natoms": natoms_batch,
"cutoff": torch.tensor(cutoff),
}
# TODO: custom fields
# Create a new Data object with the concatenated arrays data
batch_data = Data.from_dict(arrays_batch_concatenated)
return batch_data
def decollate_fn(batch_data: Data) -> list[Atoms]:
"""Decollate a batched Data object into a list of individual Atoms objects."""
# FIXME: this function is not working properly when the batch_data is on GPU.
# TODO: create a new Cell class using torch tensor to handle device placement.
# As a temporary fix, detach the batch_data from the GPU and move it to CPU.
batch_data = batch_data.detach().cpu()
# Initialize empty lists to store individual data entries
individual_entries = []
# Split the 'batch' attribute to identify data entries
unique_batches = batch_data.batch.unique(sorted=True)
for i in unique_batches:
# Identify the indices corresponding to the current data entry
entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0]
# Extract the attributes for the current data entry
cell = batch_data.cell[i]
numbers = batch_data.numbers[entry_indices]
positions = batch_data.positions[entry_indices]
# edge_index = batch_data.edge_index[:, entry_indices]
# edge_shift = batch_data.edge_shift[entry_indices]
# batch_data.ec[entry_indices] if batch_data.ec is not None else None
# Optional fields
energy = batch_data.energy[i] if "energy" in batch_data else None
forces = batch_data.forces[entry_indices] if "forces" in batch_data else None
stress = batch_data.stress[i] if "stress" in batch_data else None
# charges = batch_data.charges[entry_indices] if "charges" in batch_data else None
# magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None
# dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None
# TODO: cumstom fields
# Create an 'Atoms' object for the current data entry
atoms = Atoms(
cell=cell,
positions=positions,
numbers=numbers,
# forces=None if torch.any(torch.isnan(forces)) else forces,
# charges=None if torch.any(torch.isnan(charges)) else charges,
# magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms,
# dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles,
# energy=None if torch.isnan(energy) else energy,
# stress=None if torch.any(torch.isnan(stress)) else stress,
)
atoms.calc = SinglePointCalculator(
energy=energy,
forces=forces,
stress=stress,
# charges=charges,
# magmoms=magmoms,
) # type: ignore
# Append the individual data entry to the list
individual_entries.append(atoms)
return individual_entries