cogdl.models.nn.rgcn

Module Contents

Classes

RGCNLayer

RGCN

LinkPredictRGCN

class cogdl.models.nn.rgcn.RGCNLayer(in_feats, out_feats, num_edge_types, regularizer='basis', num_bases=None, self_loop=True, dropout=0.0, self_dropout=0.0, layer_norm=True, bias=True)[source]

Bases: torch.nn.Module

reset_parameters(self)[source]
forward(self, x, edge_index, edge_type)[source]
basis_forward(self, x, edge_index, edge_type)[source]
bdd_forward(self, x, edge_index, edge_type)[source]
class cogdl.models.nn.rgcn.RGCN(in_feats, out_feats, num_layers, num_rels, regularizer='basis', num_bases=None, self_loop=True, dropout=0.0, self_dropout=0.0)[source]

Bases: torch.nn.Module

forward(self, x, edge_index, edge_type)[source]
class cogdl.models.nn.rgcn.LinkPredictRGCN(num_entities, num_rels, hidden_size, num_layers, regularizer='basis', num_bases=None, self_loop=True, sampling_rate=0.01, penalty=0, dropout=0.0, self_dropout=0.0)[source]

Bases: cogdl.layers.link_prediction_module.GNNLinkPredict, cogdl.models.BaseModel

static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model_from_args(cls, args)[source]

Build a new model instance.

forward(self, edge_index, edge_type)[source]
loss(self, data, split='train')[source]
predict(self, edge_index, edge_type)[source]