import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy.core.records import array
from torch.autograd import Function
from torch.nn.parameter import Parameter
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_max_pool as gmp
from torch_geometric.nn import global_mean_pool as gap
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.utils import add_remaining_self_loops, dense_to_sparse, softmax
from torch_scatter import scatter_add, scatter_max
from torch_sparse import coalesce, spspmm
from cogdl.utils import split_dataset_general
from .. import BaseModel, register_model
def scatter_sort(x, batch, fill_value=-1e16):
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
sorted_x, _ = dense_x.sort(dim=-1, descending=True)
cumsum_sorted_x = sorted_x.cumsum(dim=-1)
cumsum_sorted_x = cumsum_sorted_x.view(-1)
sorted_x = sorted_x.view(-1)
filled_index = sorted_x != fill_value
sorted_x = sorted_x[filled_index]
cumsum_sorted_x = cumsum_sorted_x[filled_index]
return sorted_x, cumsum_sorted_x
def _make_ix_like(batch):
num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
idx = torch.cat(idx, dim=0)
return idx
def _threshold_and_support(x, batch):
"""Sparsemax building block: compute the threshold
Args:
x: input tensor to apply the sparsemax
batch: group indicators
Returns:
the threshold value
"""
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
sorted_input, input_cumsum = scatter_sort(x, batch)
input_cumsum = input_cumsum - 1.0
rhos = _make_ix_like(batch).to(x.dtype)
support = rhos * sorted_input > input_cumsum
support_size = scatter_add(support.to(batch.dtype), batch)
# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
idx = support_size + cum_num_nodes - 1
mask = idx < 0
idx[mask] = 0
tau = input_cumsum.gather(0, idx)
tau /= support_size.to(x.dtype)
return tau, support_size
class SparsemaxFunction(Function):
@staticmethod
def forward(ctx, x, batch):
"""sparsemax: normalizing sparse transform
Parameters:
ctx: context object
x (Tensor): shape (N, )
batch: group indicator
Returns:
output (Tensor): same shape as input
"""
max_val, _ = scatter_max(x, batch)
x -= max_val[batch]
tau, supp_size = _threshold_and_support(x, batch)
output = torch.clamp(x - tau[batch], min=0)
ctx.save_for_backward(supp_size, output, batch)
return output
@staticmethod
def backward(ctx, grad_output):
supp_size, output, batch = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype)
grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)
return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(nn.Module):
def __init__(self):
super(Sparsemax, self).__init__()
def forward(self, x, batch):
return sparsemax(x, batch)
class TwoHopNeighborhood(object):
def __call__(self, data):
edge_index, edge_attr = data.edge_index, data.edge_attr
n = data.num_nodes
value = edge_index.new_ones((edge_index.size(1),), dtype=torch.float)
index, value = spspmm(edge_index, value, edge_index, value, n, n, n)
value.fill_(0)
edge_index = torch.cat([edge_index, index], dim=1)
if edge_attr is None:
data.edge_index, _ = coalesce(edge_index, None, n, n)
else:
value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
value = value.expand(-1, *list(edge_attr.size())[1:])
edge_attr = torch.cat([edge_attr, value], dim=0)
data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n)
data.edge_attr = edge_attr
return data
def __repr__(self):
return "{}()".format(self.__class__.__name__)
class GCN(MessagePassing):
def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs):
super(GCN, self).__init__(aggr="add", **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.cached = cached
self.cached_result = None
self.cached_num_edges = None
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
nn.init.xavier_uniform_(self.weight.data)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
nn.init.zeros_(self.bias.data)
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
x = torch.matmul(x, self.weight)
if isinstance(edge_index, tuple):
edge_index = torch.stack(edge_index)
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
"Cached {} number of edges, but found {}".format(self.cached_num_edges, edge_index.size(1))
)
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)
class NodeInformationScore(MessagePassing):
def __init__(self, improved=False, cached=False, **kwargs):
super(NodeInformationScore, self).__init__(aggr="add", **kwargs)
self.improved = improved
self.cached = cached
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes)
row, col = edge_index
expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
return (
edge_index,
expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col],
)
def forward(self, x, edge_index, edge_weight):
if isinstance(edge_index, tuple):
edge_index = torch.stack(edge_index)
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
"Cached {} number of edges, but found {}".format(self.cached_num_edges, edge_index.size(1))
)
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class HGPSLPool(torch.nn.Module):
def __init__(
self,
in_channels,
ratio=0.8,
sample=False,
sparse=False,
sl=True,
lamb=1.0,
negative_slop=0.2,
):
super(HGPSLPool, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.sample = sample
self.sparse = sparse
self.sl = sl
self.negative_slop = negative_slop
self.lamb = lamb
self.att = Parameter(torch.Tensor(1, self.in_channels * 2))
nn.init.xavier_uniform_(self.att.data)
self.sparse_attention = Sparsemax()
self.neighbor_augment = TwoHopNeighborhood()
self.calc_information_score = NodeInformationScore()
def forward(self, x, edge_index, edge_attr, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
x_information_score = self.calc_information_score(x, edge_index, edge_attr)
score = torch.sum(torch.abs(x_information_score), dim=1)
# Graph Pooling
original_x = x
perm = topk(score, self.ratio, batch)
x = x[perm]
batch = batch[perm]
induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))
# Discard structure learning layer, directly return
if self.sl is False:
return x, induced_edge_index, induced_edge_attr, batch
# Structure Learning
if self.sample:
# A fast mode for large graphs.
# In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
# To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them.
k_hop = 3
if edge_attr is None:
edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)
hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
for _ in range(k_hop - 1):
hop_data = self.neighbor_augment(hop_data)
hop_edge_index = hop_data.edge_index
hop_edge_attr = hop_data.edge_attr
new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0))
new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
row, col = new_edge_index
weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
adj[row, col] = weights
new_edge_index, weights = dense_to_sparse(adj)
row, col = new_edge_index
if self.sparse:
new_edge_attr = self.sparse_attention(weights, row)
else:
new_edge_attr = softmax(weights, row, x.size(0))
# filter out zero weight edges
adj[row, col] = new_edge_attr
new_edge_index, new_edge_attr = dense_to_sparse(adj)
# release gpu memory
del adj
torch.cuda.empty_cache()
else:
# Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
if edge_attr is None:
induced_edge_attr = torch.ones(
(induced_edge_index.size(1),),
dtype=x.dtype,
device=induced_edge_index.device,
)
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
cum_num_nodes = num_nodes.cumsum(dim=0)
adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
# Construct batch fully connected graph in block diagonal matirx format
for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
adj[idx_i:idx_j, idx_i:idx_j] = 1.0
new_edge_index, _ = dense_to_sparse(adj)
row, col = new_edge_index
weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop)
adj[row, col] = weights
induced_row, induced_col = induced_edge_index
adj[induced_row, induced_col] += induced_edge_attr * self.lamb
weights = adj[row, col]
if self.sparse:
new_edge_attr = self.sparse_attention(weights, row)
else:
new_edge_attr = softmax(weights, row, x.size(0))
# filter out zero weight edges
adj[row, col] = new_edge_attr
new_edge_index, new_edge_attr = dense_to_sparse(adj)
# release gpu memory
del adj
torch.cuda.empty_cache()
return x, new_edge_index, new_edge_attr, batch
[docs]@register_model("hgpsl")
class HGPSL(BaseModel):
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--hidden-size", type=int, default=128)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--pooling", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--train-ratio", type=float, default=0.8)
parser.add_argument("--test-ratio", type=float, default=0.1)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--sample_neighbor', type=bool, default=True)
parser.add_argument('--sparse_attention', type=bool, default=True)
parser.add_argument('--structure_learning', type=bool, default=True)
parser.add_argument('--lamb', type=float, default=1.0)
parser.add_argument('--patience', type=int, default=100)
parser.add_argument('--seed', type=array, default=[777], help='random seed')
# fmt: on
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(
args.num_features,
args.num_classes,
args.hidden_size,
args.dropout,
args.pooling,
args.sample_neighbor,
args.sparse_attention,
args.structure_learning,
args.lamb,
)
[docs] @classmethod
def split_dataset(cls, dataset, args):
return split_dataset_general(dataset, args)
def __init__(
self,
num_features,
num_classes,
hidden_size,
dropout,
pooling,
sample_neighbor,
sparse_attention,
structure_learning,
lamb,
):
super(HGPSL, self).__init__()
self.num_features = num_features
self.hidden_size = hidden_size
self.num_classes = num_classes
self.pooling = pooling
self.dropout = dropout
self.sample = sample_neighbor
self.sparse = sparse_attention
self.sl = structure_learning
self.lamb = lamb
self.conv1 = GCNConv(self.num_features, self.hidden_size)
self.conv2 = GCN(self.hidden_size, self.hidden_size)
self.conv3 = GCN(self.hidden_size, self.hidden_size)
self.pool1 = HGPSLPool(
self.hidden_size,
self.pooling,
self.sample,
self.sparse,
self.sl,
self.lamb,
)
self.pool2 = HGPSLPool(
self.hidden_size,
self.pooling,
self.sample,
self.sparse,
self.sl,
self.lamb,
)
self.lin1 = torch.nn.Linear(self.hidden_size * 2, self.hidden_size)
self.lin2 = torch.nn.Linear(self.hidden_size, self.hidden_size // 2)
self.lin3 = torch.nn.Linear(self.hidden_size // 2, self.num_classes)
[docs] def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
if isinstance(edge_index, tuple):
edge_index = torch.stack(edge_index)
edge_attr = None
x = F.relu(self.conv1(x, edge_index, edge_attr))
x, edge_index, edge_attr, batch = self.pool1(x, edge_index, edge_attr, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(self.conv2(x, edge_index, edge_attr))
x, edge_index, edge_attr, batch = self.pool2(x, edge_index, edge_attr, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(self.conv3(x, edge_index, edge_attr))
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(x1) + F.relu(x2) + F.relu(x3)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(self.lin2(x))
x = F.dropout(x, p=self.dropout, training=self.training)
pred = self.lin3(x)
return pred