import os.path as osp
import torch
from cogdl.data import Graph, Dataset
from cogdl.utils import download_url
[docs]def read_triplet_data(folder):
filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"]
count = 0
edge_index = []
edge_attr = []
count_list = []
triples = []
num_entities = 0
num_relations = 0
entity_dic = {}
relation_dic = {}
for filename in filenames:
with open(osp.join(folder, filename), "r") as f:
_ = int(f.readline().strip())
if "train" in filename:
train_start_idx = len(triples)
elif "valid" in filename:
valid_start_idx = len(triples)
elif "test" in filename:
test_start_idx = len(triples)
for line in f:
items = line.strip().split()
edge_index.append([int(items[0]), int(items[1])])
edge_attr.append(int(items[2]))
triples.append((int(items[0]), int(items[2]), int(items[1])))
if items[0] not in entity_dic:
entity_dic[items[0]] = num_entities
num_entities += 1
if items[1] not in entity_dic:
entity_dic[items[1]] = num_entities
num_entities += 1
if items[2] not in relation_dic:
relation_dic[items[2]] = num_relations
num_relations += 1
count += 1
count_list.append(count)
edge_index = torch.LongTensor(edge_index).t()
edge_attr = torch.LongTensor(edge_attr)
data = Graph()
data.edge_index = edge_index
data.edge_attr = edge_attr
def generate_mask(start, end):
mask = torch.BoolTensor(count)
mask[:] = False
mask[start:end] = True
return mask
data.train_mask = generate_mask(0, count_list[0])
data.val_mask = generate_mask(count_list[0], count_list[1])
data.test_mask = generate_mask(count_list[1], count_list[2])
return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations
[docs]class KnowledgeGraphDataset(Dataset):
url = "https://cloud.tsinghua.edu.cn/d/d1c733373b014efab986/files/?p=%2F{}%2F{}&dl=1"
def __init__(self, root, name):
self.name = name
super(KnowledgeGraphDataset, self).__init__(root)
self.data = torch.load(self.processed_paths[0])
triple_config = torch.load(self.processed_paths[1])
self.triples = triple_config["triples"]
self._train_start_index = triple_config["train_start_index"]
self._valid_start_index = triple_config["valid_start_index"]
self._test_start_index = triple_config["test_start_index"]
self._num_entities = triple_config["num_entities"]
self._num_relations = triple_config["num_relations"]
@property
def raw_file_names(self):
names = ["train2id.txt", "valid2id.txt", "test2id.txt"]
return names
@property
def processed_file_names(self):
return ["data.pt", "triple_config.pt"]
@property
def train_start_idx(self):
return self._train_start_index
@property
def valid_start_idx(self):
return self._valid_start_index
@property
def test_start_idx(self):
return self._test_start_index
@property
def num_entities(self):
return self._num_entities
@property
def num_relations(self):
return self._num_relations
[docs] def get(self, idx):
assert idx == 0
return self.data
[docs] def download(self):
for name in self.raw_file_names:
# download_url("{}/{}/{}".format(self.url, self.name, name), self.raw_dir)
download_url(self.url.format(self.name, name), self.raw_dir, name=name)
[docs] def process(self):
(
data,
triples,
train_start_index,
valid_start_index,
test_start_index,
num_entities,
num_relations,
) = read_triplet_data(self.raw_dir)
torch.save(data, self.processed_paths[0])
triple_config = {
"triples": triples,
"train_start_index": train_start_index,
"valid_start_index": valid_start_index,
"test_start_index": test_start_index,
"num_entities": num_entities,
"num_relations": num_relations,
}
torch.save(triple_config, self.processed_paths[1])
[docs]class FB13Datset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "FB13"
path = osp.join(data_path, dataset)
super(FB13Datset, self).__init__(path, dataset)
[docs]class FB15kDatset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "FB15K"
path = osp.join(data_path, dataset)
super(FB15kDatset, self).__init__(path, dataset)
[docs]class FB15k237Datset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "FB15K237"
path = osp.join(data_path, dataset)
super(FB15k237Datset, self).__init__(path, dataset)
[docs]class WN18Datset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "WN18"
path = osp.join(data_path, dataset)
super(WN18Datset, self).__init__(path, dataset)
[docs]class WN18RRDataset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "WN18RR"
path = osp.join(data_path, dataset)
super(WN18RRDataset, self).__init__(path, dataset)
[docs]class FB13SDatset(KnowledgeGraphDataset):
def __init__(self, data_path="data"):
dataset = "FB13S"
path = osp.join(data_path, dataset)
super(FB13SDatset, self).__init__(path, dataset)