import torch | |
from torch.nn import Linear, ReLU, SiLU, Sequential | |
from torch_geometric.nn import MessagePassing | |
from torch_scatter import scatter | |
from models.mlp_and_gnn import MLPBiasFree | |
class EGNNLayer(MessagePassing): | |
"""E(n) Equivariant GNN Layer | |
Paper: E(n) Equivariant Graph Neural Networks, Satorras et al. | |
""" | |
def __init__(self, emb_dim, num_mlp_layers, aggr="add"): | |
""" | |
Args: | |
emb_dim: (int) - hidden dimension `d` | |
activation: (str) - non-linearity within MLPs (swish/relu) | |
norm: (str) - normalisation layer (layer/batch) | |
aggr: (str) - aggregation function `\oplus` (sum/mean/max) | |
""" | |
# Set the aggregation function | |
super().__init__(aggr=aggr) | |
self.emb_dim = emb_dim | |
# self.activation = ReLU() | |
self.dist_embedding = Linear(1, emb_dim, bias=False) | |
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers) | |
# MLP `\psi_h` for computing messages `m_ij` | |
# self.mlp_msg = Sequential( | |
# Linear(2 * emb_dim + 1, emb_dim, bias=False), | |
# torch.nn.LayerNorm(emb_dim, bias=False), | |
# self.activation, | |
# Linear(emb_dim, emb_dim, bias=False), | |
# torch.nn.LayerNorm(emb_dim, bias=False), | |
# self.activation, | |
# ) | |
# layers = [Linear(2 * emb_dim + 1, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] \ | |
# + [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) | |
# layers = [Linear(3 * emb_dim, emb_dim, bias=False)] \ | |
# + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) \ | |
# + [torch.nn.LayerNorm(emb_dim, bias=False)] | |
# self.mlp_msg = Sequential(*layers) | |
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers) | |
# MLP `\psi_x` for computing messages `\overrightarrow{m}_ij` | |
# self.mlp_pos = Sequential( | |
# Linear(emb_dim, emb_dim), torch.nn.LayerNorm(emb_dim), self.activation, Linear(emb_dim, 1) | |
# ) | |
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)] | |
# layers = [Linear(emb_dim, emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)] | |
# self.mlp_pos = Sequential(*layers) | |
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers) | |
# MLP `\phi` for computing updated node features `h_i^{l+1}` | |
# self.mlp_upd = Sequential( | |
# Linear(2 * emb_dim, emb_dim, bias=False), | |
# torch.nn.LayerNorm(emb_dim, bias=False), | |
# self.activation, | |
# Linear(emb_dim, emb_dim, bias=False), | |
# torch.nn.LayerNorm(emb_dim, bias=False), | |
# self.activation, | |
# ) | |
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * num_mlp_layers | |
# layers = [Linear(emb_dim, emb_dim, bias=False)] + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) | |
# self.mlp_upd = Sequential(*layers) | |
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers) | |
def forward(self, h, pos, edge_index): | |
""" | |
Args: | |
h: (n, d) - initial node features | |
pos: (n, 3) - initial node coordinates | |
edge_index: (e, 2) - pairs of edges (i, j) | |
Returns: | |
out: [(n, d),(n,3)] - updated node features | |
""" | |
out = self.propagate(edge_index, h=h, pos=pos) | |
return out | |
def message(self, h_i, h_j, pos_i, pos_j): | |
# Compute messages | |
pos_diff = pos_i - pos_j | |
# dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) | |
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um | |
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1) | |
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod) | |
msg = self.mlp_msg(msg) | |
# Scale magnitude of displacement vector | |
pos_diff = pos_diff * self.mlp_pos(msg) | |
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model. | |
# NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability. | |
# print(torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1)) | |
# print(msg) | |
# import pdb; pdb.set_trace() | |
return msg, pos_diff, inner_prod | |
def aggregate(self, inputs, index): | |
msgs, pos_diffs, inner_prod = inputs | |
# Aggregate messages | |
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce="add") | |
# Aggregate displacement vectors | |
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="add") | |
counts = torch.ones_like(inner_prod) | |
counts[inner_prod==0] = 0 | |
counts = scatter(counts, index, dim=0, reduce="add") | |
counts[counts==0] = 1 | |
pos_aggr = pos_aggr / counts | |
# print(msgs) | |
# print(msg_aggr) | |
# import pdb; pdb.set_trace() | |
return msg_aggr, pos_aggr | |
def update(self, aggr_out, h, pos): | |
msg_aggr, pos_aggr = aggr_out | |
# upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1)) | |
upd_out = self.mlp_upd(msg_aggr) | |
upd_pos = pos + pos_aggr | |
# import pdb; pdb.set_trace() | |
return upd_out, upd_pos | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" | |