import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy.core.records import array
from torch.autograd import Function
from torch.nn.parameter import Parameter
from import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_max_pool as gmp
from torch_geometric.nn import global_mean_pool as gap
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.utils import add_remaining_self_loops, dense_to_sparse, softmax
from torch_scatter import scatter_add, scatter_max
from torch_sparse import coalesce, spspmm

from .. import BaseModel, register_model
from .gin import split_dataset_general

def scatter_sort(x, batch, fill_value=-1e16):
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()

    cum_num_nodes =[num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)

    index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

    dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
    dense_x[index] = x
    dense_x = dense_x.view(batch_size, max_num_nodes)

    sorted_x, _ = dense_x.sort(dim=-1, descending=True)
    cumsum_sorted_x = sorted_x.cumsum(dim=-1)
    cumsum_sorted_x = cumsum_sorted_x.view(-1)

    sorted_x = sorted_x.view(-1)
    filled_index = sorted_x != fill_value

    sorted_x = sorted_x[filled_index]
    cumsum_sorted_x = cumsum_sorted_x[filled_index]

    return sorted_x, cumsum_sorted_x

def _make_ix_like(batch):
    num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
    idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
    idx =, dim=0)

    return idx

def _threshold_and_support(x, batch):
    """Sparsemax building block: compute the threshold
        x: input tensor to apply the sparsemax
        batch: group indicators
        the threshold value
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    cum_num_nodes =[num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)

    sorted_input, input_cumsum = scatter_sort(x, batch)
    input_cumsum = input_cumsum - 1.0
    rhos = _make_ix_like(batch).to(x.dtype)
    support = rhos * sorted_input > input_cumsum

    support_size = scatter_add(, batch)
    # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
    idx = support_size + cum_num_nodes - 1
    mask = idx < 0
    idx[mask] = 0
    tau = input_cumsum.gather(0, idx)
    tau /=

    return tau, support_size

class SparsemaxFunction(Function):
    def forward(ctx, x, batch):
        """sparsemax: normalizing sparse transform
            ctx: context object
            x (Tensor): shape (N, )
            batch: group indicator
            output (Tensor): same shape as input
        max_val, _ = scatter_max(x, batch)
        x -= max_val[batch]
        tau, supp_size = _threshold_and_support(x, batch)
        output = torch.clamp(x - tau[batch], min=0)
        ctx.save_for_backward(supp_size, output, batch)

        return output

    def backward(ctx, grad_output):
        supp_size, output, batch = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[output == 0] = 0

        v_hat = scatter_add(grad_input, batch) /
        grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)

        return grad_input, None

sparsemax = SparsemaxFunction.apply

class Sparsemax(nn.Module):
    def __init__(self):
        super(Sparsemax, self).__init__()

    def forward(self, x, batch):
        return sparsemax(x, batch)

class TwoHopNeighborhood(object):
    def __call__(self, data):
        edge_index, edge_attr = data.edge_index, data.edge_attr
        n = data.num_nodes

        value = edge_index.new_ones((edge_index.size(1),), dtype=torch.float)

        index, value = spspmm(edge_index, value, edge_index, value, n, n, n)

        edge_index =[edge_index, index], dim=1)
        if edge_attr is None:
            data.edge_index, _ = coalesce(edge_index, None, n, n)
            value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
            value = value.expand(-1, *list(edge_attr.size())[1:])
            edge_attr =[edge_attr, value], dim=0)
            data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n)
            data.edge_attr = edge_attr

        return data

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)

class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs):
        super(GCN, self).__init__(aggr="add", **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.register_parameter("bias", None)


    def reset_parameters(self):
        self.cached_result = None
        self.cached_num_edges = None

    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        x = torch.matmul(x, self.weight)
        if isinstance(edge_index, tuple):
            edge_index = torch.stack(edge_index)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    "Cached {} number of edges, but found {}".format(self.cached_num_edges, edge_index.size(1))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)

