Source code for cogdl.models.nn.pairnorm

import torch as torch
import torch.nn as nn

from .. import BaseModel, register_model
from .gcn import GraphConvolution
from cogdl.utils import row_normalization, spmm, edge_softmax


class GraphAttConv(nn.Module):
    def __init__(self, in_features, out_features, heads, dropout):
        super(GraphAttConv, self).__init__()
        assert out_features % heads == 0
        out_perhead = out_features // heads

        self.graph_atts = nn.ModuleList(
            [GraphAttConvOneHead(in_features, out_perhead, dropout=dropout) for _ in range(heads)]
        )

        self.in_features = in_features
        self.out_perhead = out_perhead
        self.heads = heads

    def forward(self, input, adj):
        output = torch.cat([att(input, adj) for att in self.graph_atts], dim=1)
        # notice that original GAT use elu as activation func.
        return output

    def __repr__(self):
        return self.__class__.__name__ + "({}->[{}x{}])".format(self.in_features, self.heads, self.out_perhead)


class GraphAttConvOneHead(nn.Module):
    """
    Sparse version GAT layer, single head
    """

    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2):
        super(GraphAttConvOneHead, self).__init__()
        self.weight = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features)))
        # init
        nn.init.xavier_normal_(self.weight.data, gain=nn.init.calculate_gain("relu"))  # look at here
        nn.init.xavier_normal_(self.a.data, gain=nn.init.calculate_gain("relu"))

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, input, edge_index):
        h = torch.mm(input, self.weight)
        # Self-attention on the nodes - Shared attention mechanism
        # edge_h: 2*D x E
        edge_h = torch.cat((h[edge_index[0, :], :], h[edge_index[1, :], :]), dim=1).t()
        # do softmax for each row, this need index of each row, and for each row do softmax over it
        alpha = self.leakyrelu(self.a.mm(edge_h).squeeze())  # E
        n = len(input)
        alpha = edge_softmax(edge_index, alpha, shape=(n, n))
        output = spmm(edge_index, self.dropout(alpha), h)  # h_prime: N x out
        # output = spmm(edge, self.dropout(alpha), n, n, self.dropout(h)) # h_prime: N x out
        return output


class PairNormNorm(nn.Module):
    def __init__(self, mode="PN", scale=1):
        """
        mode:
          "None" : No normalization
          "PN"   : Original version
          "PN-SI"  : Scale-Individually version
          "PN-SCS" : Scale-and-Center-Simultaneously version

        ("SCS"-mode is not in the paper but we found it works well in practice,
          especially for GCN and GAT.)
        PairNormNorm is typically used after each graph convolution operation.
        """
        assert mode in ["None", "PN", "PN-SI", "PN-SCS"]
        super(PairNormNorm, self).__init__()
        self.mode = mode
        self.scale = scale

        # Scale can be set based on origina data, and also the current feature lengths.
        # We leave the experiments to future. A good pool we used for choosing scale:
        # [0.1, 1, 10, 50, 100]

    def forward(self, x):
        if self.mode == "None":
            return x

        col_mean = x.mean(dim=0)
        if self.mode == "PN":
            x = x - col_mean
            rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
            x = self.scale * x / rownorm_mean

        if self.mode == "PN-SI":
            x = x - col_mean
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual

        if self.mode == "PN-SCS":
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual - col_mean

        return x


class SGC(nn.Module):
    # for SGC we use data without normalization
    def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, norm_mode="None", norm_scale=10, **kwargs):
        super(SGC, self).__init__()
        self.linear = torch.nn.Linear(nfeat, nclass)
        self.norm = PairNormNorm(norm_mode, norm_scale)
        self.dropout = nn.Dropout(p=dropout)
        self.nlayer = nlayer

    def forward(self, x, edge_index, edge_attr):
        x = self.norm(x)
        for _ in range(self.nlayer):
            x = spmm(edge_index, edge_attr, x)
            x = self.norm(x)
        x = self.dropout(x)
        x = self.linear(x)
        return x


