Source code for cogdl.models.nn.pprgo

import math

from typing import Any
import torch
import torch.nn as nn

from .. import BaseModel, register_model
from cogdl.utils import get_activation, spmm
from cogdl.trainers.ppr_trainer import PPRGoTrainer

class LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(LinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
            self.register_parameter("bias", None)

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, mode="fan_out", a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / nn.math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return torch.nn.functional.linear(input, self.weight, self.bias)

class PPRGoMLP(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats, num_layers, dropout, activation="relu"):
        super(PPRGoMLP, self).__init__()
        self.dropout = dropout
        self.nlayers = num_layers
        shapes = [hidden_size] * (num_layers - 1) + [out_feats]
        self.layers = nn.ModuleList()
        self.layers.append(LinearLayer(in_feats, hidden_size, bias=False))
        for i in range(num_layers - 1):
            self.layers.append(nn.Linear(shapes[i], shapes[i + 1], bias=False))
        self.activation = get_activation(activation)

    def forward(self, x):
        h = x
        for i, layer in enumerate(self.layers):
            h = nn.functional.dropout(h, p=self.dropout,
            h = layer(h)
            if i != self.nlayers - 1:
                h = self.activation(h)
        return h

[docs]@register_model("pprgo") class PPRGo(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--hidden-size", type=int, default=32) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--activation", type=str, default="relu") parser.add_argument("--nprop-inference", type=int, default=2) parser.add_argument("--alpha", type=float, default=0.5) parser.add_argument("--k", type=int, default=32) parser.add_argument("--norm", type=str, default="sym") parser.add_argument("--eps", type=float, default=1e-4) parser.add_argument("--eval-step", type=int, default=4) parser.add_argument("--batch-size", type=int, default=512) parser.add_argument("--test-batch-size", type=int, default=10000)
[docs] @classmethod def build_model_from_args(cls, args): return cls( in_feats=args.num_features, hidden_size=args.hidden_size, out_feats=args.num_classes, num_layers=args.num_layers, alpha=args.alpha, dropout=args.dropout, activation=args.activation, nprop=args.nprop_inference, )
def __init__(self, in_feats, hidden_size, out_feats, num_layers, alpha, dropout, activation="relu", nprop=2): super(PPRGo, self).__init__() self.alpha = alpha self.nprop = nprop self.fc = PPRGoMLP(in_feats, hidden_size, out_feats, num_layers, dropout, activation)
[docs] def forward(self, x, targets, ppr_scores): h = self.fc(x) h = ppr_scores.unsqueeze(1) * h batch_size = targets[-1] + 1 out = torch.zeros(batch_size, h.shape[1]).to(x.device).to(x.dtype) out = out.scatter_add_(dim=0, index=targets[:, None].repeat(1, h.shape[1]), src=h) return out
[docs] def node_classification_loss(self, x, targets, ppr_scores, y): pred = self.forward(x, targets, ppr_scores) loss = self.loss_fn(pred, y) return loss
[docs] def predict(self, graph, batch_size, norm): device = next(self.fc.parameters()).device x = graph.x num_nodes = x.shape[0] pred_logits = [] with torch.no_grad(): for i in range(0, num_nodes, batch_size): batch_x = x[i : i + batch_size].to(device) batch_logits = self.fc(batch_x) pred_logits.append(batch_logits.cpu()) pred_logits =, dim=0) with graph.local_graph(): if norm == "sym": graph.sym_norm() elif norm == "row": graph.row_norm() else: raise NotImplementedError edge_weight = graph.edge_weight * (1 - self.alpha) graph.edge_weight = edge_weight predictions = pred_logits for _ in range(self.nprop): predictions = spmm(graph, predictions) + self.alpha * pred_logits return predictions
[docs] @staticmethod def get_trainer(taskType: Any, args: Any): return PPRGoTrainer