File size: 5,890 Bytes
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter

from models.mlp_and_gnn import MLPBiasFree


class EGNNLayer(MessagePassing):
    """E(n) Equivariant GNN Layer

    Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
    """
    def __init__(self, emb_dim, num_mlp_layers, aggr="add"):
        """
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        # self.activation = ReLU()

        self.dist_embedding = Linear(1, emb_dim, bias=False)
        self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)

        # MLP `\psi_h` for computing messages `m_ij`
        # self.mlp_msg = Sequential(
        #     Linear(2 * emb_dim + 1, emb_dim, bias=False),
        #     torch.nn.LayerNorm(emb_dim, bias=False),
        #     self.activation,
        #     Linear(emb_dim, emb_dim, bias=False),
        #     torch.nn.LayerNorm(emb_dim, bias=False),
        #     self.activation,
        # )
        # layers = [Linear(2 * emb_dim + 1, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] \
        #         + [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1)
        # layers = [Linear(3 * emb_dim, emb_dim, bias=False)] \
        #         + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) \
        #         + [torch.nn.LayerNorm(emb_dim, bias=False)]
        # self.mlp_msg = Sequential(*layers)
        self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)

        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        # self.mlp_pos = Sequential(
        #     Linear(emb_dim, emb_dim), torch.nn.LayerNorm(emb_dim), self.activation, Linear(emb_dim, 1)
        # )
        # layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
        # layers = [Linear(emb_dim, emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
        # self.mlp_pos = Sequential(*layers)
        self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)

        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        # self.mlp_upd = Sequential(
        #     Linear(2 * emb_dim, emb_dim, bias=False),
        #     torch.nn.LayerNorm(emb_dim, bias=False),
        #     self.activation,
        #     Linear(emb_dim, emb_dim, bias=False),
        #     torch.nn.LayerNorm(emb_dim, bias=False),
        #     self.activation,
        # )
        # layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * num_mlp_layers
        # layers = [Linear(emb_dim, emb_dim, bias=False)] + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1)
        # self.mlp_upd = Sequential(*layers)
        self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)

    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        # dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
        inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)
        # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
        # NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
        # print(torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1))
        # print(msg)
        # import pdb; pdb.set_trace()
        return msg, pos_diff, inner_prod

    def aggregate(self, inputs, index):
        msgs, pos_diffs, inner_prod = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce="add")
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="add")

        counts = torch.ones_like(inner_prod)
        counts[inner_prod==0] = 0
        counts = scatter(counts, index, dim=0, reduce="add")
        counts[counts==0] = 1
        pos_aggr = pos_aggr / counts
        # print(msgs)
        # print(msg_aggr)
        # import pdb; pdb.set_trace()
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        # upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_out = self.mlp_upd(msg_aggr)
        upd_pos = pos + pos_aggr
        # import pdb; pdb.set_trace()
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"