|
|
|
|
|
|
|
|
import argparse, json, os |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch import Tensor |
|
|
from torch_scatter import scatter_add, scatter_mean |
|
|
from torch_sparse import coalesce, spspmm |
|
|
|
|
|
from torch_geometric.data import Data |
|
|
from torch_geometric.loader import DataLoader |
|
|
from torch_geometric.datasets import Planetoid, TUDataset |
|
|
from torch_geometric.nn import GCNConv, global_mean_pool |
|
|
|
|
|
from rich import print |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_scaled_self_loops(edge_index: Tensor, |
|
|
edge_weight: Optional[Tensor], |
|
|
num_nodes: int, |
|
|
scale: float = 1.0) -> Tuple[Tensor, Tensor]: |
|
|
"""Add self-loops with a chosen weight (scale). If scale=0, return unchanged.""" |
|
|
if scale == 0.0: |
|
|
if edge_weight is None: |
|
|
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) |
|
|
return edge_index, edge_weight |
|
|
device = edge_index.device |
|
|
self_loops = torch.arange(num_nodes, device=device) |
|
|
self_index = torch.stack([self_loops, self_loops], dim=0) |
|
|
self_weight = torch.full((num_nodes,), float(scale), device=device) |
|
|
if edge_weight is None: |
|
|
base_w = torch.ones(edge_index.size(1), device=device) |
|
|
else: |
|
|
base_w = edge_weight |
|
|
ei = torch.cat([edge_index, self_index], dim=1) |
|
|
ew = torch.cat([base_w, self_weight], dim=0) |
|
|
|
|
|
ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add') |
|
|
return ei, ew |
|
|
|
|
|
|
|
|
def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor: |
|
|
""" |
|
|
Compute (binary) k-th power adjacency using sparse matmul (torch_sparse.spspmm). |
|
|
Returns a coalesced edge_index (no weights, duplicates removed). |
|
|
""" |
|
|
|
|
|
device = edge_index.device |
|
|
row, col = edge_index |
|
|
val = torch.ones(row.numel(), device=device) |
|
|
Ai, Av = edge_index, val |
|
|
|
|
|
Ri, Rv = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes) |
|
|
|
|
|
mask = Ri[0] != Ri[1] |
|
|
Ri = Ri[:, mask] |
|
|
|
|
|
return coalesce(Ri, torch.ones(Ri.size(1), device=device), num_nodes, num_nodes)[0] |
|
|
|
|
|
|
|
|
def build_cluster_graph(edge_index: Tensor, |
|
|
num_nodes: int, |
|
|
node2cluster: Tensor, |
|
|
weight_per_edge: Optional[Tensor] = None, |
|
|
num_clusters: Optional[int] = None |
|
|
) -> Tuple[Tensor, Tensor, int]: |
|
|
""" |
|
|
Build cluster graph A_c = S^T A S. |
|
|
node2cluster: [N] long tensor with the cluster id for each node (hard assignment). |
|
|
Returns (edge_index_c, edge_weight_c, K). |
|
|
""" |
|
|
if num_clusters is None: |
|
|
K = int(node2cluster.max().item()) + 1 |
|
|
else: |
|
|
K = num_clusters |
|
|
|
|
|
src, dst = edge_index |
|
|
csrc = node2cluster[src] |
|
|
cdst = node2cluster[dst] |
|
|
edge_c = torch.stack([csrc, cdst], dim=0) |
|
|
if weight_per_edge is None: |
|
|
w = torch.ones(edge_c.size(1), device=edge_c.device) |
|
|
else: |
|
|
w = weight_per_edge |
|
|
edge_c, w = coalesce(edge_c, w, K, K, op='add') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return edge_c, w, K |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_clusters_single_file(obj: dict, n_nodes: int) -> Tuple[List[List[int]], Tensor]: |
|
|
""" |
|
|
Expect the JSON to have top-level "clusters": [{members:[...], score:...}, ...] |
|
|
Unassigned nodes become singleton clusters. |
|
|
If a node appears in multiple clusters, we keep the cluster with largest 'score' then by size. |
|
|
""" |
|
|
clusters = obj.get("clusters", []) |
|
|
|
|
|
per_node = {} |
|
|
out: List[List[int]] = [] |
|
|
|
|
|
cinfo = [] |
|
|
for idx, c in enumerate(clusters): |
|
|
members = c.get("members", []) |
|
|
score = float(c.get("score", 0.0)) |
|
|
cinfo.append((members, score, len(members), idx)) |
|
|
|
|
|
|
|
|
for members, score, size, idx in cinfo: |
|
|
out.append(list(members)) |
|
|
|
|
|
|
|
|
chosen = torch.full((n_nodes,), -1, dtype=torch.long) |
|
|
best_key = [(-1e18, -10) for _ in range(n_nodes)] |
|
|
for c_idx, (members, score, size, _) in enumerate(cinfo): |
|
|
key = (score, size) |
|
|
for u in members: |
|
|
old = best_key[u] |
|
|
if key > old: |
|
|
best_key[u] = key |
|
|
chosen[u] = c_idx |
|
|
|
|
|
|
|
|
next_c = len(out) |
|
|
for u in range(n_nodes): |
|
|
if chosen[u] == -1: |
|
|
out.append([u]) |
|
|
chosen[u] = next_c |
|
|
next_c += 1 |
|
|
|
|
|
|
|
|
base_scores = [float(s) for (_, s, _, _) in cinfo] |
|
|
K = len(out) |
|
|
scores = torch.zeros(K, dtype=torch.float32) |
|
|
|
|
|
for i, sc in enumerate(base_scores): |
|
|
scores[i] = sc |
|
|
|
|
|
if len(base_scores) > 0: |
|
|
smin = min(base_scores) |
|
|
smax = max(base_scores) |
|
|
if smax > smin: |
|
|
|
|
|
norm = (scores[:len(base_scores)] - smin) / (smax - smin) |
|
|
scores[:len(base_scores)] = norm |
|
|
else: |
|
|
|
|
|
scores[:len(base_scores)] = 1.0 |
|
|
|
|
|
cluster_scores = scores.view(-1, 1) |
|
|
|
|
|
|
|
|
return out, cluster_scores |
|
|
|
|
|
|
|
|
def seeds_to_node2cluster(n_nodes: int, clusters: List[List[int]]) -> Tensor: |
|
|
node2cluster = torch.full((n_nodes,), -1, dtype=torch.long) |
|
|
for cid, members in enumerate(clusters): |
|
|
for u in members: |
|
|
node2cluster[u] = cid |
|
|
assert int(node2cluster.min()) >= 0, "All nodes must be assigned a cluster." |
|
|
return node2cluster |
|
|
|
|
|
|
|
|
def load_lrmc_seeds_single_graph(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor]: |
|
|
"""Load seeds for a single big graph (Planetoid).""" |
|
|
with open(seeds_json, "r") as f: |
|
|
obj = json.load(f) |
|
|
clusters, cluster_scores = _parse_clusters_single_file(obj, n_nodes) |
|
|
node2cluster = seeds_to_node2cluster(n_nodes, clusters) |
|
|
return node2cluster, cluster_scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BiLevelLRMC(nn.Module): |
|
|
""" |
|
|
One round: |
|
|
1) Node GCN: H1 = GCN_node(X, A_node) |
|
|
2) Up: Z = mean_{i in c} H1[i] (cluster means via scatter) |
|
|
Cluster graph: A_c = S^T A_node S |
|
|
3) Cluster GCN: Z2 = GCN_cluster(Z, A_c) |
|
|
4) Down: H2 = H1 + W (S Z2) |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
in_dim: int, |
|
|
hidden_dim: int, |
|
|
node2cluster: Tensor, |
|
|
cluster_scores: Tensor, |
|
|
edge_index_node: Tensor, |
|
|
num_nodes: int, |
|
|
self_loop_scale: float = 0.0, |
|
|
use_a2: bool = False): |
|
|
super().__init__() |
|
|
self.num_nodes = num_nodes |
|
|
self.node2cluster = node2cluster.clone().long() |
|
|
self.register_buffer("node2cluster_buf", self.node2cluster) |
|
|
|
|
|
self.register_buffer("cluster_scores", cluster_scores.clone().float()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ei_node = edge_index_node |
|
|
ei_node, ew_node = add_scaled_self_loops(ei_node, None, num_nodes, scale=self_loop_scale) |
|
|
self.register_buffer("edge_index_node", ei_node) |
|
|
self.register_buffer("edge_weight_node", ew_node) |
|
|
|
|
|
|
|
|
ei_base_for_clusters = edge_index_node |
|
|
if use_a2: |
|
|
ei_base_for_clusters = adjacency_power(edge_index_node, num_nodes, k=2) |
|
|
|
|
|
edge_index_c, edge_weight_c, K = build_cluster_graph( |
|
|
ei_base_for_clusters, num_nodes, self.node2cluster |
|
|
) |
|
|
self.register_buffer("edge_index_c", edge_index_c) |
|
|
self.register_buffer("edge_weight_c", edge_weight_c) |
|
|
self.num_clusters = K |
|
|
|
|
|
|
|
|
self.gcn_node = GCNConv(in_dim, hidden_dim, add_self_loops=False, normalize=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
|
|
h1 = self.gcn_node(x, self.edge_index_node, self.edge_weight_node) |
|
|
h1 = F.relu(h1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return h1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeLRMCGCN(nn.Module): |
|
|
def __init__(self, in_dim: int, hidden: int, num_classes: int, |
|
|
node2cluster: Tensor, cluster_scores: Tensor, edge_index: Tensor, num_nodes: int, |
|
|
layers: int = 1, self_loop_scale: float = 0.0, use_a2: bool = False, dropout: float = 0.5): |
|
|
super().__init__() |
|
|
self.layer = BiLevelLRMC(in_dim, hidden, node2cluster, cluster_scores, edge_index, num_nodes, self_loop_scale, use_a2)) |
|
|
self.cls = nn.Linear(hidden, num_classes) |
|
|
self.dropout = dropout |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
h = x |
|
|
h = layer(h) |
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
|
out = self.cls(h) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphLRMCProvider: |
|
|
""" |
|
|
Holds per-graph LRMC assignments and cluster graphs. |
|
|
Expects a directory with one JSON per graph OR a single JSON with {"graphs":[{"graph_id":int,"clusters":[...]},...]}. |
|
|
Node indices are local per-graph [0..n_i-1]. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataset, seeds_path: str, use_a2: bool = True): |
|
|
""" |
|
|
dataset: any iterable/sequence of torch_geometric.data.Data |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.root = Path(seeds_path) |
|
|
self.per_graph: Dict[int, Dict[str, Tensor]] = {} |
|
|
|
|
|
single_json = None |
|
|
if self.root.is_file() and self.root.suffix.lower() == ".json": |
|
|
single_json = json.loads(Path(self.root).read_text()) |
|
|
for gid, data in enumerate(dataset): |
|
|
n = data.num_nodes |
|
|
if single_json is not None and "graphs" in single_json: |
|
|
|
|
|
entry = None |
|
|
for g in single_json["graphs"]: |
|
|
if int(g.get("graph_id", -1)) == gid: |
|
|
entry = g |
|
|
break |
|
|
if entry is None: |
|
|
|
|
|
node2cluster = torch.arange(n, dtype=torch.long) |
|
|
cluster_scores = torch.ones(n, 1, dtype=torch.float32) |
|
|
else: |
|
|
clusters, cluster_scores = _parse_clusters_single_file(entry, n) |
|
|
node2cluster = seeds_to_node2cluster(n, clusters) |
|
|
else: |
|
|
|
|
|
guess = self.root / f"graph_{gid:06d}.json" |
|
|
if guess.exists(): |
|
|
obj = json.loads(guess.read_text()) |
|
|
clusters, cluster_scores = _parse_clusters_single_file(obj, n) |
|
|
node2cluster = seeds_to_node2cluster(n, clusters) |
|
|
else: |
|
|
node2cluster = torch.arange(n, dtype=torch.long) |
|
|
cluster_scores = torch.ones(n, 1, dtype=torch.float32) |
|
|
|
|
|
ei = data.edge_index |
|
|
if use_a2: |
|
|
ei = adjacency_power(ei, n, k=2) |
|
|
ei_c, ew_c, K = build_cluster_graph(ei, n, node2cluster) |
|
|
self.per_graph[gid] = { |
|
|
"node2cluster": node2cluster, |
|
|
"cluster_scores": cluster_scores, |
|
|
"edge_index_c": ei_c, |
|
|
"edge_weight_c": ew_c, |
|
|
"num_clusters": torch.tensor([K]), |
|
|
} |
|
|
|
|
|
def get(self, graph_id: int): |
|
|
rec = self.per_graph[graph_id] |
|
|
return (rec["node2cluster"], rec["cluster_scores"], rec["edge_index_c"], rec["edge_weight_c"], |
|
|
int(rec["num_clusters"][0].item())) |
|
|
|
|
|
|
|
|
class GraphLRMCGCN(nn.Module): |
|
|
""" |
|
|
Batched version: |
|
|
- Run node-level GCN over batch graph (standard). |
|
|
- Up: per-graph scatter to cluster means; build a batched cluster-graph by offsetting cluster ids. |
|
|
- Cluster GCN over the batched cluster graph. |
|
|
- Down: broadcast cluster features back to nodes and residual. |
|
|
- Graph head: global mean pooling -> MLP. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim: int, hidden: int, num_classes: int, |
|
|
self_loop_scale: float = 0.0, use_a2: bool = False, dropout: float = 0.5): |
|
|
super().__init__() |
|
|
self.gcn_node = GCNConv(in_dim, hidden, add_self_loops=False, normalize=True) |
|
|
self.gcn_cluster = GCNConv(hidden, hidden, add_self_loops=True, normalize=True) |
|
|
self.down = nn.Linear(hidden, hidden) |
|
|
|
|
|
self.cls = nn.Linear(2 * hidden, num_classes) |
|
|
self.self_loop_scale = self_loop_scale |
|
|
self.use_a2 = use_a2 |
|
|
self.dropout = dropout |
|
|
self.gate = nn.Sequential( |
|
|
nn.Linear(2 * hidden, hidden // 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden // 2, 1) |
|
|
) |
|
|
self.lambda_logit = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
def forward(self, data: Data, provider: GraphLRMCProvider) -> Tensor: |
|
|
|
|
|
x, edge_index = data.x, data.edge_index |
|
|
num_nodes = x.size(0) |
|
|
|
|
|
|
|
|
ei = edge_index |
|
|
if self.use_a2: |
|
|
ei = adjacency_power(ei, num_nodes, k=2) |
|
|
ei, ew = add_scaled_self_loops(ei, None, num_nodes, scale=self.self_loop_scale) |
|
|
|
|
|
|
|
|
h1 = self.gcn_node(x, ei, ew) |
|
|
h1 = F.relu(h1) |
|
|
|
|
|
|
|
|
assert hasattr(data, 'gid'), "Each graph must carry a 'gid' attribute for provider lookup." |
|
|
gid = int(data.gid.view(-1)[0].item()) |
|
|
node2cluster_g, cluster_scores_g, edge_index_c, edge_weight_c, K = provider.get(gid) |
|
|
node2cluster_g = node2cluster_g.to(x.device) |
|
|
edge_index_c = edge_index_c.to(x.device) |
|
|
edge_weight_c = edge_weight_c.to(x.device) |
|
|
cluster_scores_g = cluster_scores_g.to(x.device) |
|
|
|
|
|
|
|
|
counts = torch.bincount(node2cluster_g, minlength=K).clamp(min=1).unsqueeze(-1) |
|
|
z = scatter_add(h1, node2cluster_g, dim=0, dim_size=K) / counts |
|
|
|
|
|
|
|
|
z2 = self.gcn_cluster(z, edge_index_c, edge_weight_c) |
|
|
z2 = F.relu(z2) |
|
|
|
|
|
|
|
|
z2_nodes = z2[node2cluster_g] |
|
|
inj = self.down(z2_nodes) |
|
|
gate_in = torch.cat([h1, inj], dim=-1) |
|
|
gate_dyn = torch.sigmoid(self.gate(gate_in)) |
|
|
|
|
|
alpha_seed = 0.25 + 0.75 * cluster_scores_g[node2cluster_g] |
|
|
lam = torch.sigmoid(self.lambda_logit) |
|
|
alpha = lam * alpha_seed + (1 - lam) * gate_dyn |
|
|
print(lam) |
|
|
h2 = h1 + alpha * inj |
|
|
|
|
|
|
|
|
h2 = F.dropout(h2, p=self.dropout, training=self.training) |
|
|
g_nodes = h2.mean(dim=0, keepdim=True) |
|
|
g_clust = z2.mean(dim=0, keepdim=True) |
|
|
g = torch.cat([g_nodes, g_clust], dim=-1) |
|
|
out = self.cls(g) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_node(task_ds: str, seeds_json: str, hidden=64, layers=1, epochs=300, |
|
|
lr=0.01, weight_decay=5e-4, dropout=0.5, self_loop_scale=0.0, use_a2=False, seed=0): |
|
|
torch.manual_seed(seed) |
|
|
ds = Planetoid(root=f"./data/{task_ds}", name=task_ds) |
|
|
data = ds[0] |
|
|
n, c_in, n_cls = data.num_nodes, ds.num_node_features, ds.num_classes |
|
|
|
|
|
node2cluster, cluster_scores = load_lrmc_seeds_single_graph(seeds_json, n) |
|
|
model = NodeLRMCGCN(c_in, hidden, n_cls, node2cluster, cluster_scores, data.edge_index, n, |
|
|
layers=layers, self_loop_scale=self_loop_scale, use_a2=use_a2, dropout=dropout).to('cpu') |
|
|
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
def step(): |
|
|
|
|
|
model.train() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
out_train = model(data.x) |
|
|
loss = F.cross_entropy(out_train[data.train_mask], data.y[data.train_mask]) |
|
|
loss.backward() |
|
|
opt.step() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
model.eval() |
|
|
out_eval = model(data.x) |
|
|
|
|
|
def acc(mask): |
|
|
pred = out_eval[mask].argmax(dim=1) |
|
|
pred_t = torch.as_tensor(pred) |
|
|
y_t = torch.as_tensor(data.y) |
|
|
return (pred_t == y_t[mask]).float().mean().item() |
|
|
|
|
|
return loss.item(), acc(data.train_mask), acc(data.val_mask), acc(data.test_mask) |
|
|
|
|
|
best_val, best_test = 0.0, 0.0 |
|
|
for ep in range(1, epochs + 1): |
|
|
loss, tr, va, te = step() |
|
|
if va > best_val: |
|
|
best_val, best_test = va, te |
|
|
if ep % 20 == 0: |
|
|
print(f"[{ep:04d}] loss={loss:.4f} train={tr:.3f} val={va:.3f} test={te:.3f} best_test={best_test:.3f}") |
|
|
print(f"Best val={best_val:.3f} test@best={best_test:.3f}") |
|
|
|
|
|
|
|
|
def train_graph(dataset_name: str, seeds_path: str, hidden=64, epochs=100, |
|
|
lr=0.001, weight_decay=1e-4, dropout=0.5, self_loop_scale=0.0, use_a2=False, seed=0): |
|
|
torch.manual_seed(seed) |
|
|
ds = TUDataset(root=f"./data/{dataset_name}", name=dataset_name) |
|
|
num_classes = ds.num_classes |
|
|
c_in = ds.num_node_features if ds.num_node_features > 0 else 1 |
|
|
|
|
|
|
|
|
graphs: List[Data] = [] |
|
|
for i, g in enumerate(ds): |
|
|
gc = g.clone() |
|
|
|
|
|
gc.gid = torch.tensor([i], dtype=torch.long) |
|
|
graphs.append(gc) |
|
|
|
|
|
|
|
|
if ds.num_node_features == 0: |
|
|
for g in graphs: |
|
|
deg = torch.bincount(g.edge_index[0], minlength=g.num_nodes).float().view(-1, 1) |
|
|
g.x = deg |
|
|
|
|
|
provider = GraphLRMCProvider(graphs, seeds_path) |
|
|
idx = torch.randperm(len(graphs)) |
|
|
ntrain = int(0.8 * len(ds)) |
|
|
nval = int(0.1 * len(ds)) |
|
|
|
|
|
train_ds = [graphs[i] for i in idx[:ntrain]] |
|
|
val_ds = [graphs[i] for i in idx[ntrain:ntrain + nval]] |
|
|
test_ds = [graphs[i] for i in idx[ntrain + nval:]] |
|
|
|
|
|
device = 'cpu' |
|
|
model = GraphLRMCGCN(c_in, hidden, num_classes, |
|
|
self_loop_scale=self_loop_scale, use_a2=use_a2, dropout=dropout).to(device) |
|
|
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(graph_list: List[Data]): |
|
|
model.eval() |
|
|
tot, correct = 0, 0 |
|
|
for g in graph_list: |
|
|
g = g.to(device) |
|
|
logits = model(g, provider) |
|
|
pred = logits.argmax(dim=1) |
|
|
pred_t = torch.as_tensor(pred) |
|
|
y_t = torch.as_tensor(g.y) |
|
|
correct += (pred_t == y_t).sum().item() |
|
|
tot += g.y.size(0) |
|
|
return correct / tot |
|
|
|
|
|
best_val, best_test = 0.0, 0.0 |
|
|
for ep in range(1, epochs + 1): |
|
|
model.train() |
|
|
for g in train_ds: |
|
|
g = g.to(device) |
|
|
opt.zero_grad(set_to_none=True) |
|
|
logits = model(g, provider) |
|
|
loss = F.cross_entropy(logits, g.y) |
|
|
loss.backward() |
|
|
opt.step() |
|
|
if ep % 5 == 0: |
|
|
va = evaluate(val_ds) |
|
|
te = evaluate(test_ds) |
|
|
if va > best_val: |
|
|
best_val, best_test = va, te |
|
|
print(f"[{ep:03d}] val={va:.3f} test={te:.3f} best_test@val={best_test:.3f}") |
|
|
print(f"Best val={best_val:.3f} test@best={best_test:.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument("--task", choices=["node", "graph"], required=True) |
|
|
p.add_argument("--dataset", required=True, help="Cora/Citeseer/Pubmed or DD/PROTEINS/COLLAB/ENZYMES") |
|
|
p.add_argument("--seeds", required=True, help="Path to seeds JSON (node task) or dir/single JSON (graph task)") |
|
|
p.add_argument("--hidden", type=int, default=64) |
|
|
p.add_argument("--layers", type=int, default=1) |
|
|
p.add_argument("--epochs", type=int, default=300) |
|
|
p.add_argument("--batch_size", type=int, default=64) |
|
|
p.add_argument("--lr", type=float, default=0.01) |
|
|
p.add_argument("--wd", type=float, default=5e-4) |
|
|
p.add_argument("--dropout", type=float, default=0.5) |
|
|
p.add_argument("--self_loop_scale", type=float, default=0.0, help="use 2.0 to mimic A+2I") |
|
|
p.add_argument("--use_a2", action="store_true", help="use A^2 connectivity augmentation") |
|
|
p.add_argument("--seed", type=int, default=0) |
|
|
args = p.parse_args() |
|
|
|
|
|
if args.task == "node": |
|
|
for i in range(42, 60): |
|
|
train_node(args.dataset, args.seeds, hidden=args.hidden, layers=args.layers, epochs=args.epochs, lr=args.lr, |
|
|
weight_decay=args.wd, dropout=args.dropout, self_loop_scale=args.self_loop_scale, |
|
|
use_a2=args.use_a2, seed=i) |
|
|
else: |
|
|
for i in range(42, 60): |
|
|
train_graph(args.dataset, args.seeds, hidden=args.hidden, epochs=max(100, args.epochs), |
|
|
lr=min(args.lr, 0.001), weight_decay=args.wd, dropout=args.dropout, |
|
|
self_loop_scale=args.self_loop_scale, use_a2=args.use_a2, seed=i) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|