Source code for cogdl.models.nn.sortpool

import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import BaseModel, register_model
from cogdl.data import DataLoader
from .graphsage import GraphSAGELayer
from .gin import split_dataset_general


def scatter_sum(src, index, dim, dim_size):
    size = list(src.size())
    if dim_size is not None:
        size[dim] = dim_size
    else:
        size[dim] = int(index.max()) + 1
    out = torch.zeros(size, dtype=src.dtype, device=src.device)
    return out.scatter_add_(dim, index, src)


def spare2dense_batch(x, batch=None, fill_value=0):
    batch_size = batch[-1] + 1
    batch_num_nodes = scatter_sum(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size)
    max_num_nodes = batch_num_nodes.max().item()
    batch_cum_nodes = torch.cat([batch.new_zeros(1), batch_num_nodes.cumsum(dim=0)])

    idx = torch.arange(x.size(0), dtype=torch.long, device=x.device)
    idx = idx - batch_cum_nodes[batch] + batch * max_num_nodes

    new_size = [batch_size * max_num_nodes, x.size(1)]
    out = x.new_full(new_size, fill_value)
    out[idx] = x
    out = out.view([batch_size, max_num_nodes, x.size(1)])
    return out


[docs]@register_model("sortpool") class SortPool(BaseModel): r"""Implimentation of sortpooling in paper `"An End-to-End Deep Learning Architecture for Graph Classification" <https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf>__.` Parameters ---------- in_feats : int Size of each input sample. out_feats : int Size of each output sample. hidden_dim : int Dimension of hidden layer embedding. num_classes : int Number of target classes. num_layers : int Number of graph neural network layers before pooling. k : int, optional Number of selected features to sort, default: ``30``. out_channel : int Number of the first convolution's output channels. kernel_size : int Size of the first convolution's kernel. dropout : float, optional Size of dropout, default: ``0.5``. """
[docs] @staticmethod def add_args(parser): parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--out-channels", type=int, default=32) parser.add_argument("--k", type=int, default=30) parser.add_argument("--kernel-size", type=int, default=5)
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.out_channels, args.kernel_size, args.k, args.dropout, )
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__(self, in_feats, hidden_dim, num_classes, num_layers, out_channel, kernel_size, k=30, dropout=0.5): super(SortPool, self).__init__() self.k = k self.dropout = dropout self.num_layers = num_layers self.gnn_convs = nn.ModuleList() self.gnn_convs.append(GraphSAGELayer(in_feats, hidden_dim)) for _ in range(self.num_layers - 1): self.gnn_convs.append(GraphSAGELayer(hidden_dim, hidden_dim)) self.conv1d = nn.Conv1d(hidden_dim, out_channel, kernel_size) self.fc1 = nn.Linear(out_channel * (self.k - kernel_size + 1), hidden_dim) self.fc2 = nn.Linear(hidden_dim, num_classes)
[docs] def forward(self, batch): h = batch.x for i in range(self.num_layers): h = self.gnn_convs[i](batch, h) h = F.relu(h) h, _ = h.sort(dim=-1) fill_value = h.min().item() - 1 batch_h = spare2dense_batch(h, batch.batch, fill_value) batch_size, num_nodes, xdim = batch_h.size() _, order = batch_h[:, :, -1].sort(dim=-1, descending=True) order = order + torch.arange(batch_size, dtype=torch.long, device=order.device).view(-1, 1) * num_nodes batch_h = batch_h.view(batch_size * num_nodes, xdim) batch_h = batch_h[order].view(batch_size, num_nodes, xdim) if num_nodes >= self.k: batch_h = batch_h[:, : self.k].contiguous() else: fill_batch = batch_h.new_full((batch_size, self.k - num_nodes, xdim), fill_value) batch_h = torch.cat([batch_h, fill_batch], dim=1) batch_h[batch_h == fill_value] = 0 h = batch_h # h = h.view(batch_size, self.k, -1).permute(0, 2, 1) # bn * hidden * k h = h.permute(0, 2, 1) # bn * hidden * k h = F.relu(self.conv1d(h)).view(batch_size, -1) h = F.relu(self.fc1(h)) h = F.dropout(h, p=self.dropout, training=self.training) h = self.fc2(h) return h