Source code for cogdl.models.nn.gat

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import BaseModel, register_model
from cogdl.utils import add_remaining_self_loops


[docs]class SpecialSpmmFunction(torch.autograd.Function): """Special function for only sparse region backpropataion layer.""" @staticmethod
[docs] def forward(ctx, indices, values, shape, b): assert indices.requires_grad == False a = torch.sparse_coo_tensor(indices, values, shape) ctx.save_for_backward(a, b) ctx.N = shape[0] return torch.matmul(a, b)
@staticmethod
[docs] def backward(ctx, grad_output): a, b = ctx.saved_tensors grad_values = grad_b = None if ctx.needs_input_grad[1]: grad_a_dense = grad_output.matmul(b.t()) edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] grad_values = grad_a_dense.view(-1)[edge_idx] if ctx.needs_input_grad[3]: grad_b = a.t().matmul(grad_output) return None, grad_values, None, grad_b
[docs]class SpecialSpmm(nn.Module):
[docs] def forward(self, indices, values, shape, b): return SpecialSpmmFunction.apply(indices, values, shape, b)
[docs]class SpGraphAttentionLayer(nn.Module): """ Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 """ def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(SpGraphAttentionLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.alpha = alpha self.concat = concat self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) nn.init.xavier_normal_(self.W.data, gain=1.414) self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) nn.init.xavier_normal_(self.a.data, gain=1.414) self.dropout = nn.Dropout(dropout) self.leakyrelu = nn.LeakyReLU(self.alpha) self.special_spmm = SpecialSpmm()
[docs] def forward(self, input, edge): N = input.size()[0] h = torch.mm(input, self.W) # h: N x out assert not torch.isnan(h).any() # Self-attention on the nodes - Shared attention mechanism edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() # edge: 2*D x E edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) assert not torch.isnan(edge_e).any() # edge_e: E e_rowsum = self.special_spmm( edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1)).to(input.device) ) # e_rowsum: N x 1 edge_e = self.dropout(edge_e) # edge_e: E h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) assert not torch.isnan(h_prime).any() # h_prime: N x out h_prime = h_prime.div(e_rowsum + 1e-8) # h_prime: N x out assert not torch.isnan(h_prime).any() if self.concat: # if this layer is not last layer, return F.elu(h_prime) else: # if this layer is last layer, return h_prime
[docs] def __repr__(self): return ( self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
)
[docs]@register_model("gat") class PetarVSpGAT(BaseModel): r"""The GAT model from the `"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>`_ paper Args: num_features (int) : Number of input features. num_classes (int) : Number of classes. hidden_size (int) : The dimension of node representation. dropout (float) : Dropout rate for model training. alpha (float) : Coefficient of leaky_relu. nheads (int) : Number of attention heads. """ @staticmethod
[docs] def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.6) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--nheads", type=int, default=8)
# fmt: on @classmethod
[docs] def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.dropout, args.alpha, args.nheads,
) def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): """Sparse version of GAT.""" super(PetarVSpGAT, self).__init__() self.dropout = dropout self.attentions = [ SpGraphAttentionLayer( nfeat, nhid, dropout=dropout, alpha=alpha, concat=True ) for _ in range(nheads) ] for i, attention in enumerate(self.attentions): self.add_module("attention_{}".format(i), attention) self.out_att = SpGraphAttentionLayer( nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False )
[docs] def forward(self, x, edge_index): edge_index, _ = add_remaining_self_loops(edge_index) x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) x = F.elu(self.out_att(x, edge_index)) return F.log_softmax(x, dim=1)
[docs] def loss(self, data): return F.nll_loss( self.forward(data.x, data.edge_index)[data.train_mask], data.y[data.train_mask],
)
[docs] def predict(self, data): return self.forward(data.x, data.edge_index)