import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from .. import ModelWrapper
from cogdl.wrappers.tools.wrapper_utils import evaluate_clustering
[docs]class DAEGCModelWrapper(ModelWrapper):
[docs] @staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--num-clusters", type=int, default=7)
parser.add_argument("--cluster-method", type=str, default="kmeans", help="option: kmeans or spectral")
parser.add_argument("--evaluation", type=str, default="full", help="option: full or NMI")
parser.add_argument("--T", type=int, default=5)
# fmt: on
def __init__(self, model, optimizer_cfg, num_clusters, cluster_method="kmeans", evaluation="full", T=5):
super(DAEGCModelWrapper, self).__init__()
self.model = model
self.num_clusters = num_clusters
self.optimizer_cfg = optimizer_cfg
self.cluster_method = cluster_method
self.full = evaluation == "full"
self.t = T
self.stage = 0
self.count = 0
[docs] def train_step(self, subgraph):
graph = subgraph
if self.stage == 0:
z = self.model(graph)
loss = self.recon_loss(z, graph.adj_mx)
else:
cluster_center = self.model.get_cluster_center()
z = self.model(graph)
Q = self.getQ(z, cluster_center)
self.count += 1
if self.count % self.t == 0:
P = self.getP(Q).detach()
loss = self.recon_loss(z, graph.adj_mx) + self.gamma * self.cluster_loss(P, Q)
return loss
[docs] def test_step(self, subgraph):
graph = subgraph
features_matrix = self.model(graph)
features_matrix = features_matrix.detach().cpu().numpy()
return evaluate_clustering(
features_matrix, graph.y, self.cluster_method, self.num_clusters, graph.num_nodes, self.full
)
[docs] def recon_loss(self, z, adj):
return F.binary_cross_entropy(F.softmax(torch.mm(z, z.t())), adj, reduction="sum")
[docs] def cluster_loss(self, P, Q):
return torch.nn.KLDivLoss(reduce=True, size_average=False)(P.log(), Q)
[docs] def setup_optimizer(self):
lr, wd = self.optimizer_cfg["lr"], self.optimizer_cfg["weight_decay"]
return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
[docs] def pre_stage(self, stage, data_w):
self.stage = stage
if stage == 0:
data = data_w.get_dataset().data
data.add_remaining_self_loops()
data.store("edge_index")
data.adj_mx = torch.sparse_coo_tensor(
torch.stack(data.edge_index),
torch.ones(data.edge_index[0].shape[0]),
torch.Size([data.x.shape[0], data.x.shape[0]]),
).to_dense()
edge_index_2hop = data.edge_index
data.edge_index = edge_index_2hop
[docs] def post_stage(self, stage, data_w):
if stage == 0:
data = data_w.get_dataset().data
data.restore("edge_index")
data.to(self.device)
kmeans = KMeans(n_clusters=self.num_clusters, random_state=0).fit(self.model(data).detach().cpu().numpy())
self.model.set_cluster_center(torch.tensor(kmeans.cluster_centers_, device=self.device))
[docs] def getQ(self, z, cluster_center):
Q = None
for i in range(z.shape[0]):
dis = torch.sum((z[i].repeat(self.num_clusters, 1) - cluster_center) ** 2, dim=1)
t = 1 / (1 + dis)
t = t / torch.sum(t)
if Q is None:
Q = t.clone().unsqueeze(0)
else:
Q = torch.cat((Q, t.unsqueeze(0)), 0)
return Q
[docs] def getP(self, Q):
P = torch.sum(Q, dim=0).repeat(Q.shape[0], 1)
P = Q ** 2 / P
P = P / (torch.ones(1, self.num_clusters, device=self.device) * torch.sum(P, dim=1).unsqueeze(-1))
# print("P=", P)
return P