Source code for cogdl.datasets.kg_data

import json
from itertools import product
import subprocess

import numpy as np
import scipy.io
import torch

from cogdl.datasets import register_dataset
import os
import os.path as osp
import sys
import torch

from cogdl.data import Data, Dataset, download_url

[docs]class BidirectionalOneShotIterator(object): def __init__(self, dataloader_head, dataloader_tail): self.iterator_head = self.one_shot_iterator(dataloader_head) self.iterator_tail = self.one_shot_iterator(dataloader_tail) self.step = 0
[docs] def __next__(self): self.step += 1 if self.step % 2 == 0: data = next(self.iterator_head) else: data = next(self.iterator_tail) return data
@staticmethod
[docs] def one_shot_iterator(dataloader): ''' Transform a PyTorch Dataloader into python iterator ''' while True: for data in dataloader: yield data
[docs]class TestDataset(torch.utils.data.Dataset): def __init__(self, triples, all_true_triples, nentity, nrelation, mode): self.len = len(triples) self.triple_set = set(all_true_triples) self.triples = triples self.nentity = nentity self.nrelation = nrelation self.mode = mode
[docs] def __len__(self): return self.len
[docs] def __getitem__(self, idx): head, relation, tail = self.triples[idx] if self.mode == 'head-batch': tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set else (-1, head) for rand_head in range(self.nentity)] tmp[head] = (0, head) elif self.mode == 'tail-batch': tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set else (-1, tail) for rand_tail in range(self.nentity)] tmp[tail] = (0, tail) else: raise ValueError('negative batch mode %s not supported' % self.mode) tmp = torch.LongTensor(tmp) filter_bias = tmp[:, 0].float() negative_sample = tmp[:, 1] positive_sample = torch.LongTensor((head, relation, tail)) return positive_sample, negative_sample, filter_bias, self.mode
@staticmethod
[docs] def collate_fn(data): positive_sample = torch.stack([_[0] for _ in data], dim=0) negative_sample = torch.stack([_[1] for _ in data], dim=0) filter_bias = torch.stack([_[2] for _ in data], dim=0) mode = data[0][3] return positive_sample, negative_sample, filter_bias, mode
[docs]class TrainDataset(torch.utils.data.Dataset): def __init__(self, triples, nentity, nrelation, negative_sample_size, mode): self.len = len(triples) self.triples = triples self.triple_set = set(triples) self.nentity = nentity self.nrelation = nrelation self.negative_sample_size = negative_sample_size self.mode = mode self.count = self.count_frequency(triples) self.true_head, self.true_tail = self.get_true_head_and_tail(self.triples)
[docs] def __len__(self): return self.len
[docs] def __getitem__(self, idx): positive_sample = self.triples[idx] head, relation, tail = positive_sample subsampling_weight = self.count[(head, relation)] + self.count[(tail, -relation-1)] subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight])) negative_sample_list = [] negative_sample_size = 0 while negative_sample_size < self.negative_sample_size: negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size*2) if self.mode == 'head-batch': mask = np.in1d( negative_sample, self.true_head[(relation, tail)], assume_unique=True, invert=True ) elif self.mode == 'tail-batch': mask = np.in1d( negative_sample, self.true_tail[(head, relation)], assume_unique=True, invert=True ) else: raise ValueError('Training batch mode %s not supported' % self.mode) negative_sample = negative_sample[mask] negative_sample_list.append(negative_sample) negative_sample_size += negative_sample.size negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] negative_sample = torch.LongTensor(negative_sample) positive_sample = torch.LongTensor(positive_sample) return positive_sample, negative_sample, subsampling_weight, self.mode
@staticmethod
[docs] def collate_fn(data): positive_sample = torch.stack([_[0] for _ in data], dim=0) negative_sample = torch.stack([_[1] for _ in data], dim=0) subsample_weight = torch.cat([_[2] for _ in data], dim=0) mode = data[0][3] return positive_sample, negative_sample, subsample_weight, mode
@staticmethod
[docs] def count_frequency(triples, start=4): ''' Get frequency of a partial triple like (head, relation) or (relation, tail) The frequency will be used for subsampling like word2vec ''' count = {} for head, relation, tail in triples: if (head, relation) not in count: count[(head, relation)] = start else: count[(head, relation)] += 1 if (tail, -relation-1) not in count: count[(tail, -relation-1)] = start else: count[(tail, -relation-1)] += 1 return count
@staticmethod
[docs] def get_true_head_and_tail(triples): ''' Build a dictionary of true triples that will be used to filter these true triples for negative sampling ''' true_head = {} true_tail = {} for head, relation, tail in triples: if (head, relation) not in true_tail: true_tail[(head, relation)] = [] true_tail[(head, relation)].append(tail) if (relation, tail) not in true_head: true_head[(relation, tail)] = [] true_head[(relation, tail)].append(head) for relation, tail in true_head: true_head[(relation, tail)] = np.array(list(set(true_head[(relation, tail)]))) for head, relation in true_tail: true_tail[(head, relation)] = np.array(list(set(true_tail[(head, relation)]))) return true_head, true_tail
[docs]def read_triplet_data(folder): filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"] count = 0 edge_index = [] edge_attr = [] count_list = [] triples = [] num_entities = 0 num_relations = 0 entity_dic = {} relation_dic = {} for filename in filenames: with open(osp.join(folder, filename), "r") as f: num = int(f.readline().strip()) if "train" in filename: train_start_idx = len(triples) elif "valid" in filename: valid_start_idx = len(triples) elif "test" in filename: test_start_idx = len(triples) for line in f: items = line.strip().split() edge_index.append([int(items[0]), int(items[1])]) edge_attr.append(int(items[2])) triples.append((int(items[0]), int(items[2]), int(items[1]))) if items[0] not in entity_dic: entity_dic[items[0]] = num_entities num_entities += 1 if items[1] not in entity_dic: entity_dic[items[1]] = num_entities num_entities += 1 if items[2] not in relation_dic: relation_dic[items[2]] = num_relations num_relations += 1 count += 1 count_list.append(count) edge_index = torch.LongTensor(edge_index).t() edge_attr = torch.LongTensor(edge_attr) data = Data() data.edge_index = edge_index data.edge_attr = edge_attr def generate_mask(start, end): mask = torch.BoolTensor(count) mask[:] = False mask[start:end] = True return mask data.train_mask = generate_mask(0, count_list[0]) data.val_mask = generate_mask(count_list[0], count_list[1]) data.test_mask = generate_mask(count_list[1], count_list[2]) return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations
[docs]class KnowledgeGraphDataset(Dataset):
[docs] url = "https://raw.githubusercontent.com/thunlp/OpenKE/OpenKE-PyTorch/benchmarks"
def __init__(self, root, name): self.name = name super(KnowledgeGraphDataset, self).__init__(root) self.data = torch.load(self.processed_paths[0]) triple_config = torch.load(self.processed_paths[1]) self.triples = triple_config["triples"] self._train_start_index = triple_config["train_start_index"] self._valid_start_index = triple_config["valid_start_index"] self._test_start_index = triple_config["test_start_index"] self._num_entities = triple_config["num_entities"] self._num_relations = triple_config["num_relations"] @property
[docs] def raw_file_names(self): names = ["train2id.txt", "valid2id.txt", "test2id.txt"] return names
@property
[docs] def processed_file_names(self): return ["data.pt", "triple_config.pt"]
@property
[docs] def train_start_idx(self): return self._train_start_index
@property
[docs] def valid_start_idx(self): return self._valid_start_index
@property
[docs] def test_start_idx(self): return self._test_start_index
@property
[docs] def num_entities(self): return self._num_entities
@property
[docs] def num_relations(self): return self._num_relations
[docs] def get(self, idx): assert idx == 0 return self.data
[docs] def download(self): for name in self.raw_file_names: download_url("{}/{}/{}".format(self.url, self.name, name), self.raw_dir)
[docs] def process(self): data, triples, train_start_index, valid_start_index, test_start_index, num_entities, num_relations = read_triplet_data(self.raw_dir) torch.save(data, self.processed_paths[0]) triple_config = {"triples": triples, "train_start_index":train_start_index, "valid_start_index":valid_start_index, "test_start_index":test_start_index, "num_entities": num_entities, "num_relations":num_relations} torch.save(triple_config, self.processed_paths[1])
[docs]@register_dataset("fb13") class FB13Datset(KnowledgeGraphDataset): def __init__(self): dataset = "FB13" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(FB13Datset, self).__init__(path, dataset)
[docs]@register_dataset("fb15k") class FB15kDatset(KnowledgeGraphDataset): def __init__(self): dataset = "FB15K" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(FB15kDatset, self).__init__(path, dataset)
[docs]@register_dataset("fb15k237") class FB15k237Datset(KnowledgeGraphDataset): def __init__(self): dataset = "FB15K237" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(FB15k237Datset, self).__init__(path, dataset)
[docs]@register_dataset("wn18") class WN18Datset(KnowledgeGraphDataset): def __init__(self): dataset = "WN18" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(WN18Datset, self).__init__(path, dataset)
[docs]@register_dataset("wn18rr") class WN18RRDataset(KnowledgeGraphDataset): def __init__(self): dataset = "WN18RR" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(WN18RRDataset, self).__init__(path, dataset)
[docs]@register_dataset("fb13s") class FB13SDatset(KnowledgeGraphDataset):
[docs] url = "https://raw.githubusercontent.com/cenyk1230/test-data/main"
def __init__(self): dataset = "FB13-S" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(FB13SDatset, self).__init__(path, dataset)