import torch
import torch.nn as nn
from .. import ModelWrapper
from cogdl.utils.link_prediction_utils import cal_mrr, DistMultLayer, ConvELayer
[docs]class GNNKGLinkPredictionModelWrapper(ModelWrapper):
[docs] @staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--score-func", type=str, default="distmult")
# fmt: on
def __init__(self, model, optimizer_cfg, score_func):
super(GNNKGLinkPredictionModelWrapper, self).__init__()
self.model = model
self.optimizer_cfg = optimizer_cfg
hidden_size = optimizer_cfg["hidden_size"]
self.score_func = score_func
if score_func == "distmult":
self.scoring = DistMultLayer()
elif score_func == "conve":
self.scoring = ConvELayer(hidden_size)
else:
raise NotImplementedError
[docs] def train_step(self, subgraph):
graph = subgraph
mask = graph.train_mask
edge_index = torch.stack(graph.edge_index)
edge_index, edge_types = edge_index[:, mask], graph.edge_attr[mask]
with graph.local_graph():
graph.edge_index = edge_index
graph.edge_attr = edge_types
loss = self.model.loss(graph, self.scoring)
return loss
[docs] def val_step(self, subgraph):
train_mask = subgraph.train_mask
eval_mask = subgraph.val_mask
return self.eval_step(subgraph, train_mask, eval_mask)
[docs] def test_step(self, subgraph):
infer_mask = subgraph.train_mask | subgraph.val_mask
eval_mask = subgraph.test_mask
return self.eval_step(subgraph, infer_mask, eval_mask)
[docs] def eval_step(self, graph, mask1, mask2):
row, col = graph.edge_index
edge_types = graph.edge_attr
with graph.local_graph():
graph.edge_index = (row[mask1], col[mask1])
graph.edge_attr = edge_types[mask1]
output, rel_weight = self.model.predict(graph)
mrr, hits = cal_mrr(
output,
rel_weight,
(row[mask2], col[mask2]),
edge_types[mask2],
scoring=self.scoring,
protocol="raw",
batch_size=500,
hits=[1, 3, 10],
)
return dict(mrr=mrr, hits1=hits[0], hits3=hits[1], hits10=hits[2])
[docs] def setup_optimizer(self):
lr, weight_decay = self.optimizer_cfg["lr"], self.optimizer_cfg["weight_decay"]
return torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
[docs] def set_early_stopping(self):
return "mrr", ">"