clique / src /.ipynb_checkpoints /2.3_lrmc_bilevel-checkpoint.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
# 2.3_lrmc_bilevel.py
# Top-1 LRMC ablation with: cluster refinement (k-core), gated residual fusion,
# sparsified cluster graph (drop self-loops + per-row top-k), and A + γA² mix.
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
import argparse, json, hashlib
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_scatter import scatter_mean
from torch_sparse import coalesce, spspmm
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
# ---------------------------
# Utilities: edges and seeds
# ---------------------------
def add_scaled_self_loops(edge_index: Tensor,
edge_weight: Optional[Tensor],
num_nodes: int,
scale: float = 1.0) -> Tuple[Tensor, Tensor]:
if scale == 0.0:
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
return edge_index, edge_weight
device = edge_index.device
self_loops = torch.arange(num_nodes, device=device)
self_index = torch.stack([self_loops, self_loops], dim=0)
self_weight = torch.full((num_nodes,), float(scale), device=device)
base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device)
ei = torch.cat([edge_index, self_index], dim=1)
ew = torch.cat([base_w, self_weight], dim=0)
ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add')
return ei, ew
def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor:
# A^2 using spspmm; return binary, coalesced, no self loops
row, col = edge_index
val = torch.ones(row.numel(), device=edge_index.device)
Ai, Av = edge_index, val
Ri, _ = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes)
mask = Ri[0] != Ri[1]
Ri = Ri[:, mask]
Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add')
return Ri
def _md5(path: Path) -> str:
h = hashlib.md5()
with path.open('rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
h.update(chunk)
return h.hexdigest()
# -----
# Seeds
# -----
def _extract_members(cluster_obj: dict) -> List[int]:
m = cluster_obj.get("members", None)
if isinstance(m, list) and len(m) > 0:
return list(dict.fromkeys(int(x) for x in m))
m2 = cluster_obj.get("seed_nodes", None)
if isinstance(m2, list) and len(m2) > 0:
return list(dict.fromkeys(int(x) for x in m2))
if isinstance(m, list) or isinstance(m2, list):
return []
raise KeyError("Cluster object has neither 'members' nor 'seed_nodes'.")
def _pick_top1_cluster(obj: dict) -> List[int]:
clusters = obj.get("clusters", [])
if not isinstance(clusters, list) or len(clusters) == 0:
return []
def keyfun(c):
score = float(c.get("score", 0.0))
try:
mem = _extract_members(c)
except KeyError:
mem = []
return (score, len(mem))
best = max(clusters, key=keyfun)
try:
members = _extract_members(best)
except KeyError:
members = []
return sorted(set(int(x) for x in members))
def refine_k_core(C_star: List[int], edge_index: Tensor, k: int = 2, rounds: int = 50) -> List[int]:
"""Refine cluster by taking a k-core of its induced subgraph (label-free purity boost)."""
if k <= 0 or len(C_star) == 0:
return C_star
device = edge_index.device
S = torch.tensor(sorted(set(C_star)), device=device, dtype=torch.long)
inS = torch.zeros(int(edge_index.max().item()) + 1, dtype=torch.bool, device=device)
inS[S] = True
ei = edge_index
for _ in range(rounds):
u, v = ei[0], ei[1]
mask_int = inS[u] & inS[v]
u_int, v_int = u[mask_int], v[mask_int]
if u_int.numel() == 0:
break
deg = torch.zeros_like(inS, dtype=torch.long)
deg.scatter_add_(0, u_int, torch.ones_like(u_int, dtype=torch.long))
deg.scatter_add_(0, v_int, torch.ones_like(v_int, dtype=torch.long))
keep = inS.clone()
kill = (deg < k) & inS
if not kill.any():
break
keep[kill] = False
if keep.sum() == inS.sum():
break
inS = keep
out = torch.nonzero(inS, as_tuple=False).view(-1).tolist()
# return only nodes that were originally in C_star
return sorted(set(out).intersection(set(C_star)))
def load_top1_assignment(seeds_json: str, n_nodes: int,
debug: bool = False,
refine_k: int = 0,
edge_index_for_refine: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, dict]:
"""
Hard assignment for top-1 LRMC cluster with optional k-core refinement.
cluster 0 = top cluster; others are singletons.
"""
p = Path(seeds_json)
obj = json.loads(p.read_text(encoding='utf-8'))
C_star = _pick_top1_cluster(obj)
if len(C_star) > 0 and max(C_star) == n_nodes:
# 1-indexed → shift down
C_star = [u - 1 for u in C_star]
if refine_k > 0:
if edge_index_for_refine is None:
raise ValueError("--refine_k requires access to edge_index for refinement.")
C_star = refine_k_core(C_star, edge_index_for_refine, k=refine_k)
C = torch.tensor(C_star, dtype=torch.long)
if C.numel() == 0:
raise RuntimeError(
f"No members found for top-1 cluster in {seeds_json}. "
f"Expected 'members' or 'seed_nodes' to be non-empty."
)
node2cluster = torch.full((n_nodes,), -1, dtype=torch.long)
node2cluster[C] = 0
outside = torch.tensor(sorted(set(range(n_nodes)) - set(C.tolist())), dtype=torch.long)
if outside.numel() > 0:
node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long)
K = 1 + outside.numel()
cluster_scores = torch.zeros(K, 1, dtype=torch.float32)
cluster_scores[0, 0] = 1.0
info = {
"json_md5": _md5(p),
"top_cluster_size": int(C.numel()),
"K": int(K),
"n_outside": int(outside.numel()),
"first_members": [int(x) for x in C[:10].tolist()],
}
if debug:
print(f"[LRMC] Loaded {seeds_json} (md5={info['json_md5']}) | "
f"top_size={info['top_cluster_size']} K={info['K']} outside={info['n_outside']} "
f"first10={info['first_members']}")
return node2cluster, cluster_scores, info
# ---------------------------
# Cluster graph construction
# ---------------------------
def _sparsify_topk(edge_index: Tensor, edge_weight: Tensor, K: int, topk: int) -> Tuple[Tensor, Tensor]:
"""Keep per-row top-k neighbors by weight; symmetrize and coalesce."""
if topk <= 0:
return edge_index, edge_weight
row, col = edge_index
keep = torch.zeros(edge_weight.numel(), dtype=torch.bool, device=edge_weight.device)
# simple per-row loop (K ~ 2k is fine)
for r in range(K):
idx = (row == r).nonzero(as_tuple=False).view(-1)
if idx.numel():
k = min(topk, idx.numel())
_, order = torch.topk(edge_weight[idx], k)
keep[idx[order]] = True
ei = edge_index[:, keep]
ew = edge_weight[keep]
# symmetrize
rev = torch.stack([ei[1], ei[0]], dim=0)
ei2 = torch.cat([ei, rev], dim=1)
ew2 = torch.cat([ew, ew], dim=0)
ei2, ew2 = coalesce(ei2, ew2, K, K, op='max')
return ei2, ew2
def build_cluster_graph_mixed(edge_index_node: Tensor,
num_nodes: int,
node2cluster: Tensor,
use_a2: bool,
a2_gamma: float,
drop_self_loops: bool,
topk_per_row: int) -> Tuple[Tensor, Tensor, int]:
"""
Build A_c = S^T (A + γ A²) S, optionally drop diag, then per-row top-k sparsify.
"""
device = edge_index_node.device
# combine A and γA² at node level
row, col = edge_index_node
wA = torch.ones(row.numel(), device=device)
e_all = edge_index_node
w_all = wA
if use_a2 and a2_gamma > 0.0:
A2 = adjacency_power(edge_index_node, num_nodes, k=2)
wA2 = torch.full((A2.size(1),), float(a2_gamma), device=device)
e_all = torch.cat([e_all, A2], dim=1)
w_all = torch.cat([w_all, wA2], dim=0)
# project to clusters: S^T * (⋅) * S
K = int(node2cluster.max().item()) + 1
src, dst = e_all
csrc = node2cluster[src]
cdst = node2cluster[dst]
eC = torch.stack([csrc, cdst], dim=0)
eC, wC = coalesce(eC, w_all, K, K, op='add')
if drop_self_loops:
mask = eC[0] != eC[1]
eC, wC = eC[:, mask], wC[mask]
if topk_per_row > 0:
eC, wC = _sparsify_topk(eC, wC, K, topk_per_row)
return eC, wC, K
# --------------------------
# Models (baseline + pooled)
# --------------------------
class GCN2(nn.Module):
def __init__(self, in_dim, hid, out_dim, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(in_dim, hid)
self.conv2 = GCNConv(hid, out_dim)
self.dropout = dropout
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return x
class OneClusterPoolGated(nn.Module):
"""
Node-GCN -> pool (means) -> Cluster-GCN over sparsified A_c -> residual gate -> Node-GCN -> logits
"""
def __init__(self,
in_dim: int,
hid: int,
out_dim: int,
node2cluster: Tensor,
edge_index_node: Tensor,
num_nodes: int,
self_loop_scale: float = 0.0,
use_a2_for_clusters: bool = False,
a2_gamma: float = 0.2,
drop_cluster_self_loops: bool = True,
cluster_topk: int = 24,
debug_header: str = ""):
super().__init__()
self.n2c = node2cluster.long()
self.K = int(self.n2c.max().item()) + 1
# Node graph (A + λI)
ei_node = edge_index_node
ei_node, ew_node = add_scaled_self_loops(ei_node, None, num_nodes, scale=self_loop_scale)
self.register_buffer("edge_index_node", ei_node)
self.register_buffer("edge_weight_node", ew_node)
# Cluster graph: A_c = S^T (A + γA²) S → drop diag → per-row top-k
eC, wC, K = build_cluster_graph_mixed(
edge_index_node, num_nodes, self.n2c,
use_a2=use_a2_for_clusters, a2_gamma=a2_gamma,
drop_self_loops=drop_cluster_self_loops, topk_per_row=cluster_topk
)
self.register_buffer("edge_index_c", eC)
self.register_buffer("edge_weight_c", wC)
self.K = K
if debug_header:
print(f"[POOL] {debug_header} | cluster_edges={eC.size(1)} (K={K})")
# Layers: gated residual fusion
self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True)
self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True)
self.down = nn.Linear(hid, hid)
self.gate = nn.Sequential(nn.Linear(2*hid, hid//2), nn.ReLU(), nn.Linear(hid//2, 1))
self.lambda_logit = nn.Parameter(torch.tensor(0.0))
self.gcn_node2 = GCNConv(hid, out_dim) # final node conv on gated residual
def forward(self, x: Tensor, edge_index_node: Tensor) -> Tensor:
# node step
h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node))
# pool
z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) # [K, H]
# cluster step
z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c))
# broadcast + gated residual
hb = z2[self.n2c] # [N, H]
inj = self.down(hb)
gate_dyn = torch.sigmoid(self.gate(torch.cat([h1, inj], dim=1))) # [N,1]
lam = torch.sigmoid(self.lambda_logit) # scalar in (0,1)
alpha = lam * 1.0 + (1.0 - lam) * gate_dyn
h2 = h1 + alpha * inj
h2 = F.dropout(h2, p=0.5, training=self.training)
# final node conv (use same weighted adjacency)
out = self.gcn_node2(h2, self.edge_index_node, self.edge_weight_node)
return out
# -------------
# Training glue
# -------------
@torch.no_grad()
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
pred = logits[mask].argmax(dim=1)
return (pred == y[mask]).float().mean().item()
def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4):
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
best_val, best_state = 0.0, None
for ep in range(1, epochs + 1):
model.train()
opt.zero_grad(set_to_none=True)
logits = model(data.x, data.edge_index)
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
loss.backward(); opt.step()
model.eval()
logits = model(data.x, data.edge_index)
val_acc = accuracy(logits, data.y, data.val_mask)
if val_acc > best_val:
best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()}
if ep % 20 == 0:
tr = accuracy(logits, data.y, data.train_mask)
te = accuracy(logits, data.y, data.test_mask)
print(f"[{ep:04d}] loss={loss.item():.4f} train={tr:.3f} val={val_acc:.3f} test={te:.3f}")
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
logits = model(data.x, data.edge_index)
return {"val": accuracy(logits, data.y, data.val_mask),
"test": accuracy(logits, data.y, data.test_mask)}
# -----------
# Entrypoint
# -----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"])
ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).")
ap.add_argument("--variant", choices=["baseline", "pool"], default="pool")
ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=0.01)
ap.add_argument("--wd", type=float, default=5e-4)
ap.add_argument("--dropout", type=float, default=0.5) # baseline only
ap.add_argument("--self_loop_scale", type=float, default=0.0)
# NEW knobs for cluster graph & refinement
ap.add_argument("--use_a2", action="store_true", help="Include A^2 in cluster graph.")
ap.add_argument("--a2_gamma", type=float, default=0.2, help="Weight for A^2 in A + γA^2.")
ap.add_argument("--cluster_topk", type=int, default=24, help="Top-k neighbors per cluster row to keep.")
ap.add_argument("--drop_cluster_self_loops", action="store_true", help="Drop (c,c) in cluster graph.")
ap.add_argument("--refine_k", type=int, default=0, help="k-core refinement on the top cluster (e.g., 2).")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--debug", action="store_true")
args = ap.parse_args()
torch.manual_seed(args.seed)
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
data = ds[0]
in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes
if args.variant == "baseline":
model = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd)
print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}")
return
# pool variant
node2cluster, _, info = load_top1_assignment(
args.seeds, n, debug=args.debug, refine_k=args.refine_k, edge_index_for_refine=data.edge_index
)
dbg_header = f"seeds_md5={info['json_md5']} top_size={info['top_cluster_size']} K={info['K']}"
model = OneClusterPoolGated(
in_dim=in_dim,
hid=args.hidden,
out_dim=out_dim,
node2cluster=node2cluster,
edge_index_node=data.edge_index,
num_nodes=n,
self_loop_scale=args.self_loop_scale,
use_a2_for_clusters=args.use_a2,
a2_gamma=args.a2_gamma,
drop_cluster_self_loops=args.drop_cluster_self_loops,
cluster_topk=args.cluster_topk,
debug_header=dbg_header
)
res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd)
print(f"L-RMC (top-1 pool, gated): val={res['val']:.4f} test={res['test']:.4f}")
if __name__ == "__main__":
main()