Source code for cogdl.datasets.han_data

import sys
import time
import os
import os.path as osp
import requests
import shutil
import tqdm
import pickle
import numpy as np
import scipy.io as sio
import scipy.sparse as sp

import torch

from cogdl.data import Data, Dataset, download_url

from . import register_dataset


[docs]def untar(path, fname, deleteTar=True): """ Unpacks the given archive file to the same directory, then (by default) deletes the archive file. """ print('unpacking ' + fname) fullpath = os.path.join(path, fname) shutil.unpack_archive(fullpath, path) if deleteTar: os.remove(fullpath)
[docs]def sample_mask(idx, l): """Create mask.""" mask = np.zeros(l) mask[idx] = 1 return np.array(mask, dtype=np.bool)
[docs]class HANDataset(Dataset): r"""The network datasets "ACM", "DBLP" and "IMDB" from the `"Heterogeneous Graph Attention Network" <https://arxiv.org/abs/1903.07293>`_ paper. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (:obj:`"han-acm"`, :obj:`"han-dblp"`, :obj:`"han-imdb"`). """ def __init__(self, root, name): self.name = name self.url = f'https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true' super(HANDataset, self).__init__(root) self.data = torch.load(self.processed_paths[0]) self.num_classes = torch.max(self.data.train_target).item() + 1 self.num_edge = len(self.data.adj) self.num_nodes = self.data.x.shape[0] @property
[docs] def raw_file_names(self): names = ["data.mat"] return names
@property
[docs] def processed_file_names(self): return ["data.pt"]
[docs] def read_gtn_data(self, folder): data = sio.loadmat(osp.join(folder, 'data.mat')) if self.name == 'han-acm' or self.name == 'han-imdb': truelabels, truefeatures = data['label'], data['feature'].astype(float) elif self.name == 'han-dblp': truelabels, truefeatures = data['label'], data['features'].astype(float) num_nodes = truefeatures.shape[0] if self.name == 'han-acm': rownetworks = [data['PAP'] - np.eye(num_nodes), data['PLP'] - np.eye(num_nodes)] elif self.name == 'han-dblp': rownetworks = [data['net_APA'] - np.eye(num_nodes), data['net_APCPA'] - np.eye(num_nodes), data['net_APTPA'] - np.eye(num_nodes)] elif self.name == 'han-imdb': rownetworks = [data['MAM'] - np.eye(num_nodes), data['MDM'] - np.eye(num_nodes), data['MYM'] - np.eye(num_nodes)] y = truelabels train_idx = data['train_idx'] val_idx = data['val_idx'] test_idx = data['test_idx'] train_mask = sample_mask(train_idx, y.shape[0]) val_mask = sample_mask(val_idx, y.shape[0]) test_mask = sample_mask(test_idx, y.shape[0]) y_train = np.argmax(y[train_mask, :], axis=1) y_val = np.argmax(y[val_mask, :], axis=1) y_test = np.argmax(y[test_mask, :], axis=1) data = Data() A = [] for i, edge in enumerate(rownetworks): edge_tmp = torch.from_numpy(np.vstack((edge.nonzero()[0], edge.nonzero()[1]))).type(torch.LongTensor) value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) A.append((edge_tmp, value_tmp)) edge_tmp = torch.stack((torch.arange(0,num_nodes), torch.arange(0,num_nodes))).type(torch.LongTensor) value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) A.append((edge_tmp, value_tmp)) data.adj = A data.x = torch.from_numpy(truefeatures).type(torch.FloatTensor) data.train_node = torch.from_numpy(train_idx[0]).type(torch.LongTensor) data.train_target = torch.from_numpy(y_train).type(torch.LongTensor) data.valid_node = torch.from_numpy(val_idx[0]).type(torch.LongTensor) data.valid_target = torch.from_numpy(y_val).type(torch.LongTensor) data.test_node = torch.from_numpy(test_idx[0]).type(torch.LongTensor) data.test_target = torch.from_numpy(y_test).type(torch.LongTensor) self.data = data
[docs] def get(self, idx): assert idx == 0 return self.data
[docs] def apply_to_device(self, device): self.data.x = self.data.x.to(device) self.data.train_node = self.data.train_node.to(device) self.data.valid_node = self.data.valid_node.to(device) self.data.test_node = self.data.test_node.to(device) self.data.train_target = self.data.train_target.to(device) self.data.valid_target = self.data.valid_target.to(device) self.data.test_target = self.data.test_target.to(device) new_adj = [] for (t1, t2) in self.data.adj: new_adj.append((t1.to(device), t2.to(device))) self.data.adj = new_adj
[docs] def download(self): download_url(self.url, self.raw_dir, name=self.name + '.zip') untar(self.raw_dir, self.name + '.zip')
[docs] def process(self): self.read_gtn_data(self.raw_dir) torch.save(self.data, self.processed_paths[0])
[docs] def __repr__(self): return "{}()".format(self.name)
[docs]@register_dataset("han-acm") class ACM_HANDataset(HANDataset): def __init__(self): dataset = "han-acm" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(ACM_HANDataset, self).__init__(path, dataset)
[docs]@register_dataset("han-dblp") class DBLP_HANDataset(HANDataset): def __init__(self): dataset = "han-dblp" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(DBLP_HANDataset, self).__init__(path, dataset)
[docs]@register_dataset("han-imdb") class IMDB_HANDataset(HANDataset): def __init__(self): dataset = "han-imdb" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) super(IMDB_HANDataset, self).__init__(path, dataset)