Source code for cogdl.models.nn.asgcn

import math
import random
import collections
import numpy as np
from scipy import sparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model

[docs]class GraphConvolution(nn.Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter("bias", None) self.reset_parameters()
[docs] def reset_parameters(self): stdv = 1.0 / math.sqrt(self.weight.size(1)) self.weight.data.normal_(-stdv, stdv) if self.bias is not None: self.bias.data.normal_(-stdv, stdv)
[docs] def forward(self, input, adj): support = torch.mm(input, self.weight) output = torch.spmm(adj, support) if self.bias is not None: return output + self.bias else: return output
[docs] def __repr__(self): return ( self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
)
[docs]@register_model("asgcn") class ASGCN(BaseModel): @staticmethod
[docs] def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--num-layers", type=int, default=3) parser.add_argument("--sample-size", type=int, nargs='+', default=[64,64,32])
# fmt: on @classmethod
[docs] def build_model_from_args(cls, args): return cls( args.num_features, args.num_classes, args.hidden_size, args.num_layers, args.dropout, args.sample_size,
) def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout, sample_size): super(ASGCN, self).__init__() self.num_features = num_features self.num_classes = num_classes self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.sample_size = sample_size self.w_s0 = Parameter(torch.FloatTensor(num_features)) self.w_s1 = Parameter(torch.FloatTensor(num_features)) shapes = [num_features] + [hidden_size] * (num_layers - 1) + [num_classes] self.convs = nn.ModuleList( [ GraphConvolution(shapes[layer], shapes[layer + 1]) for layer in range(num_layers) ] ) self.reset_parameters()
[docs] def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) self.w_s0.data.normal_(-stdv, stdv) self.w_s1.data.normal_(-stdv, stdv)
[docs] def set_adj(self, edge_index, num_nodes): self.sparse_adj = sparse.coo_matrix( (np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])), shape=(num_nodes, num_nodes), ).tocsr() self.num_nodes = num_nodes self.adj = self.compute_adjlist(self.sparse_adj) self.adj = torch.tensor(self.adj)
[docs] def compute_adjlist(self, sp_adj, max_degree=32): """Transfer sparse adjacent matrix to adj-list format""" num_data = sp_adj.shape[0] adj = num_data + np.zeros((num_data+1, max_degree), dtype=np.int32) for v in range(num_data): neighbors = np.nonzero(sp_adj[v, :])[1] len_neighbors = len(neighbors) if len_neighbors > max_degree: neighbors = np.random.choice(neighbors, max_degree, replace=False) adj[v] = neighbors else: adj[v, :len_neighbors] = neighbors return adj
[docs] def from_adjlist(self, adj): """Transfer adj-list format to sparsetensor""" u_sampled, index = torch.unique(torch.flatten(adj), return_inverse=True) row = (torch.range(0, index.shape[0]-1) / adj.shape[1]).long().to(adj.device) col = index values = torch.ones(index.shape[0]).float().to(adj.device) indices = torch.cat([row.unsqueeze(1), col.unsqueeze(1)], axis=1).t() dense_shape = (adj.shape[0], u_sampled.shape[0]) support = torch.sparse_coo_tensor(indices, values, dense_shape) return support, u_sampled.long()
[docs] def _sample_one_layer(self, x, adj, v, sample_size): support, u = self.from_adjlist(adj) h_v = torch.sum(torch.matmul(x[v], self.w_s1)) h_u = torch.matmul(x[u], self.w_s0) attention = (F.relu(h_v + h_u) + 1) * (1.0 / sample_size) g_u = F.relu(h_u) + 1 p1 = attention * g_u p1 = p1.cpu() if self.num_nodes in u: p1[u == self.num_nodes] = 0 p1 = p1 / torch.sum(p1) samples = torch.multinomial(p1, sample_size, False) u_sampled = u[samples] support_sampled = torch.index_select(support, 1, samples) return u_sampled, support_sampled
[docs] def sampling(self, x, v): all_support = [[] for _ in range(self.num_layers)] sampled = v x = torch.cat((x, torch.zeros(1, x.shape[1]).to(x.device)), dim=0) for i in range(self.num_layers - 1, -1, -1): cur_sampled, cur_support = self._sample_one_layer(x, self.adj[sampled], sampled, self.sample_size[i]) all_support[i] = cur_support.to(x.device) sampled = cur_sampled return x[sampled.to(x.device)], all_support, 0
[docs] def forward(self, x, adj): for index, conv in enumerate(self.convs[:-1]): x = F.relu(conv(x, adj[index])) x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[-1](x, adj[-1]) return F.log_softmax(x, dim=1)