Source code for cogdl.models.nn.dgl_jknet

import torch
import torch.nn.functional as F
import dgl
import dgl.function as fn
import numpy as np
from tqdm import tqdm

from cogdl.models.supervised_model import SupervisedHomogeneousNodeClassificationModel
from cogdl.trainers.supervised_model_trainer import SupervisedHomogeneousNodeClassificationTrainer
from .. import register_model


class GraphConvLayer(torch.nn.Module):
    """Graph convolution layer.

    Args:
        in_features (int): Size of each input node.
        out_features (int): Size of each output node.
        aggregation (str): 'sum', 'mean' or 'max'.
                           Specify the way to aggregate the neighbourhoods.
    """

    AGGREGATIONS = {
        "sum": torch.sum,
        "mean": torch.mean,
        "max": torch.max,
    }

    def __init__(self, in_features, out_features, aggregation="sum"):
        super(GraphConvLayer, self).__init__()

        if aggregation not in self.AGGREGATIONS.keys():
            raise ValueError("'aggregation' argument has to be one of " "'sum', 'mean' or 'max'.")
        self.aggregate = lambda nodes: self.AGGREGATIONS[aggregation](nodes, dim=1)

        self.linear = torch.nn.Linear(in_features, out_features)
        self.self_loop_w = torch.nn.Linear(in_features, out_features)
        self.bias = torch.nn.Parameter(torch.zeros(out_features))

    def forward(self, graph, x):
        graph.ndata["h"] = x
        graph.update_all(fn.copy_src(src="h", out="msg"), lambda nodes: {"h": self.aggregate(nodes.mailbox["msg"])})
        h = graph.ndata.pop("h")
        h = self.linear(h)
        return h + self.self_loop_w(x) + self.bias


class JKNetConcat(torch.nn.Module):
    """An implementation of Jumping Knowledge Network (arxiv 1806.03536) which
    combine layers with concatenation.

    Args:
        in_features (int): Size of each input node.
        out_features (int): Size of each output node.
        n_layers (int): Number of the convolution layers.
        n_units (int): Size of the middle layers.
        aggregation (str): 'sum', 'mean' or 'max'.
                           Specify the way to aggregate the neighbourhoods.
    """

    def __init__(self, in_features, out_features, n_layers=6, n_units=16, aggregation="sum"):
        super(JKNetConcat, self).__init__()
        self.n_layers = n_layers

        self.gconv0 = GraphConvLayer(in_features, n_units, aggregation)
        self.dropout0 = torch.nn.Dropout(0.5)
        for i in range(1, self.n_layers):
            setattr(self, "gconv{}".format(i), GraphConvLayer(n_units, n_units, aggregation))
            setattr(self, "dropout{}".format(i), torch.nn.Dropout(0.5))
        self.last_linear = torch.nn.Linear(n_layers * n_units, out_features)

    def forward(self, graph, x):
        layer_outputs = []
        for i in range(self.n_layers):
            dropout = getattr(self, "dropout{}".format(i))
            gconv = getattr(self, "gconv{}".format(i))
            x = dropout(F.relu(gconv(graph, x)))
            layer_outputs.append(x)

        h = torch.cat(layer_outputs, dim=1)
        return self.last_linear(h)


class JKNetMaxpool(torch.nn.Module):
    """An implementation of Jumping Knowledge Network (arxiv 1806.03536) which
    combine layers with Maxpool.

    Args:
        in_features (int): Size of each input node.
        out_features (int): Size of each output node.
        n_layers (int): Number of the convolution layers.
        n_units (int): Size of the middle layers.
        aggregation (str): 'sum', 'mean' or 'max'.
                           Specify the way to aggregate the neighbourhoods.
    """

    def __init__(self, in_features, out_features, n_layers=6, n_units=16, aggregation="sum"):
        super(JKNetMaxpool, self).__init__()
        self.n_layers = n_layers

        self.gconv0 = GraphConvLayer(in_features, n_units, aggregation)
        self.dropout0 = torch.nn.Dropout(0.5)
        for i in range(1, self.n_layers):
            setattr(self, "gconv{}".format(i), GraphConvLayer(n_units, n_units, aggregation))
            setattr(self, "dropout{}".format(i), torch.nn.Dropout(0.5))
        self.last_linear = torch.nn.Linear(n_units, out_features)

    def forward(self, graph, x):
        layer_outputs = []
        for i in range(self.n_layers):
            dropout = getattr(self, "dropout{}".format(i))
            gconv = getattr(self, "gconv{}".format(i))
            x = dropout(F.relu(gconv(graph, x)))
            layer_outputs.append(x)

        h = torch.stack(layer_outputs, dim=0)
        h = torch.max(h, dim=0)[0]
        return self.last_linear(h)


