Source code for cogdl.models.nn.pyg_diffpool

import numpy as np
import random

from scipy.linalg import block_diag
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import SAGEConv
from torch_geometric.utils import add_remaining_self_loops

from .. import BaseModel, register_model
from cogdl.data import DataLoader


[docs]class EntropyLoss(nn.Module): # Return Scalar
[docs] def forward(self, adj, anext, s_l): # entropy.mean(-1).mean(-1): 1/n in node and batch # entropy = (torch.distributions.Categorical( # probs=s_l).entropy()).sum(-1).mean(-1) entropy = (torch.distributions.Categorical( probs=s_l).entropy()).mean() assert not torch.isnan(entropy) return entropy
[docs]class LinkPredLoss(nn.Module):
[docs] def forward(self, adj, anext, s_l): link_pred_loss = ( adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) return link_pred_loss.mean()
[docs]class GraphSAGE(nn.Module): r"""GraphSAGE from `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/pdf/1706.02216.pdf>`__. ..math:: h^{i+1}_{\mathcal{N}(v)}=AGGREGATE_{k}(h_{u}^{k}) h^{k+1}_{v} = \sigma(\mathbf{W}^{k}·CONCAT(h_{v}^{k}, h_{\mathcal{N}(v)})) Args: in_feats (int) : Size of each input sample. hidden_dim (int) : Size of hidden layer dimension. out_feats (int) : Size of each output sample. num_layers (int) : Number of GraphSAGE Layers. dropout (float, optional) : Size of dropout, default: ``0.5``. normalize (bool, optional) : Normalze features after each layer if True, default: ``True``. """ def __init__(self, in_feats, hidden_dim, out_feats, num_layers, dropout=0.5, normalize=False, concat=False, use_bn=False): super(GraphSAGE, self).__init__() self.convlist = nn.ModuleList() self.bn_list = nn.ModuleList() self.num_layers = num_layers self.dropout = dropout self.use_bn = use_bn if num_layers == 1: self.convlist.append(SAGEConv(in_feats, out_feats, normalize, concat)) else: self.convlist.append(SAGEConv(in_feats, hidden_dim, normalize, concat)) if use_bn: self.bn_list.append(nn.BatchNorm1d(hidden_dim)) for _ in range(num_layers - 2): self.convlist.append(SAGEConv(hidden_dim, hidden_dim, normalize, concat)) if use_bn: self.bn_list.append(nn.BatchNorm1d(hidden_dim)) self.convlist.append(SAGEConv(hidden_dim, out_feats, normalize, concat))
[docs] def forward(self, x, edge_index, edge_weight=None): h = x for i in range(self.num_layers-1): h = F.dropout(h, p=self.dropout, training=self.training) h = self.convlist[i](h, edge_index, edge_weight) if self.use_bn: h = self.bn_list[i](h) return self.convlist[self.num_layers-1](h, edge_index, edge_weight)
[docs]class BatchedGraphSAGE(nn.Module): r"""GraphSAGE with mini-batch Args: in_feats (int) : Size of each input sample. out_feats (int) : Size of each output sample. use_bn (bool) : Apply batch normalization if True, default: ``True``. self_loop (bool) : Add self loop if True, default: ``True``. """ def __init__(self, in_feats, out_feats, use_bn=True, self_loop=True): super(BatchedGraphSAGE, self).__init__() self.self_loop = self_loop self.use_bn = use_bn self.weight = nn.Linear(in_feats, out_feats, bias=True) nn.init.xavier_uniform_(self.weight.weight.data, gain=nn.init.calculate_gain('relu'))
[docs] def forward(self, x, adj): device = x.device if self.self_loop: adj = adj + torch.eye(x.shape[1]).to(device) adj = adj / adj.sum(dim=1, keepdim=True) h = torch.matmul(adj, x) h = self.weight(h) h = F.normalize(h, dim=2, p=2) h = F.relu(h) # TODO: shape = [a, 0, b] # if self.use_bn and h.shape[1] > 0: # self.bn = nn.BatchNorm1d(h.shape[1]).to(device) # h = self.bn(h) return h
[docs]class BatchedDiffPoolLayer(nn.Module): r"""DIFFPOOL from paper `"Hierarchical Graph Representation Learning with Differentiable Pooling" <https://arxiv.org/pdf/1806.08804.pdf>`__. .. math:: X^{(l+1)} = S^{l)}^T Z^{(l)} A^{(l+1)} = S^{(l)}^T A^{(l)} S^{(l)} Z^{(l)} = GNN_{l, embed}(A^{(l)}, X^{(l)}) S^{(l)} = softmax(GNN_{l,pool}(A^{(l)}, X^{(l)})) Parameters ---------- in_feats : int Size of each input sample. out_feats : int Size of each output sample. assign_dim : int Size of next adjacency matrix. batch_size : int Size of each mini-batch. dropout : float, optional Size of dropout, default: ``0.5``. link_pred_loss : bool, optional Use link prediction loss if True, default: ``True``. """ def __init__(self, in_feats, out_feats, assign_dim, batch_size, dropout=0.5, link_pred_loss=True, entropy_loss=True): super(BatchedDiffPoolLayer, self).__init__() self.assign_dim = assign_dim self.dropout = dropout self.use_link_pred = link_pred_loss self.batch_size = batch_size self.embd_gnn = SAGEConv(in_feats, out_feats, normalize=False) self.pool_gnn = SAGEConv(in_feats, assign_dim, normalize=False) self.loss_dict = dict()
[docs] def forward(self, x, edge_index, batch, edge_weight=None): embed = self.embd_gnn(x, edge_index) pooled = F.softmax(self.pool_gnn(x, edge_index), dim=-1) device = x.device masked_tensor = [] value_set, value_counts = torch.unique(batch, return_counts=True) batch_size = len(value_set) for i in value_counts: masked = torch.ones((i, int(pooled.size()[1]/batch_size))) masked_tensor.append(masked) masked = torch.FloatTensor(block_diag(*masked_tensor)).to(device) result = torch.nn.functional.softmax(masked * pooled, dim=-1) result = result * masked result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) # result = masked_softmax(pooled, masked, memory_efficient=False) h = torch.matmul(result.t(), embed) if not edge_weight: edge_weight = torch.ones(edge_index.shape[1]).to(x.device) adj = torch.sparse_coo_tensor(edge_index, edge_weight) adj_new = torch.sparse.mm(adj, result) adj_new = torch.mm(result.t(), adj_new) if self.use_link_pred: adj_loss = torch.norm((adj.to_dense() - torch.mm(result, result.t()))) / np.power((len(batch)), 2) self.loss_dict["adj_loss"] = adj_loss entropy_loss = (torch.distributions.Categorical(probs=pooled).entropy()).mean() assert not torch.isnan(entropy_loss) self.loss_dict["entropy_loss"] = entropy_loss return adj_new, h
[docs] def get_loss(self): loss_n = 0 for _, value in self.loss_dict.items(): loss_n += value return loss_n
[docs]class BatchedDiffPool(nn.Module): r"""DIFFPOOL layer with batch forward Parameters ---------- in_feats : int Size of each input sample. next_size : int Size of next adjacency matrix. emb_size : int Dimension of next node feature matrix. use_bn : bool, optional Apply batch normalization if True, default: ``True``. self_loop : bool, optional Add self loop if True, default: ``True``. use_link_loss : bool, optional Use link prediction loss if True, default: ``True``. use_entropy : bool, optioinal Use entropy prediction loss if True, default: ``True``. """ def __init__(self, in_feats, next_size, emb_size, use_bn=True, self_loop=True, use_link_loss=False, use_entropy=True): super(BatchedDiffPool, self).__init__() self.use_link_loss = use_link_loss self.use_bn = use_bn self.feat_trans = BatchedGraphSAGE(in_feats, emb_size) self.assign_trans = BatchedGraphSAGE(in_feats, next_size) self.link_loss = LinkPredLoss() self.entropy = EntropyLoss() self.loss_module = nn.ModuleList() if use_link_loss: self.loss_module.append(LinkPredLoss()) if use_entropy: self.loss_module.append(EntropyLoss()) self.loss = {}
[docs] def forward(self, x, adj): h = self.feat_trans(x, adj) next_l = F.softmax(self.assign_trans(x, adj), dim=-1) h = torch.matmul(next_l.transpose(-1, -2), h) next = torch.matmul(next_l.transpose(-1, -2), torch.matmul(adj, next_l)) for layer in self.loss_module: self.loss[str(type(layer).__name__)] = layer(adj, next, next_l) return h, next
[docs] def get_loss(self): value = 0 for _, v in self.loss.items(): value += v return value
[docs]def toBatchedGraph(batch_adj, batch_feat, node_per_pool_graph): adj_list = [batch_adj[i:i+node_per_pool_graph, i:i+node_per_pool_graph] for i in range(0, batch_adj.size()[0], node_per_pool_graph)] feat_list = [batch_feat[i:i+node_per_pool_graph, :] for i in range(0, batch_adj.size()[0], node_per_pool_graph)] adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list)) feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list)) adj = torch.cat(adj_list, dim=0) feat = torch.cat(feat_list, dim=0) return adj, feat
[docs]@register_model("diffpool") class DiffPool(BaseModel): r"""DIFFPOOL from paper `Hierarchical Graph Representation Learning with Differentiable Pooling <https://arxiv.org/pdf/1806.08804.pdf>`__. Parameters ---------- in_feats : int Size of each input sample. hidden_dim : int Size of hidden layer dimension of GNN. embed_dim : int Size of embeded node feature, output size of GNN. num_classes : int Number of target classes. num_layers : int Number of GNN layers. num_pool_layers : int Number of pooling. assign_dim : int Embedding size after the first pooling. pooling_ratio : float Size of each poolling ratio. batch_size : int Size of each mini-batch. dropout : float, optional Size of dropout, default: `0.5`. no_link_pred : bool, optional If True, use link prediction loss, default: `True`. """ @staticmethod
[docs] def add_args(parser): parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--num-pooling-layers", type=int, default=1) parser.add_argument("--no-link-pred", dest="no_link_pred", action="store_true") parser.add_argument("--pooling-ratio", type=float, default=0.15) parser.add_argument("--embedding-dim", type=int, default=64) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) parser.add_argument("--lr", type=float, default=0.001)
@classmethod
[docs] def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.embedding_dim, args.num_classes, args.num_layers, args.num_pooling_layers, int(args.max_graph_size * args.pooling_ratio) * args.batch_size, args.pooling_ratio, args.batch_size, args.dropout, args.no_link_pred
) @classmethod
[docs] def split_dataset(cls, dataset, args): random.shuffle(dataset) train_size = int(len(dataset) * args.train_ratio) test_size = int(len(dataset) * args.test_ratio) bs = args.batch_size train_loader = DataLoader(dataset[:train_size], batch_size=bs, drop_last=True) test_loader = DataLoader(dataset[-test_size:], batch_size=bs, drop_last=True) if args.train_ratio + args.test_ratio < 1: valid_loader = DataLoader(dataset[train_size:-test_size], batch_size=bs, drop_last=True) else: valid_loader = test_loader return train_loader, valid_loader, test_loader
def __init__(self, in_feats, hidden_dim, embed_dim, num_classes, num_layers, num_pool_layers, assign_dim, pooling_ratio, batch_size, dropout=0.5, no_link_pred=True, concat=False, use_bn=False): super(DiffPool, self).__init__() self.assign_dim = assign_dim self.assign_dim_list = [assign_dim] self.use_bn = use_bn self.dropout = dropout self.use_link_loss = not no_link_pred # assert num_layers > 3, "layers > 3" self.diffpool_layers = nn.ModuleList() self.before_pooling = GraphSAGE(in_feats, hidden_dim, embed_dim, num_layers=num_layers, dropout=dropout, use_bn=self.use_bn) self.init_diffpool = BatchedDiffPoolLayer(embed_dim, hidden_dim, assign_dim, batch_size, dropout, self.use_link_loss) pooled_emb_dim = embed_dim self.after_pool = nn.ModuleList() after_per_pool = nn.ModuleList() for _ in range(num_layers-1): after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim)) after_per_pool.append(BatchedGraphSAGE(hidden_dim, pooled_emb_dim)) self.after_pool.append(after_per_pool) for _ in range(num_pool_layers-1): self.assign_dim = int(self.assign_dim//batch_size * pooling_ratio) * batch_size self.diffpool_layers.append(BatchedDiffPool( pooled_emb_dim, self.assign_dim, hidden_dim, use_bn=self.use_bn, use_link_loss=self.use_link_loss )) for _ in range(num_layers - 1): after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim)) after_per_pool.append(BatchedGraphSAGE(hidden_dim, pooled_emb_dim)) self.after_pool.append(after_per_pool) self.assign_dim_list.append(self.assign_dim) if concat: out_dim = pooled_emb_dim * (num_pool_layers+1) else: out_dim = pooled_emb_dim self.fc = nn.Linear(out_dim, num_classes)
[docs] def reset_parameters(self): for i in self.modules(): if isinstance(i, nn.Linear): nn.init.xavier_uniform_(i.weight.data, gain=nn.init.calculate_gain('relu')) if i.bias is not None: nn.init.constant_(i.bias.data, 0.)
[docs] def after_pooling_forward(self, gnn_layers, adj, x, concat=False): readouts = [] h = x for layer in gnn_layers: h = layer(h, adj) readouts.append(h) readout = torch.cat(readouts, dim=1) return h
[docs] def forward(self, batch): readouts_all = [] init_emb = self.before_pooling(batch.x, batch.edge_index) adj, h = self.init_diffpool(init_emb, batch.edge_index, batch.batch) value_set, value_counts = torch.unique(batch.batch, return_counts=True) batch_size = len(value_set) adj, h = toBatchedGraph(adj, h, adj.size(0)//batch_size) h = self.after_pooling_forward(self.after_pool[0], adj, h) readout = torch.sum(h, dim=1) readouts_all.append(readout) for i, diff_layer in enumerate(self.diffpool_layers): h, adj = diff_layer(h, adj) h = self.after_pooling_forward(self.after_pool[i+1], adj, h) readout = torch.sum(h, dim=1) readouts_all.append(readout) pred = self.fc(readout) if batch.y is not None: return pred, self.loss(pred, batch.y) return pred, None
[docs] def loss(self, prediction, label): criterion = nn.CrossEntropyLoss() loss_n = criterion(prediction, label) loss_n += self.init_diffpool.get_loss() for layer in self.diffpool_layers: loss_n += layer.get_loss() return loss_n