# lrmc_bilevel.py # Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora) # Requires: torch, torch_geometric, torch_scatter, torch_sparse # # Usage examples: # python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline # python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant pool --lrmc_inv_weight 0.01 --lrmc_gamma 0.7 # # Notes: # - We read your LRMC JSON, pick the single cluster with the highest 'score', # assign it to cluster id 0, and make all other nodes singletons (1..K-1). # - For --variant pool: Node-GCN -> pool (means) -> Cluster-GCN -> broadcast + skip -> Node-GCN -> classifier # This variant now also applies the L-RMC stability tricks. # - For --variant baseline: Standard 2-layer GCN. # - Keep flags like --self_loop_scale and --use_a2 if you want A+λI / A^2 augmentation. import argparse, json from pathlib import Path from typing import List, Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch_scatter import scatter_add, scatter_mean from torch_sparse import coalesce, spspmm from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.utils import subgraph, degree # Added for stability score # --------------------------- # Utilities: edges and seeds # --------------------------- def add_scaled_self_loops(edge_index: Tensor, edge_weight: Optional[Tensor], num_nodes: int, scale: float = 1.0) -> Tuple[Tensor, Tensor]: """Add self-loops with chosen weight (scale). If scale=0, return unchanged (and create weights if None).""" if scale == 0.0: if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) return edge_index, edge_weight device = edge_index.device self_loops = torch.arange(num_nodes, device=device) self_index = torch.stack([self_loops, self_loops], dim=0) self_weight = torch.full((num_nodes,), float(scale), device=device) base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device) ei = torch.cat([edge_index, self_index], dim=1) ew = torch.cat([base_w, self_weight], dim=0) ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add') return ei, ew def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor: """ Compute (binary) k-th power adjacency using sparse matmul (torch_sparse.spspmm). Here we use k=2. Returns coalesced edge_index without weights. """ row, col = edge_index val = torch.ones(row.numel(), device=edge_index.device) Ai, Av = edge_index, val # A^2 Ri, Rv = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes) mask = Ri[0] != Ri[1] # drop diagonal; add custom self-loops later if desired Ri = Ri[:, mask] Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add') return Ri def build_cluster_graph(edge_index: Tensor, num_nodes: int, node2cluster: Tensor, weight_per_edge: Optional[Tensor] = None, num_clusters: Optional[int] = None ) -> Tuple[Tensor, Tensor, int]: """ Build cluster graph A_c = S^T A S with summed multiplicities as weights. node2cluster: [N] long tensor mapping each node -> cluster id. """ K = int(node2cluster.max().item()) + 1 if num_clusters is None else num_clusters src, dst = edge_index csrc = node2cluster[src] cdst = node2cluster[dst] edge_c = torch.stack([csrc, cdst], dim=0) w = weight_per_edge if weight_per_edge is not None else torch.ones(edge_c.size(1), device=edge_c.device) edge_c, w = coalesce(edge_c, w, K, K, op='add') # sum multiplicities return edge_c, w, K # ----- # Seeds # ----- def _pick_top1_cluster(obj: dict) -> List[int]: """ From LRMC JSON with structure: {"clusters":[{"members":[...], "score":float, ...}, ...]} choose the cluster with max (score, size) and return its members. """ clusters = obj.get("clusters", []) if not clusters: return [] # choose by highest score, then by size (tiebreaker) best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("members", [])))) return list(best.get("members", [])) def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor, Tensor]: """ Create a hard assignment for top-1 LRMC cluster: - cluster 0 = top-1 LRMC set - nodes outside are singletons (1..K-1) Returns: node2cluster: [N] long cluster_scores: [K,1] with 1.0 for top cluster, 0.0 for singletons core_nodes: [|C|] long, original indices of nodes in the top-1 LRMC cluster """ obj = json.loads(Path(seeds_json).read_text()) C_star_list = _pick_top1_cluster(obj) C_star = torch.tensor(sorted(set(C_star_list)), dtype=torch.long) # Original indices of core nodes node2cluster = torch.full((n_nodes,), -1, dtype=torch.long) node2cluster[C_star] = 0 outside = torch.tensor(sorted(set(range(n_nodes)) - set(C_star.tolist())), dtype=torch.long) if outside.numel() > 0: node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long) assert int(node2cluster.min()) >= 0, "All nodes must be assigned." K = 1 + outside.numel() cluster_scores = torch.zeros(K, 1, dtype=torch.float32) if C_star.numel() > 0: cluster_scores[0, 0] = 1.0 # emphasize the supercluster return node2cluster, cluster_scores, C_star # -------------------------- # Models (baseline + pooled) # -------------------------- class GCN2(nn.Module): """Plain 2-layer GCN baseline.""" def __init__(self, in_dim, hid, out_dim, dropout_p: float = 0.5): super().__init__() self.conv1 = GCNConv(in_dim, hid) self.conv2 = GCNConv(hid, out_dim) self.dropout_p = dropout_p def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout_p, training=self.training) x = self.conv2(x, edge_index) return x class OneClusterPool(nn.Module): """ Node-GCN -> pool to one-cluster + singletons -> Cluster-GCN -> broadcast + skip -> Node-GCN -> classifier This version includes L-RMC stability tricks: 1. Backbone-invariance regularizer (loss computed in forward). 2. Boundary damping on node graph. """ def __init__(self, in_dim: int, hid: int, out_dim: int, node2cluster: Tensor, core_nodes: Tensor, # New: explicit core nodes edge_index_node: Tensor, num_nodes: int, self_loop_scale: float = 0.0, use_a2_for_clusters: bool = False, lrmc_gamma: float = 1.0, # New: damping factor (1.0 means no damping) dropout_p: float = 0.5): super().__init__() self.n2c = node2cluster.long() self.K = int(self.n2c.max().item()) + 1 self.core_nodes = core_nodes # Store original indices of core nodes self.lrmc_gamma = lrmc_gamma self.dropout_p = dropout_p # Node graph (A + λI if desired) ei_node = edge_index_node ew_node_base = None # Will be created by add_scaled_self_loops if None ei_node, ew_node = add_scaled_self_loops(ei_node, ew_node_base, num_nodes, scale=self_loop_scale) # --- Apply Boundary Damping --- if self.lrmc_gamma < 1.0 and self.core_nodes.numel() > 0: is_core = torch.zeros(num_nodes, dtype=torch.bool, device=ei_node.device) is_core[self.core_nodes] = True src_is_core = is_core[ei_node[0]] dst_is_core = is_core[ei_node[1]] cross_boundary_mask = (src_is_core != dst_is_core) ew_node[cross_boundary_mask] *= self.lrmc_gamma # --- End Boundary Damping --- self.register_buffer("edge_index_node", ei_node) self.register_buffer("edge_weight_node", ew_node) # Cluster graph from A or A^2 ei_for_c = adjacency_power(edge_index_node, num_nodes, k=2) if use_a2_for_clusters else edge_index_node edge_index_c, edge_weight_c, K = build_cluster_graph(ei_for_c, num_nodes, self.n2c) self.register_buffer("edge_index_c", edge_index_c) self.register_buffer("edge_weight_c", edge_weight_c) self.K = K # Layers self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True) self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True) self.gcn_node2 = GCNConv(hid * 2, out_dim) # on concatenated [h_node, h_broadcast] def forward(self, x: Tensor, edge_index_node: Tensor) -> Tuple[Tensor, Optional[Tensor]]: # Node GCN (uses stored weights) h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node)) h1 = F.dropout(h1, p=self.dropout_p, training=self.training) # Consistent dropout # --- Backbone-invariance regularizer --- lrmc_inv_loss = None # Apply only if core nodes exist AND regularizer weight is positive (handled by run_train_eval) if self.core_nodes.numel() > 0: core_embeddings = h1[self.core_nodes] # Calculate mean embedding of the core, keepdim=True for broadcasting avg_embedding = core_embeddings.mean(dim=0, keepdim=True) # MSE between each core embedding and the average core embedding lrmc_inv_loss = F.mse_loss(core_embeddings, avg_embedding.expand_as(core_embeddings), reduction='mean') # --- End Backbone-invariance regularizer --- # Pool to clusters: mean per cluster z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) # [K, H] # Cluster GCN z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c)) # Broadcast back + skip concat hb = z2[self.n2c] # [N, H] hcat = torch.cat([h1, hb], dim=1) # [N, 2H] # Final node GCN head -> logits out = self.gcn_node2(hcat, edge_index_node) return out, lrmc_inv_loss # ------------- # Training glue # ------------- @torch.no_grad() def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: pred = logits[mask].argmax(dim=1) return (pred == y[mask]).float().mean().item() def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, lrmc_inv_weight: float = 0.0): opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) best_val, best_state = 0.0, None for ep in range(1, epochs + 1): model.train() opt.zero_grad(set_to_none=True) # Model output depends on its type res = model(data.x, data.edge_index) current_lrmc_inv_loss = None if isinstance(model, OneClusterPool): # OneClusterPool returns (logits, lrmc_inv_loss) logits, current_lrmc_inv_loss = res loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) if current_lrmc_inv_loss is not None and lrmc_inv_weight > 0: loss += lrmc_inv_weight * current_lrmc_inv_loss else: # GCN2 (baseline) returns only logits logits = res loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) loss.backward(); opt.step() # track best on val model.eval() # For evaluation, we only need logits. If it's OneClusterPool, ignore the loss. logits_eval, _ = model(data.x, data.edge_index) if isinstance(model, OneClusterPool) else (model(data.x, data.edge_index), None) val_acc = accuracy(logits_eval, data.y, data.val_mask) if val_acc > best_val: best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()} if ep % 20 == 0: tr = accuracy(logits_eval, data.y, data.train_mask) te = accuracy(logits_eval, data.y, data.test_mask) lrmc_loss_str = f" inv_l={current_lrmc_inv_loss.item():.4f}" if current_lrmc_inv_loss is not None else "" print(f"[{ep:04d}] loss={loss.item():.4f}{lrmc_loss_str} train={tr:.3f} val={val_acc:.3f} test={te:.3f}") # test @ best val if best_state is not None: model.load_state_dict(best_state) model.eval() logits_final, _ = model(data.x, data.edge_index) if isinstance(model, OneClusterPool) else (model(data.x, data.edge_index), None) return { "val": accuracy(logits_final, data.y, data.val_mask), "test": accuracy(logits_final, data.y, data.test_mask) } # -------------------------- # LRMC Stability Score (new) # -------------------------- def compute_lrmc_stability_score(core_nodes: Tensor, edge_index: Tensor, num_nodes: int, epsilon: float = 1e-6) -> float: """ Computes the L-RMC stability score S_L(C) = |C| / (d^T L_C d + epsilon). Here, d^T L_C d is interpreted as sum of squared degree differences over edges within the core: sum_{(u,v) in E_C} (deg_C(u) - deg_C(v))^2. """ if core_nodes.numel() == 0: return 0.0 # Get the induced subgraph for core_nodes, relabeling nodes to [0, ..., |C|-1] # `subgraph` returns (sub_edge_index, edge_attr, mapping_nodes, edge_mask) sub_edge_index, _, _, _ = subgraph(core_nodes, edge_index, relabel_nodes=True, num_nodes=num_nodes) num_core_nodes = core_nodes.numel() if sub_edge_index.numel() == 0: # If the core has nodes but no internal edges, there's no degree variability over edges. # d^T L_C d = 0 in this interpretation, leading to max stability score. return float(num_core_nodes) / epsilon # Compute degrees of nodes *within the induced subgraph* # `degree` sums the occurrences of nodes in `sub_edge_index[0]` (sources). # For undirected graphs, total degree is (in_degree + out_degree). # Since `sub_edge_index` contains both (u,v) and (v,u) if relabel_nodes=False, we just need one pass. # With relabel_nodes=True, it's a new adjacency matrix for `num_core_nodes`. # Summing degrees from both source and destination to get total degree in undirected graph. degrees_in_subgraph_full = degree(sub_edge_index[0], num_nodes=num_core_nodes, dtype=torch.float) + \ degree(sub_edge_index[1], num_nodes=num_core_nodes, dtype=torch.float) # Calculate sum of squared degree differences for each edge within the subgraph # deg_u_relabel and deg_v_relabel are the degrees of source and destination nodes # (after relabeling) within the induced subgraph. deg_u_relabel = degrees_in_subgraph_full[sub_edge_index[0]] deg_v_relabel = degrees_in_subgraph_full[sub_edge_index[1]] degree_variability_sum = torch.sum((deg_u_relabel - deg_v_relabel)**2) # Compute S_L(C) score = float(num_core_nodes) / (degree_variability_sum.item() + epsilon) return score # ----------- # Entrypoint # ----------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).") ap.add_argument("--variant", choices=["baseline", "pool"], default="pool", help="baseline=plain GCN; pool=top-1 LRMC one-cluster pooling (with new L-RMC tricks)") ap.add_argument("--hidden", type=int, default=128) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=0.01) ap.add_argument("--wd", type=float, default=5e-4) ap.add_argument("--dropout", type=float, default=0.5, help="Dropout rate for GCN layers.") ap.add_argument("--self_loop_scale", type=float, default=0.0, help="λ for A+λI on node graph (0 disables)") ap.add_argument("--use_a2", action="store_true", help="Use A^2 to build the cluster graph (recommended for pool)") ap.add_argument("--lrmc_inv_weight", type=float, default=0.0, help="Weight for the backbone-invariance regularizer (0 disables).") ap.add_argument("--lrmc_gamma", type=float, default=1.0, help="Damping factor for cross-boundary edges (1.0 means no damping).") ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() torch.manual_seed(args.seed) # Load dataset ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) data = ds[0] in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes if args.variant == "baseline": model = GCN2(in_dim, args.hidden, out_dim, dropout_p=args.dropout) # use default add_self_loops=True behavior inside convs res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}") return # Top-1 LRMC assignment node2cluster, _, core_nodes = load_top1_assignment(args.seeds, n) # For informational purposes, compute and print the stability score for the found core lrmc_score = compute_lrmc_stability_score(core_nodes, data.edge_index, n) print(f"LRMC Top-1 Cluster Size: {core_nodes.numel()} nodes. Stability Score (S_L): {lrmc_score:.4f}") # One-cluster pooled model with L-RMC tricks model = OneClusterPool(in_dim=in_dim, hid=args.hidden, out_dim=out_dim, node2cluster=node2cluster, core_nodes=core_nodes, # Pass core nodes edge_index_node=data.edge_index, num_nodes=n, self_loop_scale=args.self_loop_scale, use_a2_for_clusters=args.use_a2, lrmc_gamma=args.lrmc_gamma, # Pass damping factor dropout_p=args.dropout) # Pass dropout rate res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd, lrmc_inv_weight=args.lrmc_inv_weight) # Pass regularizer weight print(f"L-RMC (top-1 pool with tricks): val={res['val']:.4f} test={res['test']:.4f}") if __name__ == "__main__": main()