class JKNetTrainer(SupervisedHomogeneousNodeClassificationTrainer):
    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def __init__(self, args):
        super(JKNetTrainer, self).__init__()
        self.graph = dgl.DGLGraph()
        self.args = args

    def _train_step(self):
        self.model.train()
        self.optimizer.zero_grad()
        self.model.loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val", logits=None):
        self.model.eval()
        logits = logits if logits else self.model.predict(self.data)
        if split == "train":
            mask = self.data.train_mask
        elif split == "val":
            mask = self.data.val_mask
        else:
            mask = self.data.test_mask
        loss = F.nll_loss(logits[mask], self.data.y[mask]).item()

        pred = logits[mask].max(1)[1]
        acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
        return acc, loss

    def fit(self, model: SupervisedHomogeneousNodeClassificationModel, dataset):
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        device = self.args.device_id[0] if not self.args.cpu else "cpu"
        data = dataset[0]
        data.apply(lambda x: x.to(device))
        self.max_epoch = self.args.max_epoch

        row, col = data.edge_index
        row, col = row.cpu().numpy(), col.cpu().numpy()
        num_edge = row.shape[0]
        num_node = data.x.to("cpu").shape[0]
        self.graph.add_nodes(num_node)
        for i in range(num_edge):
            src, dst = row[i], col[i]
            self.graph.add_edge(src, dst)
        self.graph = self.graph.to(device)
        model.set_graph(self.graph)

        self.data = data
        self.model = model.to(device)

        epoch_iter = tqdm(range(self.max_epoch))
        best_score = 0
        best_loss = np.inf
        max_score = 0
        min_loss = np.inf
        for epoch in epoch_iter:
            self._train_step()
            train_acc, _ = self._test_step(split="train")
            val_acc, val_loss = self._test_step(split="val")
            epoch_iter.set_description(f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}")
            if val_loss <= min_loss or val_acc >= max_score:
                if val_loss <= best_loss:  # and val_acc >= best_score:
                    best_loss = val_loss
                    best_score = val_acc
                min_loss = np.min((min_loss, val_loss))
                max_score = np.max((max_score, val_acc))

        print(f"Best accurracy = {best_score}")

        test_acc, _ = self._test_step(split="test")
        print(f"Test accuracy = {test_acc}")
        return dict(Acc=test_acc)


[docs]@register_model("jknet") class JKNet(SupervisedHomogeneousNodeClassificationModel):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--lr', help='Learning rate', type=float, default=0.005) parser.add_argument('--layer-aggregation', help='The way to aggregate outputs of layers', type=str, choices=('maxpool', 'concat'), default='maxpool') parser.add_argument('--weight-decay', help='Weight decay', type=float, default=0.0005) parser.add_argument('--node-aggregation', help='The way to aggregate neighbourhoods', type=str, choices=('sum', 'mean', 'max'), default='sum') parser.add_argument('--n-layers', help='Number of convolution layers', type=int, default=6) parser.add_argument('--n-units', help='Size of middle layers.', type=int, default=16) parser.add_argument('--in-features', help='Input feature dimension, 1433 for cora', type=int, default=1433) parser.add_argument('--out-features', help='Output feature dimension, 7 for cora', type=int, default=7) parser.add_argument('--max-epoch', help='Epochs to train', type=int, default=100)
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.in_features, args.out_features, args.n_layers, args.n_units, args.node_aggregation, args.layer_aggregation, )
def __init__(self, in_features, out_features, n_layers, n_units, node_aggregation, layer_aggregation): model_args = (in_features, out_features, n_layers, n_units, node_aggregation) super(JKNet, self).__init__() if layer_aggregation == "maxpool": self.model = JKNetMaxpool(*model_args) else: self.model = JKNetConcat(*model_args)
[docs] def forward(self, graph, x): y = F.log_softmax(self.model(graph, x), dim=1) return y
[docs] def predict(self, data): return self.forward(self.graph, data.x)
[docs] def loss(self, data): return F.nll_loss(self.forward(self.graph, data.x)[data.train_mask], data.y[data.train_mask])
[docs] def set_graph(self, graph): self.graph = graph
[docs] @staticmethod def get_trainer(taskType, args): return JKNetTrainer