# gcn_lrmc_node_classify.py # Node classification with GCN + L-RMC (static pooling + unpool + skip) # Usage: # python gcn_lrmc_node_classify.py --dataset Cora --lrmc_json /path/to/lrmc_seeds.json # Options: # --use_a2 true|false (default true; use A^2 before pooling as in Graph U-Nets) # --epochs 200 --lr 0.005 --hidden 64 --cluster_hidden 64 --dropout 0.5 import argparse, json, os import numpy as np, torch import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import Adam from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv, BatchNorm from torch_geometric.utils import coalesce, to_undirected, remove_self_loops from torch_geometric.utils import add_self_loops from torch_scatter import scatter_mean from torch_sparse import spspmm # ----------------------------- # L-RMC assignment utilities # ----------------------------- def load_lrmc_assignment(json_path, num_nodes): """ Build a single hard assignment: node -> cluster_id in [0, K-1]. If nodes appear in multiple clusters, keep the one with highest 'score'. If any nodes are unassigned, put them into their own singleton clusters. Returns: assignment: LongTensor [num_nodes] with cluster ids clusters: list of lists (members per cluster) aligned to remapped cluster ids """ with open(json_path, 'r') as f: seeds = json.load(f) clusters_raw = seeds.get("clusters", []) # Sort clusters by score descending to prefer higher-scoring clusters on conflicts clusters_raw = sorted(clusters_raw, key=lambda c: float(c.get("score", 0.0)), reverse=True) chosen_cluster_for_node = [-1] * num_nodes tmp_clusters = [] # will collect chosen clusters (members), before remap for c in clusters_raw: members = c.get("members", []) # skip empty if not members: continue # take only members not yet assigned new_members = [u for u in members if 0 <= u < num_nodes and chosen_cluster_for_node[u] == -1] if not new_members: continue # tentatively assign this cluster to those nodes (others in the cluster were already taken) tmp_clusters.append(new_members) cid = len(tmp_clusters) - 1 for u in new_members: chosen_cluster_for_node[u] = cid # Any nodes still -1 → singleton clusters for u in range(num_nodes): if chosen_cluster_for_node[u] == -1: tmp_clusters.append([u]) cid = len(tmp_clusters) - 1 chosen_cluster_for_node[u] = cid # Remap cluster ids to [0..K-1] (already contiguous by construction) assignment = torch.tensor(chosen_cluster_for_node, dtype=torch.long) clusters = tmp_clusters return assignment, clusters def lrmc_stats(assignment, clusters, edge_index): N = assignment.numel(); K = int(assignment.max()) + 1 sizes = [len(c) for c in clusters] sing = sum(1 for s in sizes if s==1) print(f"[L-RMC] N={N} K={K} mean|C|={np.mean(sizes):.2f} " f"median|C|={np.median(sizes):.0f} singleton%={100*sing/K:.1f}%") # how many edges are intra-cluster? same = (assignment[edge_index[0]] == assignment[edge_index[1]]).sum().item() print(f"[L-RMC] intra-cluster edge ratio = {same/edge_index.size(1):.3f}") # ----------------------------- # Graph helpers # ----------------------------- def compute_A2_union(edge_index, num_nodes, device): """ Compute A^2 (binary) and return union edges A OR A^2, undirected & coalesced. """ # Make undirected and coalesced (no weights) ei = to_undirected(coalesce(edge_index, num_nodes=num_nodes), num_nodes=num_nodes) # Build ones weights for sparse-sparse multiply E = ei.size(1) if E == 0: return ei # empty graph val = torch.ones(E, device=device) # spspmm: (m x k) @ (k x n) where here m=n=k=num_nodes ei2, val2 = spspmm(ei, val, ei, val, num_nodes, num_nodes, num_nodes) # Remove self-loops from A2 (optional; GCNConv adds its own self-loops later) ei2, _ = remove_self_loops(ei2) # Binarize & union with A # (coalesce later will drop duplicates anyway) ei_aug = torch.cat([ei, ei2], dim=1) ei_aug = to_undirected(coalesce(ei_aug, num_nodes=num_nodes), num_nodes=num_nodes) return ei_aug def build_cluster_edges(edge_index_aug, assignment, num_clusters): """ Map node edges to cluster edges: (u,v) -> (c(u), c(v)), undirected + coalesced. """ c_src = assignment[edge_index_aug[0]] c_dst = assignment[edge_index_aug[1]] c_ei = torch.stack([c_src, c_dst], dim=0) c_ei = to_undirected(coalesce(c_ei, num_nodes=num_clusters), num_nodes=num_clusters) return c_ei # ----------------------------- # Model # ----------------------------- class Gate(nn.Module): def __init__(self, d_enc, d_c): super().__init__() self.mlp = nn.Sequential( nn.Linear(d_enc + d_c, d_enc, bias=True), nn.ReLU(), nn.Linear(d_enc, d_enc, bias=True), nn.Sigmoid(), ) def forward(self, h_enc, h_cluster_broadcast): g = self.mlp(torch.cat([h_enc, h_cluster_broadcast], dim=-1)) return h_enc + g * h_cluster_broadcast # residual gated add class GCN_LRMC_NodeClassifier(nn.Module): """ Encoder: GCN -> GCN on original graph Pool: aggregate encoder features per L-RMC cluster Coarse: GCN -> (optional GCN) on cluster graph Unpool: broadcast cluster features back to nodes Decoder: GCN (on original graph) -> logits """ def __init__(self, in_dim, hidden_dim, cluster_hidden_dim, out_dim, edge_index, assignment, cluster_edge_index, dropout=0.5): super().__init__() self.edge_index = edge_index # original graph edges self.assignment = assignment # [N] self.cluster_edge_index = cluster_edge_index # edges on cluster graph self.num_clusters = int(assignment.max().item() + 1) self.dropout = dropout # Encoder on node graph self.enc1 = GCNConv(in_dim, hidden_dim, improved=True) self.enc2 = GCNConv(hidden_dim, hidden_dim, improved=True) # GCN(s) on cluster graph self.cgc1 = GCNConv(hidden_dim, cluster_hidden_dim, improved=True) self.cgc2 = GCNConv(cluster_hidden_dim, cluster_hidden_dim, improved=True) # Decoder on node graph (combine skip from encoder + broadcast from cluster) dec_in = hidden_dim + cluster_hidden_dim self.dec1 = GCNConv(dec_in, hidden_dim, improved=True) self.cls = GCNConv(hidden_dim, out_dim, improved=True) # final logits self.bn_e1 = BatchNorm(hidden_dim) self.bn_e2 = BatchNorm(hidden_dim) self.bn_c1 = BatchNorm(cluster_hidden_dim) self.bn_c2 = BatchNorm(cluster_hidden_dim) self.bn_d1 = BatchNorm(hidden_dim) self.gate = Gate(hidden_dim, cluster_hidden_dim) def forward(self, x): # Encoder on original graph h = F.dropout(x, p=self.dropout, training=self.training) h = F.relu(self.bn_e1(self.enc1(h, self.edge_index))) h = F.dropout(h, p=self.dropout, training=self.training) h2 = F.relu(self.bn_e2(self.enc2(h, self.edge_index))) h = h + h2 h_enc = h # skip for decoder # Pool: aggregate encoder features to clusters (mean) # cluster_x: [K, hidden_dim] cluster_x = scatter_mean(h_enc, self.assignment, dim=0, dim_size=self.num_clusters) # Coarse GCN(s) on cluster graph hc = F.dropout(cluster_x, p=self.dropout, training=self.training) hc = F.relu(self.bn_c1(self.cgc1(cluster_x, self.cluster_edge_index))) hc = F.dropout(hc, p=self.dropout, training=self.training) hc2 = F.relu(self.bn_c2(self.cgc2(hc, self.cluster_edge_index))) hc = hc + hc2 # Unpool: broadcast coarse features back to nodes via assignment hc_broadcast = hc[self.assignment] # [N, cluster_hidden_dim] # # after hc_broadcast is computed # g_in = torch.cat([h_enc, hc_broadcast], dim=1) # gate = torch.sigmoid(nn.Linear(g_in.size(1), h_enc.size(1)).to(g_in.device)(g_in)) # h_dec_in = h_enc + gate * hc_broadcast # gated residual instead of concat # Decoder on original graph h_dec_in = torch.cat([h_enc, hc_broadcast], dim=1) # [N, hidden_dim + cluster_hidden_dim] h = F.dropout(h_dec_in, p=self.dropout, training=self.training) h = F.relu(self.dec1(h, self.edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.cls(h, self.edge_index) # logits [N, C] return out # ----------------------------- # Train / Eval # ----------------------------- @torch.no_grad() def evaluate(model, data): model.eval() out = model(data.x) y = data.y pred = out.argmax(dim=-1) def acc(mask): m = mask if mask.dtype == torch.bool else mask.bool() if m.sum() == 0: return 0.0 return (pred[m] == y[m]).float().mean().item() return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask) def train_loop(model, data, epochs=200, lr=5e-3, weight_decay=5e-4, patience=100): optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) best_val, best_test = 0.0, 0.0 best_state = None no_improve = 0 for epoch in range(1, epochs + 1): model.train() optimizer.zero_grad() logits = model(data.x) loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() tr, va, te = evaluate(model, data) if va > best_val: best_val, best_test = va, te best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} no_improve = 0 else: no_improve += 1 print(f"Epoch {epoch:03d} | loss={loss.item():.4f} | " f"train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}% test@best={best_test*100:.2f}%") if no_improve >= patience: print(f"Early stopping at epoch {epoch} (no val improvement for {patience})") break if best_state is not None: model.load_state_dict(best_state) tr, va, te = evaluate(model, data) print(f"\nFinal (reloaded best): train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}%") return te # ----------------------------- # Main # ----------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "Citeseer", "Pubmed"]) parser.add_argument("--lrmc_json", type=str, required=True) parser.add_argument("--use_a2", type=str, default="true", help="Use A^2 before pooling (true/false)") parser.add_argument("--hidden", type=int, default=64) parser.add_argument("--cluster_hidden", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--lr", type=float, default=5e-3) parser.add_argument("--weight_decay", type=float, default=5e-4) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = Planetoid(root=os.path.join("data", args.dataset), name=args.dataset) data = dataset[0].to(device) num_nodes = data.num_nodes in_dim = dataset.num_node_features out_dim = dataset.num_classes # Load L-RMC assignment assignment, clusters = load_lrmc_assignment(args.lrmc_json, num_nodes) assignment = assignment.to(device) num_clusters = int(assignment.max().item() + 1) print(f"[L-RMC] Loaded clusters: K={num_clusters} (N={num_nodes})") lrmc_stats(assignment, clusters, data.edge_index) # Build augmented node edge_index (A or A^2 ∪ A), then cluster edges use_a2 = args.use_a2.lower() in ("1", "true", "yes", "y") if use_a2: edge_index_aug = compute_A2_union(data.edge_index, num_nodes, device) print("[L-RMC] Using A^2 ∪ A before pooling (connectivity augmentation).") else: edge_index_aug = to_undirected(coalesce(data.edge_index, num_nodes=num_nodes), num_nodes=num_nodes) print("[L-RMC] Using original A for pooling.") cluster_edge_index = build_cluster_edges(edge_index_aug, assignment, num_clusters) # Build model model = GCN_LRMC_NodeClassifier( in_dim=in_dim, hidden_dim=args.hidden, cluster_hidden_dim=args.cluster_hidden, out_dim=out_dim, edge_index=data.edge_index, # original graph for enc/dec assignment=assignment, # node -> cluster cluster_edge_index=cluster_edge_index, # cluster graph for coarse GCN dropout=args.dropout, ).to(device) # Train / evaluate test_acc = train_loop(model, data, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, patience=100) if __name__ == "__main__": main()