import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import GCNConv
from .. import BaseModel, register_model
from cogdl.trainers.sampled_trainer import SAINTTrainer
[docs]@register_model("pyg_gcn")
class GCN(BaseModel):
@staticmethod
[docs] def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument('--sampler', default='none', type=str, help='graph samplers')
parser.add_argument('--sample-coverage', default=20, type=float, help='sample coverage ratio')
parser.add_argument('--size-subgraph', default=1200, type=int, help='subgraph size')
parser.add_argument('--num-walks', default=50, type=int, help='number of random walks')
parser.add_argument('--walk-length', default=20, type=int, help='random walk length')
parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks')
# fmt: on
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(
args.num_features,
args.num_classes,
args.hidden_size,
args.num_layers,
args.dropout,
)
[docs] def get_trainer(self, task, args):
if args.sampler != 'none':
return SAINTTrainer
else:
return None
def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout):
super(GCN, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
shapes = [num_features] + [hidden_size] * (num_layers - 1) + [num_classes]
self.convs = nn.ModuleList(
[
GCNConv(shapes[layer], shapes[layer + 1], cached=False)
for layer in range(num_layers)
]
)
[docs] def forward(self, x, edge_index, weight = None):
for conv in self.convs[:-1]:
x = F.relu(conv(x, edge_index, weight))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index, weight)
return F.log_softmax(x, dim=1)
[docs] def loss(self, data):
return F.nll_loss(
self.forward(data.x, data.edge_index)[data.train_mask],
data.y[data.train_mask],
)
[docs] def predict(self, data):
return self.forward(data.x, data.edge_index, None if not "norm_aggr" in data else data.norm_aggr)