Source code for cogdl.models.nn.gin

import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import BaseModel, register_model
from .mlp import MLP
from cogdl.data import DataLoader
from cogdl.utils import spmm


def split_dataset_general(dataset, args):
    droplast = args.model == "diffpool"

    train_size = int(len(dataset) * args.train_ratio)
    test_size = int(len(dataset) * args.test_ratio)
    index = list(range(len(dataset)))
    random.shuffle(index)

    train_index = index[:train_size]
    test_index = index[-test_size:]

    bs = args.batch_size
    train_loader = DataLoader([dataset[i] for i in train_index], batch_size=bs, drop_last=droplast)
    test_loader = DataLoader([dataset[i] for i in test_index], batch_size=bs, drop_last=droplast)
    if args.train_ratio + args.test_ratio < 1:
        val_index = index[train_size:-test_size]
        valid_loader = DataLoader([dataset[i] for i in val_index], batch_size=bs, drop_last=droplast)
    else:
        valid_loader = test_loader
    return train_loader, valid_loader, test_loader


class GINLayer(nn.Module):
    r"""Graph Isomorphism Network layer from paper `"How Powerful are Graph
    Neural Networks?" <https://arxiv.org/pdf/1810.00826.pdf>`__.

    .. math::
        h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
        \mathrm{sum}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
        \right\}\right)\right)

    Parameters
    ----------
    apply_func : callable layer function)
        layer or function applied to update node feature
    eps : float32, optional
        Initial `\epsilon` value.
    train_eps : bool, optional
        If True, `\epsilon` will be a learnable parameter.
    """

    def __init__(self, apply_func=None, eps=0, train_eps=True):
        super(GINLayer, self).__init__()
        if train_eps:
            self.eps = torch.nn.Parameter(torch.FloatTensor([eps]))
        else:
            self.register_buffer("eps", torch.FloatTensor([eps]))
        self.apply_func = apply_func

    def forward(self, graph, x):
        out = (1 + self.eps) * x + spmm(graph, x)
        if self.apply_func is not None:
            out = self.apply_func(out)
        return out


[docs]@register_model("gin") class GIN(BaseModel): r"""Graph Isomorphism Network from paper `"How Powerful are Graph Neural Networks?" <https://arxiv.org/pdf/1810.00826.pdf>`__. Args: num_layers : int Number of GIN layers in_feats : int Size of each input sample out_feats : int Size of each output sample hidden_dim : int Size of each hidden layer dimension num_mlp_layers : int Number of MLP layers eps : float32, optional Initial `\epsilon` value, default: ``0`` pooling : str, optional Aggregator type to use, default: ``sum`` train_eps : bool, optional If True, `\epsilon` will be a learnable parameter, default: ``True`` """
[docs] @staticmethod def add_args(parser): parser.add_argument("--epsilon", type=float, default=0.0) parser.add_argument("--hidden-size", type=int, default=32) parser.add_argument("--num-layers", type=int, default=3) parser.add_argument("--num-mlp-layers", type=int, default=2) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--train-epsilon", dest="train_epsilon", action="store_false") parser.add_argument("--pooling", type=str, default="sum") parser.add_argument("--batch-size", type=int, default=128) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1)
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_layers, args.num_features, args.num_classes, args.hidden_size, args.num_mlp_layers, args.epsilon, args.pooling, args.train_epsilon, args.dropout, )
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__( self, num_layers, in_feats, out_feats, hidden_dim, num_mlp_layers, eps=0, pooling="sum", train_eps=False, dropout=0.5, ): super(GIN, self).__init__() self.gin_layers = nn.ModuleList() self.batch_norm = nn.ModuleList() self.num_layers = num_layers for i in range(num_layers - 1): if i == 0: mlp = MLP(in_feats, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm") else: mlp = MLP(hidden_dim, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm") self.gin_layers.append(GINLayer(mlp, eps, train_eps)) self.batch_norm.append(nn.BatchNorm1d(hidden_dim)) self.linear_prediction = nn.ModuleList() for i in range(self.num_layers): if i == 0: self.linear_prediction.append(nn.Linear(in_feats, out_feats)) else: self.linear_prediction.append(nn.Linear(hidden_dim, out_feats)) self.dropout = nn.Dropout(dropout) self.criterion = torch.nn.CrossEntropyLoss()
[docs] def forward(self, batch): h = batch.x device = h.device batchsize = int(torch.max(batch.batch)) + 1 layer_rep = [h] for i in range(self.num_layers - 1): h = self.gin_layers[i](batch, h) h = self.batch_norm[i](h) h = F.relu(h) layer_rep.append(h) final_score = 0 for i in range(self.num_layers): hsize = layer_rep[i].shape[1] output = torch.zeros(batchsize, layer_rep[i].shape[1]).to(device) pooled = output.scatter_add_(dim=0, index=batch.batch.view(-1, 1).repeat(1, hsize), src=layer_rep[i]) final_score += self.dropout(self.linear_prediction[i](pooled)) return final_score