class NodeInformationScore(MessagePassing):
    def __init__(self, improved=False, cached=False, **kwargs):
        super(NodeInformationScore, self).__init__(aggr="add", **kwargs)

        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None

    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0

        edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes)

        row, col = edge_index
        expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)

        return (
            expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col],

    def forward(self, x, edge_index, edge_weight):
        if isinstance(edge_index, tuple):
            edge_index = torch.stack(edge_index)
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    "Cached {} number of edges, but found {}".format(self.cached_num_edges, edge_index.size(1))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out

class HGPSLPool(torch.nn.Module):
    def __init__(
        super(HGPSLPool, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.sample = sample
        self.sparse = sparse = sl
        self.negative_slop = negative_slop
        self.lamb = lamb

        self.att = Parameter(torch.Tensor(1, self.in_channels * 2))
        self.sparse_attention = Sparsemax()
        self.neighbor_augment = TwoHopNeighborhood()
        self.calc_information_score = NodeInformationScore()

    def forward(self, x, edge_index, edge_attr, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x_information_score = self.calc_information_score(x, edge_index, edge_attr)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        # Graph Pooling
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))

        # Discard structure learning layer, directly return
        if is False:
            return x, induced_edge_index, induced_edge_attr, batch

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)

            hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0))

            new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = ([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones(
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes =[num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index

            weights = ([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj

        return x, new_edge_index, new_edge_attr, batch

[docs]@register_model("hgpsl") class HGPSL(BaseModel):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--hidden-size", type=int, default=128) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--pooling", type=float, default=0.5) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--train-ratio", type=float, default=0.8) parser.add_argument("--test-ratio", type=float, default=0.1) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--weight_decay', type=float, default=0.001) parser.add_argument('--sample_neighbor', type=bool, default=True) parser.add_argument('--sparse_attention', type=bool, default=True) parser.add_argument('--structure_learning', type=bool, default=True) parser.add_argument('--lamb', type=float, default=1.0) parser.add_argument('--patience', type=int, default=100) parser.add_argument('--seed', type=array, default=[777], help='random seed')
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.num_classes, args.hidden_size, args.dropout, args.pooling, args.sample_neighbor, args.sparse_attention, args.structure_learning, args.lamb, )
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__( self, num_features, num_classes, hidden_size, dropout, pooling, sample_neighbor, sparse_attention, structure_learning, lamb, ): super(HGPSL, self).__init__() self.num_features = num_features self.hidden_size = hidden_size self.num_classes = num_classes self.pooling = pooling self.dropout = dropout self.sample = sample_neighbor self.sparse = sparse_attention = structure_learning self.lamb = lamb self.conv1 = GCNConv(self.num_features, self.hidden_size) self.conv2 = GCN(self.hidden_size, self.hidden_size) self.conv3 = GCN(self.hidden_size, self.hidden_size) self.pool1 = HGPSLPool( self.hidden_size, self.pooling, self.sample, self.sparse,, self.lamb, ) self.pool2 = HGPSLPool( self.hidden_size, self.pooling, self.sample, self.sparse,, self.lamb, ) self.lin1 = torch.nn.Linear(self.hidden_size * 2, self.hidden_size) self.lin2 = torch.nn.Linear(self.hidden_size, self.hidden_size // 2) self.lin3 = torch.nn.Linear(self.hidden_size // 2, self.num_classes)
[docs] def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if isinstance(edge_index, tuple): edge_index = torch.stack(edge_index) edge_attr = None x = F.relu(self.conv1(x, edge_index, edge_attr)) x, edge_index, edge_attr, batch = self.pool1(x, edge_index, edge_attr, batch) x1 =[gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index, edge_attr)) x, edge_index, edge_attr, batch = self.pool2(x, edge_index, edge_attr, batch) x2 =[gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index, edge_attr)) x3 =[gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(x1) + F.relu(x2) + F.relu(x3) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, x = F.relu(self.lin2(x)) x = F.dropout(x, p=self.dropout, pred = self.lin3(x) return pred