Source code for cogdl.wrappers.model_wrapper.node_classification.dgi_mw

import numpy as np

import torch
import torch.nn as nn

from .. import UnsupervisedModelWrapper
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg


[docs]class DGIModelWrapper(UnsupervisedModelWrapper):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--hidden-size", type=int, default=512)
# fmt: on def __init__(self, model, optimizer_cfg): super(DGIModelWrapper, self).__init__() self.model = model self.optimizer_cfg = optimizer_cfg self.read = AvgReadout() self.sigm = nn.Sigmoid() hidden_size = optimizer_cfg["hidden_size"] assert hidden_size > 0 self.disc = Discriminator(hidden_size) self.loss_fn = torch.nn.BCEWithLogitsLoss() self.act = nn.PReLU()
[docs] def train_step(self, subgraph): graph = subgraph graph.sym_norm() x = graph.x shuffle_x = self.augment(graph) graph.x = x h_pos = self.act(self.model(graph)) c = self.read(h_pos) c = self.sigm(c) graph.x = shuffle_x h_neg = self.act(self.model(graph)) logits = self.disc(c, h_pos, h_neg) graph.x = x num_nodes = x.shape[0] labels = torch.zeros((num_nodes * 2,), device=x.device) labels[:num_nodes] = 1 loss = self.loss_fn(logits, labels) return loss
[docs] def test_step(self, graph): with torch.no_grad(): pred = self.act(self.model(graph)) y = graph.y result = evaluate_node_embeddings_using_logreg(pred, y, graph.train_mask, graph.test_mask) self.note("test_acc", result)
[docs] @staticmethod def augment(graph): idx = np.random.permutation(graph.num_nodes) shuffle_x = graph.x[idx, :] return shuffle_x
[docs] def setup_optimizer(self): cfg = self.optimizer_cfg return torch.optim.Adam(self.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
# Borrowed from https://github.com/PetarV-/DGI class AvgReadout(nn.Module): def __init__(self): super(AvgReadout, self).__init__() def forward(self, seq, msk=None): dim = len(seq.shape) - 2 if msk is None: return torch.mean(seq, dim) else: return torch.sum(seq * msk, dim) / torch.sum(msk) # Borrowed from https://github.com/PetarV-/DGI class Discriminator(nn.Module): def __init__(self, n_h): super(Discriminator, self).__init__() self.f_k = nn.Bilinear(n_h, n_h, 1) for m in self.modules(): self.weights_init(m) def weights_init(self, m): if isinstance(m, nn.Bilinear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0) def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): c_x = torch.unsqueeze(c, 0) c_x = c_x.expand_as(h_pl) sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1) sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1) if s_bias1 is not None: sc_1 += s_bias1 if s_bias2 is not None: sc_2 += s_bias2 logits = torch.cat((sc_1, sc_2)) return logits