# 2.1_lrmc_bilevel.py # Top-1 LRMC ablation with debug guards so seeds differences are visible. # Requires: torch, torch_geometric, torch_scatter, torch_sparse import argparse, json, hashlib from pathlib import Path from typing import List, Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch_scatter import scatter_mean from torch_sparse import coalesce, spspmm from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv # --------------------------- # Utilities: edges and seeds # --------------------------- def add_scaled_self_loops(edge_index: Tensor, edge_weight: Optional[Tensor], num_nodes: int, scale: float = 1.0) -> Tuple[Tensor, Tensor]: if scale == 0.0: if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) return edge_index, edge_weight device = edge_index.device self_loops = torch.arange(num_nodes, device=device) self_index = torch.stack([self_loops, self_loops], dim=0) self_weight = torch.full((num_nodes,), float(scale), device=device) base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device) ei = torch.cat([edge_index, self_index], dim=1) ew = torch.cat([base_w, self_weight], dim=0) ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add') return ei, ew def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor: # A^2 using spspmm; return binary, coalesced, no self loops row, col = edge_index val = torch.ones(row.numel(), device=edge_index.device) Ai, Av = edge_index, val Ri, _ = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes) mask = Ri[0] != Ri[1] Ri = Ri[:, mask] Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add') return Ri def build_cluster_graph(edge_index: Tensor, num_nodes: int, node2cluster: Tensor, weight_per_edge: Optional[Tensor] = None, num_clusters: Optional[int] = None ) -> Tuple[Tensor, Tensor, int]: K = int(node2cluster.max().item()) + 1 if num_clusters is None else num_clusters src, dst = edge_index csrc = node2cluster[src] cdst = node2cluster[dst] edge_c = torch.stack([csrc, cdst], dim=0) w = weight_per_edge if weight_per_edge is not None else torch.ones(edge_c.size(1), device=edge_c.device) edge_c, w = coalesce(edge_c, w, K, K, op='add') return edge_c, w, K # ----- # Seeds # ----- def _md5(path: Path) -> str: h = hashlib.md5() with path.open('rb') as f: for chunk in iter(lambda: f.read(8192), b''): h.update(chunk) return h.hexdigest() def _extract_members(cluster_obj: dict) -> List[int]: """ Try 'members' first, then 'seed_nodes'. Raise if neither works. """ m = cluster_obj.get("members", None) if isinstance(m, list) and len(m) > 0: return list(dict.fromkeys(int(x) for x in m)) # dedupe/preserve order m2 = cluster_obj.get("seed_nodes", None) if isinstance(m2, list) and len(m2) > 0: return list(dict.fromkeys(int(x) for x in m2)) # If both present but empty, return empty; caller will handle. if isinstance(m, list) or isinstance(m2, list): return [] raise KeyError("Cluster object has neither 'members' nor 'seed_nodes'.") def _pick_top1_cluster(obj: dict) -> List[int]: """ From {"clusters":[{..., "score":float, "members" or "seed_nodes"}, ...]}, choose max by (score, size). Returns deduped member list. """ clusters = obj.get("clusters", []) if not isinstance(clusters, list) or len(clusters) == 0: return [] def keyfun(c): score = float(c.get("score", 0.0)) try: mem = _extract_members(c) except KeyError: mem = [] return (score, len(mem)) best = max(clusters, key=keyfun) try: members = _extract_members(best) except KeyError: members = [] return sorted(set(int(x) for x in members)) def load_top1_assignment(seeds_json: str, n_nodes: int, debug: bool = False) -> Tuple[Tensor, Tensor, dict]: """ Hard assignment for top-1 LRMC cluster: cluster 0 = top cluster; others are singletons. Returns node2cluster[N], cluster_scores[K,1], and a small debug dict. """ p = Path(seeds_json) text = p.read_text(encoding='utf-8') obj = json.loads(text) C_star = _pick_top1_cluster(obj) # if len(C_star) > 0 and max(C_star) == n_nodes: # Looks 1-indexed (since max == N, not N-1) → shift down by 1 C_star = [u - 1 for u in C_star] C_star = torch.tensor(C_star, dtype=torch.long) # C_star = _pick_top1_cluster(obj) # C_star = torch.tensor(C_star, dtype=torch.long) node2cluster = torch.full((n_nodes,), -1, dtype=torch.long) if C_star.numel() == 0: # FAIL LOUDLY instead of silently falling back to identity raise RuntimeError( f"No members found for top-1 cluster in {seeds_json}. " f"Expected 'members' or 'seed_nodes' to be non-empty." ) node2cluster[C_star] = 0 outside = torch.tensor(sorted(set(range(n_nodes)) - set(C_star.tolist())), dtype=torch.long) if outside.numel() > 0: node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long) assert int(node2cluster.min()) >= 0 K = 1 + outside.numel() cluster_scores = torch.zeros(K, 1, dtype=torch.float32) cluster_scores[0, 0] = 1.0 info = { "json_md5": _md5(p), "top_cluster_size": int(C_star.numel()), "K": int(K), "n_outside": int(outside.numel()), "first_members": [int(x) for x in C_star[:10].tolist()], } if debug: print(f"[LRMC] Loaded {seeds_json} (md5={info['json_md5']}) | " f"top_size={info['top_cluster_size']} K={info['K']} outside={info['n_outside']} " f"first10={info['first_members']}") return node2cluster, cluster_scores, info # -------------------------- # Models (baseline + pooled) # -------------------------- class GCN2(nn.Module): def __init__(self, in_dim, hid, out_dim, dropout=0.5): super().__init__() self.conv1 = GCNConv(in_dim, hid) self.conv2 = GCNConv(hid, out_dim) self.dropout = dropout def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) return x class OneClusterPool(nn.Module): def __init__(self, in_dim: int, hid: int, out_dim: int, node2cluster: Tensor, edge_index_node: Tensor, num_nodes: int, self_loop_scale: float = 0.0, use_a2_for_clusters: bool = False, debug_header: str = ""): super().__init__() self.n2c = node2cluster.long() self.K = int(self.n2c.max().item()) + 1 # Node graph (A + λI if desired) ei_node = edge_index_node ei_node, ew_node = add_scaled_self_loops(ei_node, None, num_nodes, scale=self_loop_scale) self.register_buffer("edge_index_node", ei_node) self.register_buffer("edge_weight_node", ew_node) # Cluster graph ei_for_c = adjacency_power(edge_index_node, num_nodes, k=2) if use_a2_for_clusters else edge_index_node edge_index_c, edge_weight_c, K = build_cluster_graph(ei_for_c, num_nodes, self.n2c) self.register_buffer("edge_index_c", edge_index_c) self.register_buffer("edge_weight_c", edge_weight_c) self.K = K if debug_header: print(f"[POOL] {debug_header} | cluster_edges={edge_index_c.size(1)} (K={K})") # Layers self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True) self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True) self.gcn_node2 = GCNConv(hid * 2, out_dim) # concat [h_node, h_broadcast] def forward(self, x: Tensor, edge_index_node: Tensor) -> Tensor: h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node)) z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) # [K, H] z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c)) hb = z2[self.n2c] # [N, H] hcat = torch.cat([h1, hb], dim=1) # [N, 2H] out = self.gcn_node2(hcat, edge_index_node) return out # ------------- # Training glue # ------------- @torch.no_grad() def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: pred = logits[mask].argmax(dim=1) return (pred == y[mask]).float().mean().item() def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4): opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) best_val, best_state = 0.0, None for ep in range(1, epochs + 1): model.train() opt.zero_grad(set_to_none=True) logits = model(data.x, data.edge_index) loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) loss.backward(); opt.step() model.eval() logits = model(data.x, data.edge_index) val_acc = accuracy(logits, data.y, data.val_mask) if val_acc > best_val: best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()} if ep % 20 == 0: tr = accuracy(logits, data.y, data.train_mask) te = accuracy(logits, data.y, data.test_mask) print(f"[{ep:04d}] loss={loss.item():.4f} train={tr:.3f} val={val_acc:.3f} test={te:.3f}") if best_state is not None: model.load_state_dict(best_state) model.eval() logits = model(data.x, data.edge_index) return {"val": accuracy(logits, data.y, data.val_mask), "test": accuracy(logits, data.y, data.test_mask)} # ----------- # Entrypoint # ----------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).") ap.add_argument("--variant", choices=["baseline", "pool"], default="pool") ap.add_argument("--hidden", type=int, default=128) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=0.01) ap.add_argument("--wd", type=float, default=5e-4) ap.add_argument("--dropout", type=float, default=0.5) # baseline only ap.add_argument("--self_loop_scale", type=float, default=0.0) ap.add_argument("--use_a2", action="store_true", help="Use A^2 for the cluster graph.") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--debug", action="store_true", help="Print seeds md5, cluster size, K, etc.") args = ap.parse_args() torch.manual_seed(args.seed) ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) data = ds[0] in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes if args.variant == "baseline": model = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout) res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}") return # pool variant node2cluster, _, info = load_top1_assignment(args.seeds, n, debug=args.debug) dbg_header = f"seeds_md5={info['json_md5']} top_size={info['top_cluster_size']} K={info['K']}" model = OneClusterPool(in_dim=in_dim, hid=args.hidden, out_dim=out_dim, node2cluster=node2cluster, edge_index_node=data.edge_index, num_nodes=n, self_loop_scale=args.self_loop_scale, use_a2_for_clusters=args.use_a2, debug_header=dbg_header) res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) print(f"L-RMC (top-1 pool): val={res['val']:.4f} test={res['test']:.4f}") if __name__ == "__main__": main()