Source code for cogdl.models.nn.pyg_sagpool

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool as gmp
from torch_geometric.nn import global_mean_pool as gap
from torch_geometric.nn.pool.topk_pool import filter_adj, topk

from cogdl.layers import GCNLayer
from cogdl.utils import split_dataset_general

from .. import BaseModel, register_model


class SAGPoolLayers(nn.Module):
    def __init__(self, nhid, ratio=0.8, Conv=GCNLayer, non_linearity=torch.tanh):
        super(SAGPoolLayers, self).__init__()
        self.nhid = nhid
        self.ratio = ratio
        self.score_layer = Conv(nhid, 1)
        self.non_linearity = non_linearity

    def forward(self, graph, x, batch=None):
        if batch is None:
            batch = graph.edge_index.new_zeros(x.size(0))
        score = self.score_layer(graph, x).squeeze()
        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(graph.edge_index, graph.edge_weight, perm, num_nodes=score.size(0))
        return x, edge_index, edge_attr, batch, perm


[docs]@register_model("sagpool") class SAGPoolNetwork(BaseModel):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--pooling-ratio", type=float, default=0.5) parser.add_argument("--pooling-layer-type", type=str, default="gcnconv") parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1)
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.dropout, args.pooling_ratio, args.pooling_layer_type, )
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__(self, nfeat, nhid, nclass, dropout, pooling_ratio, pooling_layer_type): def __get_layer_from_str__(str): if str == "gcnconv": return GCNLayer return GCNLayer super(SAGPoolNetwork, self).__init__() self.nfeat = nfeat self.nhid = nhid self.nclass = nclass self.dropout = dropout self.pooling_ratio = pooling_ratio self.conv_layer_1 = GCNLayer(self.nfeat, self.nhid) self.conv_layer_2 = GCNLayer(self.nhid, self.nhid) self.conv_layer_3 = GCNLayer(self.nhid, self.nhid) self.pool_layer_1 = SAGPoolLayers( self.nhid, Conv=__get_layer_from_str__(pooling_layer_type), ratio=self.pooling_ratio ) self.pool_layer_2 = SAGPoolLayers( self.nhid, Conv=__get_layer_from_str__(pooling_layer_type), ratio=self.pooling_ratio ) self.pool_layer_3 = SAGPoolLayers( self.nhid, Conv=__get_layer_from_str__(pooling_layer_type), ratio=self.pooling_ratio ) self.lin_layer_1 = torch.nn.Linear(self.nhid * 2, self.nhid) self.lin_layer_2 = torch.nn.Linear(self.nhid, self.nhid // 2) self.lin_layer_3 = torch.nn.Linear(self.nhid // 2, self.nclass)
[docs] def forward(self, batch): x = batch.x edge_index = batch.edge_index batch_h = batch.batch with batch.local_graph(): x = F.relu(self.conv_layer_1(batch, x)) x, edge_index, _, batch_h, _ = self.pool_layer_1(batch, x, batch_h) out = torch.cat([gmp(x, batch_h), gap(x, batch_h)], dim=1) batch.edge_index = edge_index x = F.relu(self.conv_layer_2(batch, x)) x, edge_index, _, batch_h, _ = self.pool_layer_2(batch, x, batch_h) out += torch.cat([gmp(x, batch_h), gap(x, batch_h)], dim=1) batch.edge_index = edge_index x = F.relu(self.conv_layer_3(batch, x)) x, edge_index, _, batch_h, _ = self.pool_layer_3(batch, x, batch_h) out += torch.cat([gmp(x, batch_h), gap(x, batch_h)], dim=1) out = F.relu(self.lin_layer_1(out)) out = F.dropout(out, p=self.dropout, training=self.training) out = F.relu(self.lin_layer_2(out)) return out