class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, norm_mode="None", norm_scale=1, **kwargs):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU(True)
        self.norm = PairNormNorm(norm_mode, norm_scale)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.dropout(x)
        x = self.gc1(x, edge_index, edge_attr)
        x = self.norm(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.gc2(x, edge_index, edge_attr)
        return x


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, nhead, norm_mode="None", norm_scale=1, **kwargs):
        super(GAT, self).__init__()
        alpha_droprate = dropout
        self.gac1 = GraphAttConv(nfeat, nhid, nhead, alpha_droprate)
        self.gac2 = GraphAttConv(nhid, nclass, 1, alpha_droprate)

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ELU(True)
        self.norm = PairNormNorm(norm_mode, norm_scale)

    def forward(self, x, adj):
        x = self.dropout(x)  # ?
        x = self.gac1(x, adj)
        x = self.norm(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.gac2(x, adj)
        return x


class DeepGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0, norm_mode="None", norm_scale=1, **kwargs):
        super(DeepGCN, self).__init__()
        assert nlayer >= 1
        self.hidden_layers = nn.ModuleList(
            [GraphConvolution(nfeat if i == 0 else nhid, nhid) for i in range(nlayer - 1)]
        )
        self.out_layer = GraphConvolution(nfeat if nlayer == 1 else nhid, nclass)

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU(True)
        self.norm = PairNormNorm(norm_mode, norm_scale)
        self.skip = residual

    def forward(self, x, edge_index, edge_attr):
        x_old = 0
        for i, layer in enumerate(self.hidden_layers):
            x = self.dropout(x)
            x = layer(x, edge_index, edge_attr)
            x = self.norm(x)
            x = self.relu(x)
            if self.skip > 0 and i % self.skip == 0:
                x = x + x_old
                x_old = x

        x = self.dropout(x)
        x = self.out_layer(x, edge_index, edge_attr)
        return x


class DeepGAT(nn.Module):
    def __init__(
        self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0, nhead=1, norm_mode="None", norm_scale=1, **kwargs
    ):
        super(DeepGAT, self).__init__()
        assert nlayer >= 1
        alpha_droprate = dropout
        self.hidden_layers = nn.ModuleList(
            [GraphAttConv(nfeat if i == 0 else nhid, nhid, nhead, alpha_droprate) for i in range(nlayer - 1)]
        )
        self.out_layer = GraphAttConv(nfeat if nlayer == 1 else nhid, nclass, 1, alpha_droprate)

        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ELU(True)
        self.norm = PairNormNorm(norm_mode, norm_scale)
        self.skip = residual

    def forward(self, x, edge_index, edge_attr=None):
        x_old = 0
        for i, layer in enumerate(self.hidden_layers):
            x = self.dropout(x)
            x = layer(x, edge_index)
            x = self.norm(x)
            x = self.relu(x)
            if self.skip > 0 and i % self.skip == 0:
                x = x + x_old
                x_old = x

        x = self.dropout(x)
        x = self.out_layer(x, edge_index)
        return x


[docs]@register_model("pairnorm") class PairNorm(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--pn_model", type=str, default="GCN", help="{SGC, DeepGCN, DeepGAT}") parser.add_argument("--hidden_layers", type=int, default=64, help="Number of hidden units.") parser.add_argument("--nhead", type=int, default=1, help="Number of head attentions.") parser.add_argument("--dropout", type=float, default=0.6, help="Dropout rate.") parser.add_argument("--nlayer", type=int, default=2, help="Number of layers, works for Deep model.") parser.add_argument("--residual", type=int, default=0, help="Residual connection") parser.add_argument( "--norm_mode", type=str, default="None", help="Mode for PairNorm, {None, PN, PN-SI, PN-SCS}" ) parser.add_argument("--norm_scale", type=float, default=1.0, help="Row-normalization scale")
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.pn_model, args.hidden_layers, args.nhead, args.dropout, args.nlayer, args.residual, args.norm_mode, args.norm_scale, args.num_features, args.num_classes, )
def __init__( self, pn_model, hidden_layers, nhead, dropout, nlayer, residual, norm_mode, norm_scale, num_features, num_classes, ): super(PairNorm, self).__init__() self.edge_attr = None if pn_model == "GCN": self.pn_model = GCN(num_features, hidden_layers, num_classes, dropout, norm_mode, norm_scale) elif pn_model == "SGC": self.pn_model = SGC(num_features, hidden_layers, num_classes, dropout, nlayer, norm_mode, norm_scale) elif pn_model == "DeepGCN": self.pn_model = DeepGCN( num_features, hidden_layers, num_classes, dropout, nlayer, residual, norm_mode, norm_scale ) else: self.pn_model = DeepGAT( num_features, hidden_layers, num_classes, dropout, nlayer, residual, nhead, norm_mode, norm_scale )
[docs] def forward(self, x, edge_index): if self.edge_attr is None: self.edge_attr = row_normalization(x.shape[0], edge_index) edge_attr = self.edge_attr return self.pn_model(x, edge_index, edge_attr)
[docs] def predict(self, data): return self.forward(data.x, data.edge_index)