Source code for cogdl.models.nn.rgcn

import torch
import torch.nn as nn
import torch.nn.functional as F

from cogdl.utils import row_normalization, spmm
from cogdl.layers.link_prediction_module import GNNLinkPredict, cal_mrr, sampling_edge_uniform
from .. import register_model, BaseModel


class RGCNLayer(nn.Module):
    """
    Implementation of Relational-GCN in paper `"Modeling Relational Data with Graph Convolutional Networks"`
     <https://arxiv.org/abs/1703.06103>

     Parameters
     ----------
     in_feats : int
        Size of each input embedding.
    out_feats : int
        Size of each output embedding.
    num_edge_type : int
        The number of edge type in knowledge graph.
    regularizer : str, optional
        Regularizer used to avoid overfitting, ``basis`` or ``bdd``, default : ``basis``.
    num_bases : int, optional
        The number of basis, only used when `regularizer` is `basis`, default : ``None``.
    self_loop : bool, optional
        Add self loop embedding if True, default : ``True``.
    dropout : float
    self_dropout : float, optional
        Dropout rate of self loop embedding, default : ``0.0``
    layer_norm : bool, optional
        Use layer normalization if True, default : ``True``
    bias : bool
    """

    def __init__(
        self,
        in_feats,
        out_feats,
        num_edge_types,
        regularizer="basis",
        num_bases=None,
        self_loop=True,
        dropout=0.0,
        self_dropout=0.0,
        layer_norm=True,
        bias=True,
    ):
        super(RGCNLayer, self).__init__()
        self.num_bases = num_bases
        self.regularizer = regularizer
        self.num_edge_types = num_edge_types
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.self_loop = self_loop
        self.dropout = dropout
        self.self_dropout = self_dropout

        if self.num_bases is None or self.num_bases > num_edge_types or self.num_bases < 0:
            self.num_bases = num_edge_types

        if regularizer == "basis":
            self.weight = nn.Parameter(torch.Tensor(self.num_bases, in_feats, out_feats))
            if self.num_bases < num_edge_types:
                self.alpha = nn.Parameter(torch.Tensor(num_edge_types, self.num_bases))
            else:
                self.register_buffer("alpha", None)
        elif regularizer == "bdd":
            assert (in_feats % num_bases == 0) and (out_feats % num_bases == 0)
            self.block_in_feats = in_feats // num_bases
            self.block_out_feats = out_feats // num_bases
            self.weight = nn.Parameter(
                torch.Tensor(num_edge_types, self.num_bases, self.block_in_feats * self.block_out_feats)
            )
        else:
            raise NotImplementedError

        if bias is True:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.register_buffer("bias", None)

        if self_loop:
            self.weight_self_loop = nn.Parameter(torch.Tensor(in_feats, out_feats))
        else:
            self.register_buffer("weight_self_loop", None)

        if layer_norm:
            self.layer_norm = nn.LayerNorm(out_feats, elementwise_affine=True)
        else:
            self.register_buffer("layer_norm", None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain("relu"))
        if self.alpha is not None:
            nn.init.xavier_uniform_(self.alpha, gain=nn.init.calculate_gain("relu"))
        if self.bias is not None:
            nn.init.zeros_(self.bias)
        if self.self_loop is not None:
            nn.init.xavier_uniform_(self.weight_self_loop, gain=nn.init.calculate_gain("relu"))

    def forward(self, x, edge_index, edge_type):
        if self.regularizer == "basis":
            h_list = self.basis_forward(x, edge_index, edge_type)
        else:
            h_list = self.bdd_forward(x, edge_index, edge_type)

        h_result = sum(h_list)
        h_result = F.dropout(h_result, p=self.dropout, training=self.training)
        if self.layer_norm is not None:
            h_result = self.layer_norm(h_result)
        if self.bias is not None:
            h_result = h_result + self.bias
        if self.self_loop is not None:
            h_result += F.dropout(torch.matmul(x, self.weight_self_loop), p=self.self_dropout, training=self.training)
        return h_result

    def basis_forward(self, x, edge_index, edge_type):
        if self.num_bases < self.num_edge_types:
            weight = torch.matmul(self.alpha, self.weight.view(self.num_bases, -1))
            weight = weight.view(self.num_edge_types, self.in_feats, self.out_feats)
        else:
            weight = self.weight
        edge_weight = torch.ones(edge_type.shape).to(x.device)
        edge_weight = row_normalization(x.shape[0], edge_index, edge_weight)

        h = torch.matmul(x, weight)  # (N, d1) by (r, d1, d2) -> (r, N, d2)

        h_list = []
        for edge_t in range(self.num_edge_types):
            edge_mask = edge_type == edge_t
            _edge_index_t = edge_index.t()[edge_mask].t()
            temp = spmm(_edge_index_t, edge_weight[edge_mask], h[edge_t])
            h_list.append(temp)
        return h_list

    def bdd_forward(self, x, edge_index, edge_type):
        _x = x.view(-1, self.num_bases, self.block_in_feats)

        edge_weight = torch.ones(edge_type.shape).to(x.device)
        edge_weight = row_normalization(x.shape[0], edge_index, edge_weight)

        h_list = []
        for edge_t in range(self.num_edge_types):
            _weight = self.weight[edge_t].view(self.num_bases, self.block_in_feats, self.block_out_feats)
            edge_mask = edge_type == edge_t
            _edge_index_t = edge_index.t()[edge_mask].t()
            h_t = torch.einsum("abc,bcd->abd", _x, _weight).reshape(-1, self.out_feats)
            h_t = spmm(_edge_index_t, edge_weight[edge_mask], h_t)
            h_list.append(h_t)

        return h_list


