Source code for cogdl.wrappers.model_wrapper.clustering.gae_mw

import torch
import torch.nn.functional as F

from .. import ModelWrapper
from cogdl.wrappers.tools.wrapper_utils import evaluate_clustering


[docs]class GAEModelWrapper(ModelWrapper):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--num-clusters", type=int, default=7) parser.add_argument("--cluster-method", type=str, default="kmeans", help="option: kmeans or spectral") parser.add_argument("--evaluation", type=str, default="full", help="option: full or NMI")
# fmt: on def __init__(self, model, optimizer_cfg, num_clusters, cluster_method="kmeans", evaluation="full"): super(GAEModelWrapper, self).__init__() self.model = model self.num_clusters = num_clusters self.optimizer_cfg = optimizer_cfg self.cluster_method = cluster_method self.full = evaluation == "full"
[docs] def train_step(self, subgraph): graph = subgraph loss = self.model.make_loss(graph, graph.adj_mx) return loss
[docs] def test_step(self, subgraph): graph = subgraph features_matrix = self.model(graph) features_matrix = features_matrix.detach().cpu().numpy() return evaluate_clustering( features_matrix, graph.y, self.cluster_method, self.num_clusters, graph.num_nodes, self.full )
[docs] def pre_stage(self, stage, data_w): if stage == 0: data = data_w.get_dataset().data adj_mx = torch.sparse_coo_tensor( torch.stack(data.edge_index), torch.ones(data.edge_index[0].shape[0]), torch.Size([data.x.shape[0], data.x.shape[0]]), ).to_dense() data.adj_mx = adj_mx
[docs] def setup_optimizer(self): lr, wd = self.optimizer_cfg["lr"], self.optimizer_cfg["weight_decay"] return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)