import torch
import torch.nn as nn
from cogdl.utils import spmm
[docs]class MeanAggregator(torch.nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super(MeanAggregator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cached_result = None
self.linear = nn.Linear(in_channels, out_channels, bias)
[docs] @staticmethod
def norm(graph, x):
graph.row_norm()
x = spmm(graph, x)
return x
[docs] def forward(self, graph, x):
x = self.linear(x)
x = self.norm(graph, x)
return x
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)
[docs]class SumAggregator(torch.nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super(SumAggregator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cached_result = None
self.linear = nn.Linear(in_channels, out_channels, bias)
[docs] @staticmethod
def aggr(graph, x):
x = spmm(graph, x)
return x
[docs] def forward(self, graph, x):
x = self.linear(x)
x = self.aggr(graph, x)
return x
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)