class RGCN(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        num_layers,
        num_rels,
        regularizer="basis",
        num_bases=None,
        self_loop=True,
        dropout=0.0,
        self_dropout=0.0,
    ):
        super(RGCN, self).__init__()
        shapes = [in_feats] + [out_feats] * num_layers
        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            RGCNLayer(shapes[i], shapes[i + 1], num_rels, regularizer, num_bases, self_loop, dropout, self_dropout)
            for i in range(num_layers)
        )

    def forward(self, x, edge_index, edge_type):
        h = x
        for i in range(len(self.layers)):
            h = self.layers[i](x, edge_index, edge_type)
            if i < self.num_layers - 1:
                h = F.relu(h)
        return h


[docs]@register_model("rgcn") class LinkPredictRGCN(GNNLinkPredict, BaseModel):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--hidden-size", type=int, default=200) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--regularizer", type=str, default="basis") parser.add_argument("--self-loop", action="store_false") parser.add_argument("--penalty", type=float, default=0.001) parser.add_argument("--dropout", type=float, default=0.2) parser.add_argument("--self-dropout", type=float, default=0.4) parser.add_argument("--num-bases", type=int, default=5) parser.add_argument("--sampling-rate", type=float, default=0.01)
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( num_entities=args.num_entities, num_rels=args.num_rels, hidden_size=args.hidden_size, num_layers=args.num_layers, regularizer=args.regularizer, num_bases=args.num_bases, self_loop=args.self_loop, sampling_rate=args.sampling_rate, penalty=args.penalty, dropout=args.dropout, self_dropout=args.self_dropout, )
def __init__( self, num_entities, num_rels, hidden_size, num_layers, regularizer="basis", num_bases=None, self_loop=True, sampling_rate=0.01, penalty=0, dropout=0.0, self_dropout=0.0, ): BaseModel.__init__(self) GNNLinkPredict.__init__(self, "distmult", hidden_size) self.penalty = penalty self.num_nodes = num_entities self.num_rels = num_rels self.sampling_rate = sampling_rate self.edge_set = None self.model = RGCN( in_feats=hidden_size, out_feats=hidden_size, num_layers=num_layers, num_rels=num_rels, regularizer=regularizer, num_bases=num_bases, self_loop=self_loop, dropout=dropout, self_dropout=self_dropout, ) # self.rel_weight = nn.Parameter(torch.Tensor(num_rels, hidden_size)) # nn.init.xavier_normal_(self.rel_weight, gain=nn.init.calculate_gain("relu")) # self.emb = nn.Parameter(torch.Tensor(num_entities, hidden_size)) # nn.init.xavier_normal_(self.emb, gain=nn.init.calculate_gain("relu")) self.rel_weight = nn.Embedding(num_rels, hidden_size) self.emb = nn.Embedding(num_entities, hidden_size)
[docs] def forward(self, edge_index, edge_type): reindexed_nodes, reindexed_indices = torch.unique(edge_index, sorted=True, return_inverse=True) x = self.emb(reindexed_nodes) self.cahce_index = reindexed_nodes output = self.model(x, reindexed_indices, edge_type) return output
[docs] def loss(self, data, split="train"): if split == "train": mask = data.train_mask elif split == "val": mask = data.val_mask else: mask = data.test_mask edge_index, edge_types = data.edge_index[:, mask], data.edge_attr[mask] self.get_edge_set(edge_index, edge_types) batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform( edge_index, edge_types, self.edge_set, self.sampling_rate, self.num_rels ) output = self.forward(batch_edges, batch_attr) edge_weight = self.rel_weight(rels) sampled_nodes, reindexed_edges = torch.unique(samples, sorted=True, return_inverse=True) assert (sampled_nodes == self.cahce_index).any() sampled_types = torch.unique(rels) loss_n = self._loss( output[reindexed_edges[0]], output[reindexed_edges[1]], edge_weight, labels ) + self.penalty * self._regularization([self.emb(sampled_nodes), self.rel_weight(sampled_types)]) return loss_n
[docs] def predict(self, edge_index, edge_type): indices = torch.arange(0, self.num_nodes).to(edge_index.device) x = self.emb(indices) output = self.model(x, edge_index, edge_type) mrr, hits = cal_mrr( output, self.rel_weight.weight, edge_index, edge_type, scoring=self.scoring, protocol="raw", batch_size=500, hits=[1, 3, 10], ) return mrr, hits