basic model definition
Browse files- model_definition.py +14 -0
model_definition.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Net(torch.nn.Module):
|
| 2 |
+
def __init__(self, num_relations, num_classes, num_nodes=None, input_dim=None, hidden_dim=16, num_bases=30):
|
| 3 |
+
super().__init__()
|
| 4 |
+
assert num_nodes is not None or input_dim is not None, "Please provide input feature dimensionality or number of nodes"
|
| 5 |
+
self.conv1 = RGCNConv(num_nodes if input_dim is None else input_dim, hidden_dim, num_relations,
|
| 6 |
+
num_bases)
|
| 7 |
+
self.conv2 = RGCNConv(hidden_dim, num_classes, dataset.num_relations,
|
| 8 |
+
num_bases)
|
| 9 |
+
|
| 10 |
+
def forward(self, x, edge_index, edge_type):
|
| 11 |
+
# if x is None, uses an embedding based on num_nodes
|
| 12 |
+
x = F.relu(self.conv1(x, edge_index, edge_type))
|
| 13 |
+
x = self.conv2(x, edge_index, edge_type)
|
| 14 |
+
return F.log_softmax(x, dim=1)
|