import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .. import BaseModel, register_model
from cogdl.utils import add_remaining_self_loops, mul_edge_softmax, spmm
class GATLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(
self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False, fast_mode=False
):
super(GATLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.nhead = nhead
self.fast_mode = fast_mode
self.W = nn.Parameter(torch.FloatTensor(in_features, out_features * nhead))
self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))
self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
if residual:
out_features = out_features * nhead if concat else out_features
self.residual = nn.Linear(in_features, out_features)
else:
self.register_buffer("residual", None)
self.reset_parameters()
def reset_parameters(self):
def reset(tensor):
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
reset(self.a_l)
reset(self.a_r)
reset(self.W)
def forward(self, x, edge):
N = x.size()[0]
h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
# h: N * H * d
if torch.isnan(self.W.data).any():
# print("NaN in Graph Attention, ", self.nhead)
h[torch.isnan(h)] = 0
# Self-attention on the nodes - Shared attention mechanism
h_l = (self.a_l * h).sum(dim=-1)[edge[0, :]]
h_r = (self.a_r * h).sum(dim=-1)[edge[1, :]]
edge_attention = self.leakyrelu(h_l + h_r)
# edge_e: E * H
edge_attention = mul_edge_softmax(edge, edge_attention, shape=(N, N))
num_edges = edge.shape[1]
num_nodes = x.shape[0]
if self.fast_mode:
edge_attention = edge_attention.view(-1)
edge_attention = self.dropout(edge_attention)
edge_index = edge.view(-1)
edge_index = edge_index.unsqueeze(0).repeat(self.nhead, 1)
add_num = torch.arange(0, self.nhead * num_nodes, num_nodes).view(-1, 1).to(edge_index.device)
edge_index = edge_index + add_num
edge_index = edge_index.split((num_edges, num_edges), dim=1)
row, col = edge_index
row = row.reshape(-1)
col = col.reshape(-1)
edge_index = torch.stack([row, col])
h_prime = spmm(edge_index, edge_attention, h.permute(1, 0, 2).reshape(num_nodes * self.nhead, -1))
assert not torch.isnan(h_prime).any()
h_prime = h_prime.split([num_nodes] * self.nhead)
else:
h_prime = []
h = h.permute(1, 0, 2).contiguous()
for i in range(self.nhead):
edge_weight = edge_attention[:, i]
hidden = h[i]
assert not torch.isnan(hidden).any()
h_prime.append(spmm(edge, edge_weight, hidden))
if self.residual:
res = self.residual(x)
else:
res = 0
if self.concat:
# if this layer is not last layer,
out = torch.cat(h_prime, dim=1) + res
else:
# if this layer is last layer,
out = sum(h_prime) / self.nhead + res
return out
def __repr__(self):
return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
[docs]@register_model("gat")
class GAT(BaseModel):
r"""The GAT model from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
Args:
num_features (int) : Number of input features.
num_classes (int) : Number of classes.
hidden_size (int) : The dimension of node representation.
dropout (float) : Dropout rate for model training.
alpha (float) : Coefficient of leaky_relu.
nheads (int) : Number of attention heads.
"""
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--residual", action="store_true")
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.6)
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--last-nhead", type=int, default=1)
parser.add_argument("--fast-mode", action="store_true", default=False)
# fmt: on
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(
args.num_features,
args.hidden_size,
args.num_classes,
args.num_layers,
args.dropout,
args.alpha,
args.nhead,
args.residual,
args.last_nhead,
args.fast_mode,
)
def __init__(
self,
in_feats,
hidden_size,
out_features,
num_layers,
dropout,
alpha,
nhead,
residual,
last_nhead,
fast_mode=False,
):
"""Sparse version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = nn.ModuleList()
self.attentions.append(
GATLayer(
in_feats,
hidden_size,
nhead=nhead,
dropout=dropout,
alpha=alpha,
concat=True,
residual=residual,
fast_mode=fast_mode,
)
)
for i in range(num_layers - 2):
self.attentions.append(
GATLayer(
hidden_size * nhead,
hidden_size,
nhead=nhead,
dropout=dropout,
alpha=alpha,
concat=True,
residual=residual,
fast_mode=fast_mode,
)
)
self.attentions.append(
GATLayer(
hidden_size * nhead,
out_features,
dropout=dropout,
alpha=alpha,
concat=False,
nhead=last_nhead,
residual=False,
fast_mode=fast_mode,
)
)
self.num_layers = num_layers
self.last_nhead = last_nhead
self.residual = residual
[docs] def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
for i, layer in enumerate(self.attentions):
x = F.dropout(x, p=self.dropout, training=self.training)
x = layer(x, edge_index)
if i != self.num_layers - 1:
x = F.elu(x)
return x
[docs] def predict(self, data):
return self.forward(data.x, data.edge_index)