Source code for cogdl.models.nn.mvgrl

import math

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import fractional_matrix_power, inv
from sklearn.preprocessing import MinMaxScaler
from torch.nn.parameter import Parameter
from tqdm import tqdm

from .. import BaseModel, register_model
from .dgi import GCN, AvgReadout, LogReg


# Borrowed from https://github.com/kavehhassani/mvgrl
[docs]class Discriminator(nn.Module): def __init__(self, n_h): super(Discriminator, self).__init__() self.f_k = nn.Bilinear(n_h, n_h, 1) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): if isinstance(m, nn.Bilinear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, c1, c2, h1, h2, h3, h4, s_bias1=None, s_bias2=None): c_x1 = torch.unsqueeze(c1, 1) c_x1 = c_x1.expand_as(h1).contiguous() c_x2 = torch.unsqueeze(c2, 1) c_x2 = c_x2.expand_as(h2).contiguous() # positive sc_1 = torch.squeeze(self.f_k(h2, c_x1), 2) sc_2 = torch.squeeze(self.f_k(h1, c_x2), 2) # negetive sc_3 = torch.squeeze(self.f_k(h4, c_x1), 2) sc_4 = torch.squeeze(self.f_k(h3, c_x2), 2) logits = torch.cat((sc_1, sc_2, sc_3, sc_4), 1) return logits
# Borrowed from https://github.com/kavehhassani/mvgrl
[docs]class Model(nn.Module): def __init__(self, n_in, n_h): super(Model, self).__init__() self.gcn1 = GCN(n_in, n_h, 'prelu') self.gcn2 = GCN(n_in, n_h, 'prelu') self.read = AvgReadout() self.sigm = nn.Sigmoid() self.disc = Discriminator(n_h)
[docs] def forward(self, seq1, seq2, adj, diff, sparse, msk, samp_bias1, samp_bias2): h_1 = self.gcn1(seq1, adj, sparse) c_1 = self.read(h_1, msk) c_1 = self.sigm(c_1) h_2 = self.gcn2(seq1, diff, sparse) c_2 = self.read(h_2, msk) c_2 = self.sigm(c_2) h_3 = self.gcn1(seq2, adj, sparse) h_4 = self.gcn2(seq2, diff, sparse) ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4, samp_bias1, samp_bias2) return ret, h_1, h_2
[docs] def embed(self, seq, adj, diff, sparse, msk): h_1 = self.gcn1(seq, adj, sparse) c = self.read(h_1, msk) h_2 = self.gcn2(seq, diff, sparse) return (h_1 + h_2).detach(), c.detach()
[docs]def preprocess_features(features): """Row-normalize feature matrix and convert to tuple representation""" rowsum = np.array(features.sum(1)) r_inv = np.power(rowsum, -1).flatten() r_inv[np.isinf(r_inv)] = 0. r_mat_inv = sp.diags(r_inv) features = r_mat_inv.dot(features) return features
[docs]def normalize_adj(adj): """Symmetrically normalize adjacency matrix.""" adj = sp.coo_matrix(adj) rowsum = np.array(adj.sum(1)) d_inv_sqrt = np.power(rowsum, -0.5).flatten() d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. d_mat_inv_sqrt = sp.diags(d_inv_sqrt) return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
[docs]def sparse_mx_to_torch_sparse_tensor(sparse_mx): """Convert a scipy sparse matrix to a torch sparse tensor.""" sparse_mx = sparse_mx.tocoo().astype(np.float32) indices = torch.from_numpy( np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) values = torch.from_numpy(sparse_mx.data) shape = torch.Size(sparse_mx.shape) return torch.sparse.FloatTensor(indices, values, shape)
[docs]def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True): a = nx.convert_matrix.to_numpy_array(graph) if self_loop: a = a + np.eye(a.shape[0]) # A^ = A + I_n d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1
[docs]@register_model("mvgrl") class MVGRL(BaseModel): @staticmethod
[docs] def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--hidden-size", type=int, default=512) parser.add_argument("--max-epochs", type=int, default=1000)
# fmt: on @classmethod
[docs] def build_model_from_args(cls, args): return cls(args.num_features, args.hidden_size, args.num_classes, args.max_epochs)
def __init__(self, nfeat, nhid, nclass, max_epochs): super(MVGRL, self).__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = Model(nfeat, nhid).to(self.device) self.nhid = nhid self.nclass = nclass self.epochs = max_epochs self.patience = 50
[docs] def train(self, data, dataset_name): num_nodes = data.x.shape[0] features = preprocess_features(data.x.numpy()) adj = sp.coo_matrix( (np.ones(data.edge_index.shape[1]), data.edge_index), (num_nodes, num_nodes), ) adj = normalize_adj(adj + sp.eye(adj.shape[0])).todense() g = nx.Graph() g.add_nodes_from(list(range(num_nodes))) g.add_edges_from(data.edge_index.numpy().transpose()) diff = compute_ppr(g, 0.2) if dataset_name == 'citeseer': epsilons = [1e-5, 1e-4, 1e-3, 1e-2] avg_degree = np.sum(adj) / adj.shape[0] epsilon = epsilons[np.argmin([abs(avg_degree - np.argwhere(diff >= e).shape[0] / diff.shape[0]) for e in epsilons])] diff[diff < epsilon] = 0.0 scaler = MinMaxScaler() scaler.fit(diff) diff = scaler.transform(diff) best = 1e9 best_t = 0 cnt_wait = 0 sparse = False b_xent = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0.0) ft_size = features.shape[1] sample_size = 2000 batch_size = 4 lbl_1 = torch.ones(batch_size, sample_size * 2) lbl_2 = torch.zeros(batch_size, sample_size * 2) lbl = torch.cat((lbl_1, lbl_2), 1) lbl = lbl.to(self.device) epoch_iter = tqdm(range(self.epochs)) for epoch in epoch_iter: self.model.train() optimizer.zero_grad() idx = np.random.randint(0, adj.shape[-1] - sample_size + 1, batch_size) ba, bd, bf = [], [], [] for i in idx: ba.append(adj[i: i + sample_size, i: i + sample_size]) bd.append(diff[i: i + sample_size, i: i + sample_size]) bf.append(features[i: i + sample_size]) ba = np.array(ba).reshape(batch_size, sample_size, sample_size) bd = np.array(bd) bd = bd.reshape(batch_size, sample_size, sample_size) bf = np.array(bf).reshape(batch_size, sample_size, ft_size) if sparse: ba = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(ba)) bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd)) else: ba = torch.FloatTensor(ba) bd = torch.FloatTensor(bd) bf = torch.FloatTensor(bf) idx = np.random.permutation(sample_size) shuf_fts = bf[:, idx, :] bf = bf.to(self.device) ba = ba.to(self.device) bd = bd.to(self.device) shuf_fts = shuf_fts.to(self.device) logits, _, _ = self.model(bf, shuf_fts, ba, bd, sparse, None, None, None) loss = b_xent(logits, lbl) epoch_iter.set_description(f'Epoch: {epoch:03d}, Loss: {loss.item()}') if loss < best: best = loss best_t = epoch cnt_wait = 0 else: cnt_wait += 1 if cnt_wait == self.patience: print('Early stopping!') break loss.backward() optimizer.step() if sparse: adj = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(adj)) diff = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(diff)) features = torch.FloatTensor(features[np.newaxis]).to(self.device) adj = torch.FloatTensor(adj[np.newaxis]).to(self.device) diff = torch.FloatTensor(diff[np.newaxis]).to(self.device) embeds, _ = self.model.embed(features, adj, diff, sparse, None) idx_train = data.train_mask.to(self.device) idx_val = data.val_mask.to(self.device) idx_test = data.test_mask.to(self.device) labels = data.y.to(self.device) train_embs = embeds[0, idx_train] val_embs = embeds[0, idx_val] test_embs = embeds[0, idx_test] train_lbls = labels[idx_train] val_lbls = labels[idx_val] test_lbls = labels[idx_test] tot = 0 xent = nn.CrossEntropyLoss() wd = 0.01 if dataset_name == 'citeseer' else 0.0 for _ in range(50): log = LogReg(self.nhid, self.nclass) opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=wd) log.to(self.device) for _ in range(300): log.train() opt.zero_grad() logits = log(train_embs) loss = xent(logits, train_lbls) loss.backward() opt.step() logits = log(test_embs) preds = torch.argmax(logits, dim=1) acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0] tot += acc.item() return tot / 50