Source code for layers.maggregator

import torch
import torch.nn as nn


[docs]class MeanAggregator(torch.nn.Module): def __init__( self, in_channels, out_channels, improved=False, cached=False, bias=True ): super(MeanAggregator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.cached_result = None self.linear = nn.Linear(in_channels, out_channels, bias) @staticmethod
[docs] def norm(x, edge_index): # here edge_index is already a sparse tensor deg = torch.sparse.sum(edge_index, 1) deg_inv = deg.pow(-1).to_dense() x = torch.matmul(edge_index, x) # print(x,deg_inv) x = x.t() * deg_inv # x:512*dim, edge_weight:256*512 return x.t()
[docs] def forward(self, x, edge_index, edge_weight=None, bias=True): """""" x = self.linear(x) x = self.norm(x, edge_index) return x
[docs] def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out
[docs] def __repr__(self): return "{}({}, {})".format( self.__class__.__name__, self.in_channels, self.out_channels
)