import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from .. import BaseModel, register_model
from .knowledge_base import KGEModel
[docs]@register_model("distmult")
class DistMult(KGEModel):
r"""The DistMult model from the ICLR 2015 paper `"EMBEDDING ENTITIES AND RELATIONS FOR LEARNING AND INFERENCE IN KNOWLEDGE BASES"
<https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ICLR2015_updated.pdf>`
borrowed from `KnowledgeGraphEmbedding<https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding>`
"""
def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
):
super(DistMult, self).__init__(
nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
)
[docs] def score(self, head, relation, tail, mode):
if mode == "head-batch":
score = head * (relation * tail)
else:
score = (head * relation) * tail
score = score.sum(dim=2)
return score