import os.path as osp
import numpy as np
import scipy.io as sio
import torch
from cogdl.data import Graph, Dataset
from cogdl.utils import download_url, untar
[docs]def sample_mask(idx, length):
"""Create mask."""
mask = np.zeros(length)
mask[idx] = 1
return np.array(mask, dtype=np.bool)
[docs]class HANDataset(Dataset):
r"""The network datasets "ACM", "DBLP" and "IMDB" from the
`"Heterogeneous Graph Attention Network"
<https://arxiv.org/abs/1903.07293>`_ paper.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"han-acm"`,
:obj:`"han-dblp"`, :obj:`"han-imdb"`).
"""
def __init__(self, root, name):
self.name = name
self.url = f"https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true"
super(HANDataset, self).__init__(root)
self.data = torch.load(self.processed_paths[0])
self.num_edge = len(self.data.adj)
self.num_nodes = self.data.x.shape[0]
@property
def raw_file_names(self):
names = ["data.mat"]
return names
@property
def processed_file_names(self):
return ["data.pt"]
@property
def num_classes(self):
return torch.max(self.data.train_target).item() + 1
[docs] def read_gtn_data(self, folder):
data = sio.loadmat(osp.join(folder, "data.mat"))
if self.name == "han-acm" or self.name == "han-imdb":
truelabels, truefeatures = data["label"], data["feature"].astype(float)
elif self.name == "han-dblp":
truelabels, truefeatures = data["label"], data["features"].astype(float)
num_nodes = truefeatures.shape[0]
if self.name == "han-acm":
rownetworks = [data["PAP"] - np.eye(num_nodes), data["PLP"] - np.eye(num_nodes)]
elif self.name == "han-dblp":
rownetworks = [
data["net_APA"] - np.eye(num_nodes),
data["net_APCPA"] - np.eye(num_nodes),
data["net_APTPA"] - np.eye(num_nodes),
]
elif self.name == "han-imdb":
rownetworks = [
data["MAM"] - np.eye(num_nodes),
data["MDM"] - np.eye(num_nodes),
data["MYM"] - np.eye(num_nodes),
]
y = truelabels
train_idx = data["train_idx"]
val_idx = data["val_idx"]
test_idx = data["test_idx"]
train_mask = sample_mask(train_idx, y.shape[0])
val_mask = sample_mask(val_idx, y.shape[0])
test_mask = sample_mask(test_idx, y.shape[0])
y_train = np.argmax(y[train_mask, :], axis=1)
y_val = np.argmax(y[val_mask, :], axis=1)
y_test = np.argmax(y[test_mask, :], axis=1)
data = Graph()
A = []
for i, edge in enumerate(rownetworks):
edge_tmp = torch.from_numpy(np.vstack((edge.nonzero()[0], edge.nonzero()[1]))).type(torch.LongTensor)
value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor)
A.append((edge_tmp, value_tmp))
edge_tmp = torch.stack((torch.arange(0, num_nodes), torch.arange(0, num_nodes))).type(torch.LongTensor)
value_tmp = torch.ones(num_nodes).type(torch.FloatTensor)
A.append((edge_tmp, value_tmp))
data.adj = A
data.x = torch.from_numpy(truefeatures).type(torch.FloatTensor)
data.train_node = torch.from_numpy(train_idx[0]).type(torch.LongTensor)
data.train_target = torch.from_numpy(y_train).type(torch.LongTensor)
data.valid_node = torch.from_numpy(val_idx[0]).type(torch.LongTensor)
data.valid_target = torch.from_numpy(y_val).type(torch.LongTensor)
data.test_node = torch.from_numpy(test_idx[0]).type(torch.LongTensor)
data.test_target = torch.from_numpy(y_test).type(torch.LongTensor)
y = np.zeros((num_nodes), dtype=int)
x_index = torch.cat((data.train_node, data.valid_node, data.test_node))
y_index = torch.cat((data.train_target, data.valid_target, data.test_target))
y[x_index.numpy()] = y_index.numpy()
data.y = torch.from_numpy(y)
self.data = data
[docs] def get(self, idx):
assert idx == 0
return self.data
[docs] def apply_to_device(self, device):
self.data.x = self.data.x.to(device)
self.data.y = self.data.y.to(device)
self.data.train_node = self.data.train_node.to(device)
self.data.valid_node = self.data.valid_node.to(device)
self.data.test_node = self.data.test_node.to(device)
self.data.train_target = self.data.train_target.to(device)
self.data.valid_target = self.data.valid_target.to(device)
self.data.test_target = self.data.test_target.to(device)
new_adj = []
for (t1, t2) in self.data.adj:
new_adj.append((t1.to(device), t2.to(device)))
self.data.adj = new_adj
[docs] def download(self):
download_url(self.url, self.raw_dir, name=self.name + ".zip")
untar(self.raw_dir, self.name + ".zip")
[docs] def process(self):
self.read_gtn_data(self.raw_dir)
torch.save(self.data, self.processed_paths[0])
def __repr__(self):
return "{}()".format(self.name)
[docs]class ACM_HANDataset(HANDataset):
def __init__(self, data_path="data"):
dataset = "han-acm"
path = osp.join(data_path, dataset)
super(ACM_HANDataset, self).__init__(path, dataset)
[docs]class DBLP_HANDataset(HANDataset):
def __init__(self, data_path="data"):
dataset = "han-dblp"
path = osp.join(data_path, dataset)
super(DBLP_HANDataset, self).__init__(path, dataset)
[docs]class IMDB_HANDataset(HANDataset):
def __init__(self, data_path="data"):
dataset = "han-imdb"
path = osp.join(data_path, dataset)
super(IMDB_HANDataset, self).__init__(path, dataset)