import math
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.nn.parameter import Parameter
from .. import BaseModel, register_model
# Borrowed from https://github.com/PetarV-/DGI
[docs]class GCN(nn.Module):
def __init__(self, in_ft, out_ft, act, bias=True):
super(GCN, self).__init__()
self.fc = nn.Linear(in_ft, out_ft, bias=False)
self.act = nn.PReLU() if act == 'prelu' else act
if bias:
self.bias = nn.Parameter(torch.FloatTensor(out_ft))
self.bias.data.fill_(0.0)
else:
self.register_parameter('bias', None)
for m in self.modules():
self.weights_init(m)
[docs] def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
# Shape of seq: (batch, nodes, features)
[docs] def forward(self, seq, adj, sparse=False):
seq_fts = self.fc(seq)
if sparse:
out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0)
else:
out = torch.bmm(adj, seq_fts)
if self.bias is not None:
out += self.bias
return self.act(out)
# Borrowed from https://github.com/PetarV-/DGI
[docs]class AvgReadout(nn.Module):
def __init__(self):
super(AvgReadout, self).__init__()
[docs] def forward(self, seq, msk):
if msk is None:
return torch.mean(seq, 1)
else:
msk = torch.unsqueeze(msk, -1)
return torch.sum(seq * msk, 1) / torch.sum(msk)
# Borrowed from https://github.com/PetarV-/DGI
[docs]class Discriminator(nn.Module):
def __init__(self, n_h):
super(Discriminator, self).__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1)
for m in self.modules():
self.weights_init(m)
[docs] def weights_init(self, m):
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
[docs] def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
c_x = torch.unsqueeze(c, 1)
c_x = c_x.expand_as(h_pl)
sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)
sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)
if s_bias1 is not None:
sc_1 += s_bias1
if s_bias2 is not None:
sc_2 += s_bias2
logits = torch.cat((sc_1, sc_2), 1)
return logits
# Borrowed from https://github.com/PetarV-/DGI
[docs]class LogReg(nn.Module):
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(ft_in, nb_classes)
for m in self.modules():
self.weights_init(m)
[docs] def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
[docs] def forward(self, seq):
ret = self.fc(seq)
return ret
[docs]class LogRegTrainer(object):
[docs] def train(self, data, labels, opt):
device = data.device
idx_train = opt["idx_train"].to(device)
idx_val = opt["idx_val"].to(device)
idx_test = opt["idx_test"].to(device)
nclass = opt["num_classes"]
nhid = data.shape[-1]
labels = labels.to(device)
train_embs = data[idx_train]
val_embs = data[idx_val]
test_embs = data[idx_test]
train_lbls = labels[idx_train]
val_lbls = labels[idx_val]
test_lbls = labels[idx_test]
tot = 0
xent = nn.CrossEntropyLoss()
for _ in range(50):
log = LogReg(nhid, nclass).to(device)
optimizer = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
log.to(device)
for _ in range(100):
log.train()
optimizer.zero_grad()
logits = log(train_embs)
loss = xent(logits, train_lbls)
loss.backward()
optimizer.step()
logits = log(test_embs)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
tot += acc.item()
return tot / 50
# Borrowed from https://github.com/PetarV-/DGI
[docs]class DGIModel(nn.Module):
def __init__(self, n_in, n_h, activation):
super(DGIModel, self).__init__()
self.gcn = GCN(n_in, n_h, activation)
self.read = AvgReadout()
self.sigm = nn.Sigmoid()
self.disc = Discriminator(n_h)
[docs] def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2):
h_1 = self.gcn(seq1, adj, sparse)
c = self.read(h_1, msk)
c = self.sigm(c)
h_2 = self.gcn(seq2, adj, sparse)
ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)
return ret
# Detach the return variables
[docs] def embed(self, seq, adj, sparse, msk):
h_1 = self.gcn(seq, adj, sparse)
c = self.read(h_1, msk)
return h_1.detach(), c.detach()
[docs]def preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features
[docs]def normalize_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
[docs]def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
[docs]@register_model("dgi")
class DGI(BaseModel):
@staticmethod
[docs] def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--max-epochs", type=int, default=1000)
# fmt: on
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(args.num_features, args.hidden_size, args.num_classes, args.max_epochs)
def __init__(self, nfeat, nhid, nclass, max_epochs):
super(DGI, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = DGIModel(nfeat, nhid, 'prelu').to(self.device)
self.nhid = nhid
self.nclass = nclass
self.epochs = max_epochs
self.patience = 20
[docs] def train(self, data):
num_nodes = data.x.shape[0]
features = preprocess_features(data.x.cpu().numpy())
features = torch.FloatTensor(features).unsqueeze(0).to(self.device)
adj = sp.coo_matrix(
(np.ones(data.edge_index.shape[1]), data.edge_index.cpu()),
(num_nodes, num_nodes),
)
adj = normalize_adj(adj + sp.eye(adj.shape[0]))
sp_adj = sparse_mx_to_torch_sparse_tensor(adj)
sp_adj = sp_adj.to(self.device)
best = 1e9
best_t = 0
cnt_wait = 0
b_xent = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0.0)
epoch_iter = tqdm(range(self.epochs))
for epoch in epoch_iter:
self.model.train()
optimizer.zero_grad()
idx = np.random.permutation(num_nodes)
shuf_fts = features[:, idx, :]
lbl_1 = torch.ones(1, num_nodes)
lbl_2 = torch.zeros(1, num_nodes)
lbl = torch.cat((lbl_1, lbl_2), 1)
shuf_fts = shuf_fts.to(self.device)
lbl = lbl.to(self.device)
logits = self.model(features, shuf_fts, sp_adj, True, None, None, None)
loss = b_xent(logits, lbl)
epoch_iter.set_description(f'Epoch: {epoch:03d}, Loss: {loss.item()}')
if loss < best:
best = loss
best_t = epoch
cnt_wait = 0
else:
cnt_wait += 1
if cnt_wait == self.patience:
print('Early stopping!')
break
loss.backward()
optimizer.step()
embeds, _ = self.model.embed(features, sp_adj, True, None)
opt = {
"idx_train": data.train_mask,
"idx_val": data.val_mask,
"idx_test": data.test_mask,
"num_classes": self.nclass
}
result = LogRegTrainer().train(embeds[0], data.y, opt)
return result