|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse, json, os |
|
|
import numpy as np, torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.optim import Adam |
|
|
from torch_geometric.datasets import Planetoid |
|
|
from torch_geometric.nn import GCNConv, BatchNorm |
|
|
from torch_geometric.utils import coalesce, to_undirected, remove_self_loops |
|
|
from torch_geometric.utils import add_self_loops |
|
|
from torch_scatter import scatter_mean |
|
|
from torch_sparse import spspmm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_lrmc_assignment(json_path, num_nodes): |
|
|
""" |
|
|
Build a single hard assignment: node -> cluster_id in [0, K-1]. |
|
|
If nodes appear in multiple clusters, keep the one with highest 'score'. |
|
|
If any nodes are unassigned, put them into their own singleton clusters. |
|
|
Returns: |
|
|
assignment: LongTensor [num_nodes] with cluster ids |
|
|
clusters: list of lists (members per cluster) aligned to remapped cluster ids |
|
|
""" |
|
|
with open(json_path, 'r') as f: |
|
|
seeds = json.load(f) |
|
|
|
|
|
clusters_raw = seeds.get("clusters", []) |
|
|
|
|
|
clusters_raw = sorted(clusters_raw, key=lambda c: float(c.get("score", 0.0)), reverse=True) |
|
|
|
|
|
chosen_cluster_for_node = [-1] * num_nodes |
|
|
tmp_clusters = [] |
|
|
|
|
|
for c in clusters_raw: |
|
|
members = c.get("members", []) |
|
|
|
|
|
if not members: |
|
|
continue |
|
|
|
|
|
new_members = [u for u in members if 0 <= u < num_nodes and chosen_cluster_for_node[u] == -1] |
|
|
if not new_members: |
|
|
continue |
|
|
|
|
|
tmp_clusters.append(new_members) |
|
|
cid = len(tmp_clusters) - 1 |
|
|
for u in new_members: |
|
|
chosen_cluster_for_node[u] = cid |
|
|
|
|
|
|
|
|
for u in range(num_nodes): |
|
|
if chosen_cluster_for_node[u] == -1: |
|
|
tmp_clusters.append([u]) |
|
|
cid = len(tmp_clusters) - 1 |
|
|
chosen_cluster_for_node[u] = cid |
|
|
|
|
|
|
|
|
assignment = torch.tensor(chosen_cluster_for_node, dtype=torch.long) |
|
|
clusters = tmp_clusters |
|
|
return assignment, clusters |
|
|
|
|
|
def lrmc_stats(assignment, clusters, edge_index): |
|
|
N = assignment.numel(); K = int(assignment.max()) + 1 |
|
|
sizes = [len(c) for c in clusters] |
|
|
sing = sum(1 for s in sizes if s==1) |
|
|
print(f"[L-RMC] N={N} K={K} mean|C|={np.mean(sizes):.2f} " |
|
|
f"median|C|={np.median(sizes):.0f} singleton%={100*sing/K:.1f}%") |
|
|
|
|
|
same = (assignment[edge_index[0]] == assignment[edge_index[1]]).sum().item() |
|
|
print(f"[L-RMC] intra-cluster edge ratio = {same/edge_index.size(1):.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_A2_union(edge_index, num_nodes, device): |
|
|
""" |
|
|
Compute A^2 (binary) and return union edges A OR A^2, undirected & coalesced. |
|
|
""" |
|
|
|
|
|
ei = to_undirected(coalesce(edge_index, num_nodes=num_nodes), num_nodes=num_nodes) |
|
|
|
|
|
|
|
|
E = ei.size(1) |
|
|
if E == 0: |
|
|
return ei |
|
|
val = torch.ones(E, device=device) |
|
|
|
|
|
ei2, val2 = spspmm(ei, val, ei, val, num_nodes, num_nodes, num_nodes) |
|
|
|
|
|
ei2, _ = remove_self_loops(ei2) |
|
|
|
|
|
|
|
|
ei_aug = torch.cat([ei, ei2], dim=1) |
|
|
ei_aug = to_undirected(coalesce(ei_aug, num_nodes=num_nodes), num_nodes=num_nodes) |
|
|
return ei_aug |
|
|
|
|
|
|
|
|
def build_cluster_edges(edge_index_aug, assignment, num_clusters): |
|
|
""" |
|
|
Map node edges to cluster edges: (u,v) -> (c(u), c(v)), undirected + coalesced. |
|
|
""" |
|
|
c_src = assignment[edge_index_aug[0]] |
|
|
c_dst = assignment[edge_index_aug[1]] |
|
|
c_ei = torch.stack([c_src, c_dst], dim=0) |
|
|
c_ei = to_undirected(coalesce(c_ei, num_nodes=num_clusters), num_nodes=num_clusters) |
|
|
return c_ei |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Gate(nn.Module): |
|
|
def __init__(self, d_enc, d_c): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(d_enc + d_c, d_enc, bias=True), |
|
|
nn.ReLU(), |
|
|
nn.Linear(d_enc, d_enc, bias=True), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
def forward(self, h_enc, h_cluster_broadcast): |
|
|
g = self.mlp(torch.cat([h_enc, h_cluster_broadcast], dim=-1)) |
|
|
return h_enc + g * h_cluster_broadcast |
|
|
|
|
|
class GCN_LRMC_NodeClassifier(nn.Module): |
|
|
""" |
|
|
Encoder: GCN -> GCN on original graph |
|
|
Pool: aggregate encoder features per L-RMC cluster |
|
|
Coarse: GCN -> (optional GCN) on cluster graph |
|
|
Unpool: broadcast cluster features back to nodes |
|
|
Decoder: GCN (on original graph) -> logits |
|
|
""" |
|
|
def __init__(self, in_dim, hidden_dim, cluster_hidden_dim, out_dim, |
|
|
edge_index, assignment, cluster_edge_index, dropout=0.5): |
|
|
super().__init__() |
|
|
self.edge_index = edge_index |
|
|
self.assignment = assignment |
|
|
self.cluster_edge_index = cluster_edge_index |
|
|
self.num_clusters = int(assignment.max().item() + 1) |
|
|
self.dropout = dropout |
|
|
|
|
|
|
|
|
self.enc1 = GCNConv(in_dim, hidden_dim, improved=True) |
|
|
self.enc2 = GCNConv(hidden_dim, hidden_dim, improved=True) |
|
|
|
|
|
|
|
|
self.cgc1 = GCNConv(hidden_dim, cluster_hidden_dim, improved=True) |
|
|
self.cgc2 = GCNConv(cluster_hidden_dim, cluster_hidden_dim, improved=True) |
|
|
|
|
|
|
|
|
dec_in = hidden_dim + cluster_hidden_dim |
|
|
self.dec1 = GCNConv(dec_in, hidden_dim, improved=True) |
|
|
self.cls = GCNConv(hidden_dim, out_dim, improved=True) |
|
|
|
|
|
self.bn_e1 = BatchNorm(hidden_dim) |
|
|
self.bn_e2 = BatchNorm(hidden_dim) |
|
|
self.bn_c1 = BatchNorm(cluster_hidden_dim) |
|
|
self.bn_c2 = BatchNorm(cluster_hidden_dim) |
|
|
self.bn_d1 = BatchNorm(hidden_dim) |
|
|
self.gate = Gate(hidden_dim, cluster_hidden_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
h = F.dropout(x, p=self.dropout, training=self.training) |
|
|
h = F.relu(self.bn_e1(self.enc1(h, self.edge_index))) |
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
|
h2 = F.relu(self.bn_e2(self.enc2(h, self.edge_index))) |
|
|
h = h + h2 |
|
|
h_enc = h |
|
|
|
|
|
|
|
|
|
|
|
cluster_x = scatter_mean(h_enc, self.assignment, dim=0, dim_size=self.num_clusters) |
|
|
|
|
|
|
|
|
hc = F.dropout(cluster_x, p=self.dropout, training=self.training) |
|
|
hc = F.relu(self.bn_c1(self.cgc1(cluster_x, self.cluster_edge_index))) |
|
|
hc = F.dropout(hc, p=self.dropout, training=self.training) |
|
|
hc2 = F.relu(self.bn_c2(self.cgc2(hc, self.cluster_edge_index))) |
|
|
hc = hc + hc2 |
|
|
|
|
|
|
|
|
hc_broadcast = hc[self.assignment] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_dec_in = torch.cat([h_enc, hc_broadcast], dim=1) |
|
|
h = F.dropout(h_dec_in, p=self.dropout, training=self.training) |
|
|
h = F.relu(self.dec1(h, self.edge_index)) |
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
|
out = self.cls(h, self.edge_index) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(model, data): |
|
|
model.eval() |
|
|
out = model(data.x) |
|
|
y = data.y |
|
|
pred = out.argmax(dim=-1) |
|
|
|
|
|
def acc(mask): |
|
|
m = mask if mask.dtype == torch.bool else mask.bool() |
|
|
if m.sum() == 0: |
|
|
return 0.0 |
|
|
return (pred[m] == y[m]).float().mean().item() |
|
|
|
|
|
return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask) |
|
|
|
|
|
|
|
|
def train_loop(model, data, epochs=200, lr=5e-3, weight_decay=5e-4, patience=100): |
|
|
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
best_val, best_test = 0.0, 0.0 |
|
|
best_state = None |
|
|
no_improve = 0 |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
model.train() |
|
|
optimizer.zero_grad() |
|
|
logits = model(data.x) |
|
|
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
tr, va, te = evaluate(model, data) |
|
|
if va > best_val: |
|
|
best_val, best_test = va, te |
|
|
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
|
|
no_improve = 0 |
|
|
else: |
|
|
no_improve += 1 |
|
|
|
|
|
|
|
|
print(f"Epoch {epoch:03d} | loss={loss.item():.4f} | " |
|
|
f"train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}% test@best={best_test*100:.2f}%") |
|
|
|
|
|
if no_improve >= patience: |
|
|
print(f"Early stopping at epoch {epoch} (no val improvement for {patience})") |
|
|
break |
|
|
|
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state) |
|
|
tr, va, te = evaluate(model, data) |
|
|
print(f"\nFinal (reloaded best): train={tr*100:.2f}% val={va*100:.2f}% test={te*100:.2f}%") |
|
|
return te |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "Citeseer", "Pubmed"]) |
|
|
parser.add_argument("--lrmc_json", type=str, required=True) |
|
|
parser.add_argument("--use_a2", type=str, default="true", help="Use A^2 before pooling (true/false)") |
|
|
parser.add_argument("--hidden", type=int, default=64) |
|
|
parser.add_argument("--cluster_hidden", type=int, default=64) |
|
|
parser.add_argument("--dropout", type=float, default=0.5) |
|
|
parser.add_argument("--epochs", type=int, default=200) |
|
|
parser.add_argument("--lr", type=float, default=5e-3) |
|
|
parser.add_argument("--weight_decay", type=float, default=5e-4) |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
args = parser.parse_args() |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
dataset = Planetoid(root=os.path.join("data", args.dataset), name=args.dataset) |
|
|
data = dataset[0].to(device) |
|
|
num_nodes = data.num_nodes |
|
|
in_dim = dataset.num_node_features |
|
|
out_dim = dataset.num_classes |
|
|
|
|
|
|
|
|
assignment, clusters = load_lrmc_assignment(args.lrmc_json, num_nodes) |
|
|
assignment = assignment.to(device) |
|
|
num_clusters = int(assignment.max().item() + 1) |
|
|
print(f"[L-RMC] Loaded clusters: K={num_clusters} (N={num_nodes})") |
|
|
|
|
|
lrmc_stats(assignment, clusters, data.edge_index) |
|
|
|
|
|
|
|
|
use_a2 = args.use_a2.lower() in ("1", "true", "yes", "y") |
|
|
if use_a2: |
|
|
edge_index_aug = compute_A2_union(data.edge_index, num_nodes, device) |
|
|
print("[L-RMC] Using A^2 ∪ A before pooling (connectivity augmentation).") |
|
|
else: |
|
|
edge_index_aug = to_undirected(coalesce(data.edge_index, num_nodes=num_nodes), num_nodes=num_nodes) |
|
|
print("[L-RMC] Using original A for pooling.") |
|
|
|
|
|
cluster_edge_index = build_cluster_edges(edge_index_aug, assignment, num_clusters) |
|
|
|
|
|
|
|
|
model = GCN_LRMC_NodeClassifier( |
|
|
in_dim=in_dim, |
|
|
hidden_dim=args.hidden, |
|
|
cluster_hidden_dim=args.cluster_hidden, |
|
|
out_dim=out_dim, |
|
|
edge_index=data.edge_index, |
|
|
assignment=assignment, |
|
|
cluster_edge_index=cluster_edge_index, |
|
|
dropout=args.dropout, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
test_acc = train_loop(model, data, epochs=args.epochs, lr=args.lr, |
|
|
weight_decay=args.weight_decay, patience=100) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|