Source code for cogdl.models.nn.sgc

import torch.nn as nn

from .. import BaseModel, register_model
from cogdl.utils import spmm


class SimpleGraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, order=3):
        super(SimpleGraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.order = order
        self.W = nn.Linear(in_features, out_features)

    def forward(self, graph, x):
        output = self.W(x)
        for _ in range(self.order):
            output = spmm(graph, output)
        return output


[docs]@register_model("sgc") class sgc(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int)
[docs] @classmethod def build_model_from_args(cls, args): return cls(in_feats=args.num_features, out_feats=args.num_classes)
def __init__(self, in_feats, out_feats): super(sgc, self).__init__() self.nn = SimpleGraphConvolution(in_feats, out_feats) self.cache = dict()
[docs] def forward(self, graph): graph.sym_norm() x = self.nn(graph, graph.x) return x
[docs] def predict(self, data): return self.forward(data)