clique / src /2.1_lrmc_bilevel.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
raw
history blame
12 kB
# lrmc_bilevel.py
# Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
#
# Usage examples:
# python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline
# python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant pool
#
# Notes:
# - We read your LRMC JSON, pick the single cluster with the highest 'score',
# assign it to cluster id 0, and make all other nodes singletons (1..K-1).
# - For --variant pool: Node-GCN -> pool (means) -> Cluster-GCN -> broadcast + skip -> Node-GCN -> classifier
# - For --variant baseline: Standard 2-layer GCN.
# - Keep flags like --self_loop_scale and --use_a2 if you want A+位I / A^2 augmentation.
import argparse, json
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_scatter import scatter_add, scatter_mean
from torch_sparse import coalesce, spspmm
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
# ---------------------------
# Utilities: edges and seeds
# ---------------------------
def add_scaled_self_loops(edge_index: Tensor,
edge_weight: Optional[Tensor],
num_nodes: int,
scale: float = 1.0) -> Tuple[Tensor, Tensor]:
"""Add self-loops with chosen weight (scale). If scale=0, return unchanged (and create weights if None)."""
if scale == 0.0:
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
return edge_index, edge_weight
device = edge_index.device
self_loops = torch.arange(num_nodes, device=device)
self_index = torch.stack([self_loops, self_loops], dim=0)
self_weight = torch.full((num_nodes,), float(scale), device=device)
base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device)
ei = torch.cat([edge_index, self_index], dim=1)
ew = torch.cat([base_w, self_weight], dim=0)
ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add')
return ei, ew
def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor:
"""
Compute (binary) k-th power adjacency using sparse matmul (torch_sparse.spspmm).
Here we use k=2. Returns coalesced edge_index without weights.
"""
row, col = edge_index
val = torch.ones(row.numel(), device=edge_index.device)
Ai, Av = edge_index, val
# A^2
Ri, Rv = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes)
mask = Ri[0] != Ri[1] # drop diagonal; add custom self-loops later if desired
Ri = Ri[:, mask]
Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add')
return Ri
def build_cluster_graph(edge_index: Tensor,
num_nodes: int,
node2cluster: Tensor,
weight_per_edge: Optional[Tensor] = None,
num_clusters: Optional[int] = None
) -> Tuple[Tensor, Tensor, int]:
"""
Build cluster graph A_c = S^T A S with summed multiplicities as weights.
node2cluster: [N] long tensor mapping each node -> cluster id.
"""
K = int(node2cluster.max().item()) + 1 if num_clusters is None else num_clusters
src, dst = edge_index
csrc = node2cluster[src]
cdst = node2cluster[dst]
edge_c = torch.stack([csrc, cdst], dim=0)
w = weight_per_edge if weight_per_edge is not None else torch.ones(edge_c.size(1), device=edge_c.device)
edge_c, w = coalesce(edge_c, w, K, K, op='add') # sum multiplicities
return edge_c, w, K
# -----
# Seeds
# -----
def _pick_top1_cluster(obj: dict) -> List[int]:
"""
From LRMC JSON with structure: {"clusters":[{"members":[...], "score":float, ...}, ...]}
choose the cluster with max (score, size) and return its members.
"""
clusters = obj.get("clusters", [])
if not clusters:
return []
# choose by highest score, then by size (tiebreaker)
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("members", []))))
return list(best.get("members", []))
def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor]:
"""
Create a hard assignment for top-1 LRMC cluster:
- cluster 0 = top-1 LRMC set
- nodes outside are singletons (1..K-1)
Returns:
node2cluster: [N] long
cluster_scores: [K,1] with 1.0 for top cluster, 0.0 for singletons
"""
obj = json.loads(Path(seeds_json).read_text())
C_star = _pick_top1_cluster(obj)
C_star = torch.tensor(sorted(set(C_star)), dtype=torch.long)
node2cluster = torch.full((n_nodes,), -1, dtype=torch.long)
node2cluster[C_star] = 0
outside = torch.tensor(sorted(set(range(n_nodes)) - set(C_star.tolist())), dtype=torch.long)
if outside.numel() > 0:
node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long)
assert int(node2cluster.min()) >= 0, "All nodes must be assigned."
K = 1 + outside.numel()
cluster_scores = torch.zeros(K, 1, dtype=torch.float32)
if C_star.numel() > 0:
cluster_scores[0, 0] = 1.0 # emphasize the supercluster
return node2cluster, cluster_scores
# --------------------------
# Models (baseline + pooled)
# --------------------------
class GCN2(nn.Module):
"""Plain 2-layer GCN baseline."""
def __init__(self, in_dim, hid, out_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hid)
self.conv2 = GCNConv(hid, out_dim)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
class OneClusterPool(nn.Module):
"""
Node-GCN -> pool to one-cluster + singletons -> Cluster-GCN -> broadcast + skip -> Node-GCN -> classifier
"""
def __init__(self,
in_dim: int,
hid: int,
out_dim: int,
node2cluster: Tensor,
edge_index_node: Tensor,
num_nodes: int,
self_loop_scale: float = 0.0,
use_a2_for_clusters: bool = False):
super().__init__()
self.n2c = node2cluster.long()
self.K = int(self.n2c.max().item()) + 1
# Node graph (A + 位I if desired)
ei_node = edge_index_node
ei_node, ew_node = add_scaled_self_loops(ei_node, None, num_nodes, scale=self_loop_scale)
self.register_buffer("edge_index_node", ei_node)
self.register_buffer("edge_weight_node", ew_node)
# Cluster graph from A or A^2
ei_for_c = adjacency_power(edge_index_node, num_nodes, k=2) if use_a2_for_clusters else edge_index_node
edge_index_c, edge_weight_c, K = build_cluster_graph(ei_for_c, num_nodes, self.n2c)
self.register_buffer("edge_index_c", edge_index_c)
self.register_buffer("edge_weight_c", edge_weight_c)
self.K = K
# Layers
self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True)
self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True)
self.gcn_node2 = GCNConv(hid * 2, out_dim) # on concatenated [h_node, h_broadcast]
def forward(self, x: Tensor, edge_index_node: Tensor) -> Tensor:
# Node GCN (uses stored weights)
h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node))
# Pool to clusters: mean per cluster
z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) # [K, H]
# Cluster GCN
z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c))
# Broadcast back + skip concat
hb = z2[self.n2c] # [N, H]
hcat = torch.cat([h1, hb], dim=1) # [N, 2H]
# Final node GCN head -> logits
out = self.gcn_node2(hcat, edge_index_node)
return out
# -------------
# Training glue
# -------------
@torch.no_grad()
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
pred = logits[mask].argmax(dim=1)
return (pred == y[mask]).float().mean().item()
def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4):
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
best_val, best_state = 0.0, None
for ep in range(1, epochs + 1):
model.train()
opt.zero_grad(set_to_none=True)
logits = model(data.x, data.edge_index)
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
loss.backward(); opt.step()
# track best on val
model.eval()
logits = model(data.x, data.edge_index)
val_acc = accuracy(logits, data.y, data.val_mask)
if val_acc > best_val:
best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()}
if ep % 20 == 0:
tr = accuracy(logits, data.y, data.train_mask)
te = accuracy(logits, data.y, data.test_mask)
print(f"[{ep:04d}] loss={loss.item():.4f} train={tr:.3f} val={val_acc:.3f} test={te:.3f}")
# test @ best val
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
logits = model(data.x, data.edge_index)
return {
"val": accuracy(logits, data.y, data.val_mask),
"test": accuracy(logits, data.y, data.test_mask)
}
# -----------
# Entrypoint
# -----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"])
ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).")
ap.add_argument("--variant", choices=["baseline", "pool"], default="pool",
help="baseline=plain GCN; pool=top-1 LRMC one-cluster pooling")
ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=0.01)
ap.add_argument("--wd", type=float, default=5e-4)
ap.add_argument("--dropout", type=float, default=0.5) # used in baseline only
ap.add_argument("--self_loop_scale", type=float, default=0.0, help="位 for A+位I on node graph (0 disables)")
ap.add_argument("--use_a2", action="store_true", help="Use A^2 to build the cluster graph (recommended for pool)")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
torch.manual_seed(args.seed)
# Load dataset
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
data = ds[0]
in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes
if args.variant == "baseline":
model = GCN2(in_dim, args.hidden, out_dim)
# use default add_self_loops=True behavior inside convs
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd)
print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}")
return
# Top-1 LRMC assignment
node2cluster, _ = load_top1_assignment(args.seeds, n)
# One-cluster pooled model
model = OneClusterPool(in_dim=in_dim,
hid=args.hidden,
out_dim=out_dim,
node2cluster=node2cluster,
edge_index_node=data.edge_index,
num_nodes=n,
self_loop_scale=args.self_loop_scale,
use_a2_for_clusters=args.use_a2)
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd)
print(f"L-RMC (top-1 pool): val={res['val']:.4f} test={res['test']:.4f}")
if __name__ == "__main__":
main()