Source code for cogdl.models.nn.dgi

import math

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model


# Borrowed from https://github.com/PetarV-/DGI
[docs]class GCN(nn.Module): def __init__(self, in_ft, out_ft, act, bias=True): super(GCN, self).__init__() self.fc = nn.Linear(in_ft, out_ft, bias=False) self.act = nn.PReLU() if act == 'prelu' else act if bias: self.bias = nn.Parameter(torch.FloatTensor(out_ft)) self.bias.data.fill_(0.0) else: self.register_parameter('bias', None) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
# Shape of seq: (batch, nodes, features)
[docs] def forward(self, seq, adj, sparse=False): seq_fts = self.fc(seq) if sparse: out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0) else: out = torch.bmm(adj, seq_fts) if self.bias is not None: out += self.bias return self.act(out)
# Borrowed from https://github.com/PetarV-/DGI
[docs]class AvgReadout(nn.Module): def __init__(self): super(AvgReadout, self).__init__()
[docs] def forward(self, seq, msk): if msk is None: return torch.mean(seq, 1) else: msk = torch.unsqueeze(msk, -1) return torch.sum(seq * msk, 1) / torch.sum(msk)
# Borrowed from https://github.com/PetarV-/DGI
[docs]class Discriminator(nn.Module): def __init__(self, n_h): super(Discriminator, self).__init__() self.f_k = nn.Bilinear(n_h, n_h, 1) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): if isinstance(m, nn.Bilinear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): c_x = torch.unsqueeze(c, 1) c_x = c_x.expand_as(h_pl) sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2) sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2) if s_bias1 is not None: sc_1 += s_bias1 if s_bias2 is not None: sc_2 += s_bias2 logits = torch.cat((sc_1, sc_2), 1) return logits
# Borrowed from https://github.com/PetarV-/DGI
[docs]class LogReg(nn.Module): def __init__(self, ft_in, nb_classes): super(LogReg, self).__init__() self.fc = nn.Linear(ft_in, nb_classes) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, seq): ret = self.fc(seq) return ret
[docs]class LogRegTrainer(object):
[docs] def train(self, data, labels, opt): device = data.device idx_train = opt["idx_train"].to(device) idx_val = opt["idx_val"].to(device) idx_test = opt["idx_test"].to(device) nclass = opt["num_classes"] nhid = data.shape[-1] labels = labels.to(device) train_embs = data[idx_train] val_embs = data[idx_val] test_embs = data[idx_test] train_lbls = labels[idx_train] val_lbls = labels[idx_val] test_lbls = labels[idx_test] tot = 0 xent = nn.CrossEntropyLoss() for _ in range(50): log = LogReg(nhid, nclass).to(device) optimizer = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0) log.to(device) for _ in range(100): log.train() optimizer.zero_grad() logits = log(train_embs) loss = xent(logits, train_lbls) loss.backward() optimizer.step() logits = log(test_embs) preds = torch.argmax(logits, dim=1) acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0] tot += acc.item() return tot / 50
# Borrowed from https://github.com/PetarV-/DGI
[docs]class DGIModel(nn.Module): def __init__(self, n_in, n_h, activation): super(DGIModel, self).__init__() self.gcn = GCN(n_in, n_h, activation) self.read = AvgReadout() self.sigm = nn.Sigmoid() self.disc = Discriminator(n_h)
[docs] def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2): h_1 = self.gcn(seq1, adj, sparse) c = self.read(h_1, msk) c = self.sigm(c) h_2 = self.gcn(seq2, adj, sparse) ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) return ret
# Detach the return variables
[docs] def embed(self, seq, adj, sparse, msk): h_1 = self.gcn(seq, adj, sparse) c = self.read(h_1, msk) return h_1.detach(), c.detach()
[docs]def preprocess_features(features): """Row-normalize feature matrix and convert to tuple representation""" rowsum = np.array(features.sum(1)) r_inv = np.power(rowsum, -1).flatten() r_inv[np.isinf(r_inv)] = 0. r_mat_inv = sp.diags(r_inv) features = r_mat_inv.dot(features) return features
[docs]def normalize_adj(adj): """Symmetrically normalize adjacency matrix.""" adj = sp.coo_matrix(adj) rowsum = np.array(adj.sum(1)) d_inv_sqrt = np.power(rowsum, -0.5).flatten() d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. d_mat_inv_sqrt = sp.diags(d_inv_sqrt) return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
[docs]def sparse_mx_to_torch_sparse_tensor(sparse_mx): """Convert a scipy sparse matrix to a torch sparse tensor.""" sparse_mx = sparse_mx.tocoo().astype(np.float32) indices = torch.from_numpy( np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) values = torch.from_numpy(sparse_mx.data) shape = torch.Size(sparse_mx.shape) return torch.sparse.FloatTensor(indices, values, shape)
[docs]@register_model("dgi") class DGI(BaseModel): @staticmethod
[docs] def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--hidden-size", type=int, default=512) parser.add_argument("--max-epochs", type=int, default=1000)
# fmt: on @classmethod
[docs] def build_model_from_args(cls, args): return cls(args.num_features, args.hidden_size, args.num_classes, args.max_epochs)
def __init__(self, nfeat, nhid, nclass, max_epochs): super(DGI, self).__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = DGIModel(nfeat, nhid, 'prelu').to(self.device) self.nhid = nhid self.nclass = nclass self.epochs = max_epochs self.patience = 20
[docs] def train(self, data): num_nodes = data.x.shape[0] features = preprocess_features(data.x.cpu().numpy()) features = torch.FloatTensor(features).unsqueeze(0).to(self.device) adj = sp.coo_matrix( (np.ones(data.edge_index.shape[1]), data.edge_index.cpu()), (num_nodes, num_nodes), ) adj = normalize_adj(adj + sp.eye(adj.shape[0])) sp_adj = sparse_mx_to_torch_sparse_tensor(adj) sp_adj = sp_adj.to(self.device) best = 1e9 best_t = 0 cnt_wait = 0 b_xent = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0.0) epoch_iter = tqdm(range(self.epochs)) for epoch in epoch_iter: self.model.train() optimizer.zero_grad() idx = np.random.permutation(num_nodes) shuf_fts = features[:, idx, :] lbl_1 = torch.ones(1, num_nodes) lbl_2 = torch.zeros(1, num_nodes) lbl = torch.cat((lbl_1, lbl_2), 1) shuf_fts = shuf_fts.to(self.device) lbl = lbl.to(self.device) logits = self.model(features, shuf_fts, sp_adj, True, None, None, None) loss = b_xent(logits, lbl) epoch_iter.set_description(f'Epoch: {epoch:03d}, Loss: {loss.item()}') if loss < best: best = loss best_t = epoch cnt_wait = 0 else: cnt_wait += 1 if cnt_wait == self.patience: print('Early stopping!') break loss.backward() optimizer.step() embeds, _ = self.model.embed(features, sp_adj, True, None) opt = { "idx_train": data.train_mask, "idx_val": data.val_mask, "idx_test": data.test_mask, "num_classes": self.nclass } result = LogRegTrainer().train(embeds[0], data.y, opt) return result