clique / src /gcn_lrmc_node_classification.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
raw
history blame
13.3 kB
# gcn_lrmc_node_classify.py
# Node classification with GCN + L-RMC (static pooling + unpool + skip)
# Usage:
# python gcn_lrmc_node_classify.py --dataset Cora --lrmc_json /path/to/lrmc_seeds.json
# Options:
# --use_a2 true|false (default true; use A^2 before pooling as in Graph U-Nets)
# --epochs 200 --lr 0.005 --hidden 64 --cluster_hidden 64 --dropout 0.5
import argparse, json, os
import numpy as np, torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, BatchNorm
from torch_geometric.utils import coalesce, to_undirected, remove_self_loops
from torch_geometric.utils import add_self_loops
from torch_scatter import scatter_mean
from torch_sparse import spspmm
# -----------------------------
# L-RMC assignment utilities
# -----------------------------
def load_lrmc_assignment(json_path, num_nodes):
"""
Build a single hard assignment: node -> cluster_id in [0, K-1].
If nodes appear in multiple clusters, keep the one with highest 'score'.
If any nodes are unassigned, put them into their own singleton clusters.
Returns:
assignment: LongTensor [num_nodes] with cluster ids
clusters: list of lists (members per cluster) aligned to remapped cluster ids
"""
with open(json_path, 'r') as f:
seeds = json.load(f)
clusters_raw = seeds.get("clusters", [])
# Sort clusters by score descending to prefer higher-scoring clusters on conflicts
clusters_raw = sorted(clusters_raw, key=lambda c: float(c.get("score", 0.0)), reverse=True)
chosen_cluster_for_node = [-1] * num_nodes
tmp_clusters = [] # will collect chosen clusters (members), before remap
for c in clusters_raw:
members = c.get("members", [])
# skip empty
if not members:
continue
# take only members not yet assigned
new_members = [u for u in members if 0 <= u < num_nodes and chosen_cluster_for_node[u] == -1]
if not new_members:
continue
# tentatively assign this cluster to those nodes (others in the cluster were already taken)
tmp_clusters.append(new_members)
cid = len(tmp_clusters) - 1
for u in new_members:
chosen_cluster_for_node[u] = cid
# Any nodes still -1 → singleton clusters
for u in range(num_nodes):
if chosen_cluster_for_node[u] == -1:
tmp_clusters.append([u])
cid = len(tmp_clusters) - 1
chosen_cluster_for_node[u] = cid
# Remap cluster ids to [0..K-1] (already contiguous by construction)
assignment = torch.tensor(chosen_cluster_for_node, dtype=torch.long)
clusters = tmp_clusters
return assignment, clusters
def lrmc_stats(assignment, clusters, edge_index):
N = assignment.numel(); K = int(assignment.max()) + 1
sizes = [len(c) for c in clusters]
sing = sum(1 for s in sizes if s==1)
print(f"[L-RMC] N={N} K={K} mean|C|={np.mean(sizes):.2f} "
f"median|C|={np.median(sizes):.0f} singleton%={100*sing/K:.1f}%")
# how many edges are intra-cluster?
same = (assignment[edge_index[0]] == assignment[edge_index[1]]).sum().item()
print(f"[L-RMC] intra-cluster edge ratio = {same/edge_index.size(1):.3f}")
# -----------------------------
# Graph helpers
# -----------------------------
def compute_A2_union(edge_index, num_nodes, device):
"""
Compute A^2 (binary) and return union edges A OR A^2, undirected & coalesced.
"""
# Make undirected and coalesced (no weights)
ei = to_undirected(coalesce(edge_index, num_nodes=num_nodes), num_nodes=num_nodes)
# Build ones weights for sparse-sparse multiply
E = ei.size(1)
if E == 0:
return ei # empty graph
val = torch.ones(E, device=device)
# spspmm: (m x k) @ (k x n) where here m=n=k=num_nodes
ei2, val2 = spspmm(ei, val, ei, val, num_nodes, num_nodes, num_nodes)
# Remove self-loops from A2 (optional; GCNConv adds its own self-loops later)
ei2, _ = remove_self_loops(ei2)
# Binarize & union with A
# (coalesce later will drop duplicates anyway)
ei_aug = torch.cat([ei, ei2], dim=1)
ei_aug = to_undirected(coalesce(ei_aug, num_nodes=num_nodes), num_nodes=num_nodes)
return ei_aug
def build_cluster_edges(edge_index_aug, assignment, num_clusters):
"""
Map node edges to cluster edges: (u,v) -> (c(u), c(v)), undirected + coalesced.
"""
c_src = assignment[edge_index_aug[0]]
c_dst = assignment[edge_index_aug[1]]
c_ei = torch.stack([c_src, c_dst], dim=0)
c_ei = to_undirected(coalesce(c_ei, num_nodes=num_clusters), num_nodes=num_clusters)
return c_ei
# -----------------------------
# Model
# -----------------------------
class Gate(nn.Module):
def __init__(self, d_enc, d_c):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(d_enc + d_c, d_enc, bias=True),
nn.ReLU(),
nn.Linear(d_enc, d_enc, bias=True),
nn.Sigmoid(),
)
def forward(self, h_enc, h_cluster_broadcast):
g = self.mlp(torch.cat([h_enc, h_cluster_broadcast], dim=-1))
return h_enc + g * h_cluster_broadcast # residual gated add
class GCN_LRMC_NodeClassifier(nn.Module):
"""
Encoder: GCN -> GCN on original graph
Pool: aggregate encoder features per L-RMC cluster
Coarse: GCN -> (optional GCN) on cluster graph
Unpool: broadcast cluster features back to nodes
Decoder: GCN (on original graph) -> logits
"""
def __init__(self, in_dim, hidden_dim, cluster_hidden_dim, out_dim,
edge_index, assignment, cluster_edge_index, dropout=0.5):
super().__init__()
self.edge_index = edge_index # original graph edges
self.assignment = assignment # [N]
self.cluster_edge_index = cluster_edge_index # edges on cluster graph
self.num_clusters = int(assignment.max().item() + 1)
self.dropout = dropout
# Encoder on node graph
self.enc1 = GCNConv(in_dim, hidden_dim, improved=True)
self.enc2 = GCNConv(hidden_dim, hidden_dim, improved=True)
# GCN(s) on cluster graph
self.cgc1 = GCNConv(hidden_dim, cluster_hidden_dim, improved=True)
self.cgc2 = GCNConv(cluster_hidden_dim, cluster_hidden_dim, improved=True)
# Decoder on node graph (combine skip from encoder + broadcast from cluster)
dec_in = hidden_dim + cluster_hidden_dim
self.dec1 = GCNConv(dec_in, hidden_dim, improved=True)
self.cls = GCNConv(hidden_dim, out_dim, improved=True) # final logits
self.bn_e1 = BatchNorm(hidden_dim)
self.bn_e2 = BatchNorm(hidden_dim)
self.bn_c1 = BatchNorm(cluster_hidden_dim)
self.bn_c2 = BatchNorm(cluster_hidden_dim)
self.bn_d1 = BatchNorm(hidden_dim)
self.gate = Gate(hidden_dim, cluster_hidden_dim)
def forward(self, x):
# Encoder on original graph
h = F.dropout(x, p=self.dropout, training=self.training)
h = F.relu(self.bn_e1(self.enc1(h, self.edge_index)))
h = F.dropout(h, p=self.dropout, training=self.training)
h2 = F.relu(self.bn_e2(self.enc2(h, self.edge_index)))
h = h + h2
h_enc = h # skip for decoder
# Pool: aggregate encoder features to clusters (mean)
# cluster_x: [K, hidden_dim]
cluster_x = scatter_mean(h_enc, self.assignment, dim=0, dim_size=self.num_clusters)
# Coarse GCN(s) on cluster graph
hc = F.dropout(cluster_x, p=self.dropout, training=self.training)
hc = F.relu(self.bn_c1(self.cgc1(cluster_x, self.cluster_edge_index)))
hc = F.dropout(hc, p=self.dropout, training=self.training)
hc2 = F.relu(self.bn_c2(self.cgc2(hc, self.cluster_edge_index)))
hc = hc + hc2
# Unpool: broadcast coarse features back to nodes via assignment
hc_broadcast = hc[self.assignment] # [N, cluster_hidden_dim]
# # after hc_broadcast is computed
# g_in = torch.cat([h_enc, hc_broadcast], dim=1)
# gate = torch.sigmoid(nn.Linear(g_in.size(1), h_enc.size(1)).to(g_in.device)(g_in))
# h_dec_in = h_enc + gate * hc_broadcast # gated residual instead of concat
# Decoder on original graph
h_dec_in = torch.cat([h_enc, hc_broadcast], dim=1) # [N, hidden_dim + cluster_hidden_dim]
h = F.dropout(h_dec_in, p=self.dropout, training=self.training)
h = F.relu(self.dec1(h, self.edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
out = self.cls(h, self.edge_index) # logits [N, C]
return out
# -----------------------------
# Train / Eval
# -----------------------------
@torch.no_grad()
def evaluate(model, data):
model.eval()
out = model(data.x)
y = data.y
pred = out.argmax(dim=-1)
def acc(mask):
m = mask if mask.dtype == torch.bool else mask.bool()
if m.sum() == 0:
return 0.0
return (pred[m] == y[m]).float().mean().item()
return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask)
def train_loop(model, data, epochs=200, lr=5e-3, weight_decay=5e-4, patience=100):
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
best_val, best_test = 0.0, 0.0
best_state = None
no_improve = 0
for epoch in range(1, epochs + 1):
model.train()
optimizer.zero_grad()
logits = model(data.x)
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
tr, va, te = evaluate(model, data)
if va > best_val:
best_val, best_test = va, te
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch:03d} | loss={loss.item():.4f} | "
f"train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}% test@best={best_test*100:.2f}%")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch} (no val improvement for {patience})")
break
if best_state is not None:
model.load_state_dict(best_state)
tr, va, te = evaluate(model, data)
print(f"\nFinal (reloaded best): train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}%")
return te
# -----------------------------
# Main
# -----------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "Citeseer", "Pubmed"])
parser.add_argument("--lrmc_json", type=str, required=True)
parser.add_argument("--use_a2", type=str, default="true", help="Use A^2 before pooling (true/false)")
parser.add_argument("--hidden", type=int, default=64)
parser.add_argument("--cluster_hidden", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--lr", type=float, default=5e-3)
parser.add_argument("--weight_decay", type=float, default=5e-4)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = Planetoid(root=os.path.join("data", args.dataset), name=args.dataset)
data = dataset[0].to(device)
num_nodes = data.num_nodes
in_dim = dataset.num_node_features
out_dim = dataset.num_classes
# Load L-RMC assignment
assignment, clusters = load_lrmc_assignment(args.lrmc_json, num_nodes)
assignment = assignment.to(device)
num_clusters = int(assignment.max().item() + 1)
print(f"[L-RMC] Loaded clusters: K={num_clusters} (N={num_nodes})")
lrmc_stats(assignment, clusters, data.edge_index)
# Build augmented node edge_index (A or A^2 ∪ A), then cluster edges
use_a2 = args.use_a2.lower() in ("1", "true", "yes", "y")
if use_a2:
edge_index_aug = compute_A2_union(data.edge_index, num_nodes, device)
print("[L-RMC] Using A^2 ∪ A before pooling (connectivity augmentation).")
else:
edge_index_aug = to_undirected(coalesce(data.edge_index, num_nodes=num_nodes), num_nodes=num_nodes)
print("[L-RMC] Using original A for pooling.")
cluster_edge_index = build_cluster_edges(edge_index_aug, assignment, num_clusters)
# Build model
model = GCN_LRMC_NodeClassifier(
in_dim=in_dim,
hidden_dim=args.hidden,
cluster_hidden_dim=args.cluster_hidden,
out_dim=out_dim,
edge_index=data.edge_index, # original graph for enc/dec
assignment=assignment, # node -> cluster
cluster_edge_index=cluster_edge_index, # cluster graph for coarse GCN
dropout=args.dropout,
).to(device)
# Train / evaluate
test_acc = train_loop(model, data, epochs=args.epochs, lr=args.lr,
weight_decay=args.weight_decay, patience=100)
if __name__ == "__main__":
main()