import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from .. import BaseModel, register_model
[docs]class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.normal_(-stdv, stdv)
if self.bias is not None:
self.bias.data.normal_(-stdv, stdv)
[docs] def forward(self, input, edge_index):
adj = torch.sparse_coo_tensor(
edge_index,
torch.ones(edge_index.shape[1]).float(),
(input.shape[0], input.shape[0]),
).to(input.device)
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
[docs] def __repr__(self):
return (
self.__class__.__name__
+ " ("
+ str(self.in_features)
+ " -> "
+ str(self.out_features)
+ ")"
)
[docs]@register_model("gcn")
class TKipfGCN(BaseModel):
@staticmethod
[docs] def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
# fmt: on
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout)
def __init__(self, nfeat, nhid, nclass, dropout):
super(TKipfGCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
self.dropout = dropout
# self.nonlinear = nn.SELU()
[docs] def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
# h1 = x
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
# x = F.relu(x)
# x = torch.sigmoid(x)
# return x
# h2 = x
return F.log_softmax(x, dim=-1)
[docs] def loss(self, data):
return F.nll_loss(
self.forward(data.x, data.edge_index)[data.train_mask],
data.y[data.train_mask],
)
[docs] def predict(self, data):
return self.forward(data.x, data.edge_index)