import glob
import os
import os.path as osp
import shutil
import zipfile
import numpy as np
import torch
import torch.nn.functional as F
from cogdl.data.dataset import MultiGraphDataset
from cogdl.data import Graph
from cogdl.utils import download_url
[docs]def normalize_feature(data):
x_sum = torch.sum(data.x, dim=1)
x_rev = x_sum.pow(-1).flatten()
x_rev[torch.isnan(x_rev)] = 0.0
x_rev[torch.isinf(x_rev)] = 0.0
data.x = data.x * x_rev.unsqueeze(-1).expand_as(data.x)
return data
[docs]def parse_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None):
src = [[float(x) for x in line.split(sep)[start:end]] for line in src]
src = torch.tensor(src, dtype=dtype).squeeze()
return src
[docs]def read_txt_array(path, sep=None, start=0, end=None, dtype=None, device=None):
with open(path, "r") as f:
src = f.read().split("\n")[:-1]
return parse_txt_array(src, sep, start, end, dtype, device)
[docs]def read_file(folder, prefix, name, dtype=None):
path = osp.join(folder, "{}_{}.txt".format(prefix, name))
return read_txt_array(path, sep=",", dtype=dtype)
[docs]def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None
def _split(edge_index, batch, x=None, y=None, edge_attr=None):
node_slice = np.bincount(batch).tolist()
row, _ = edge_index
edge_slice = np.bincount(batch[row]).tolist()
if edge_attr is not None:
edge_attr = edge_attr.split(edge_slice)
edge_index_t = edge_index.T.split(edge_slice)
if x is not None:
x = x.split(node_slice)
num_nodes = [i.shape[0] for i in x]
num_nodes_cum = np.cumsum(num_nodes).tolist()
else:
num_nodes_cum = [edge.max().item() + 1 for edge in edge_index_t]
num_nodes_cum = [0] + num_nodes_cum
if edge_index_t[-1].min() > 0:
edge_index_t = [edge_index_t[i] - num_nodes_cum[i] for i in range(len(edge_index_t))]
data = []
for i in range(len(node_slice)):
g = Graph(edge_index=edge_index_t[i].T)
if x is not None:
g.x = x[i]
if y is not None:
g.y = y[i].view(1)
if edge_attr is not None:
g.edge_attr = edge_attr[i]
data.append(g)
return data
[docs]def segment(src, indptr):
out_list = []
for i in range(indptr.size(-1) - 1):
indexptr = torch.arange(indptr[..., i].item(), indptr[..., i + 1].item(), dtype=torch.int64)
src_data = src.index_select(indptr.dim() - 1, indexptr)
out = torch.sum(src_data, dim=indptr.dim() - 1, keepdim=True)
out_list.append(out)
return torch.cat(out_list, dim=indptr.dim() - 1)
[docs]def coalesce(index, value, m, n):
row = index[0]
col = index[1]
idx = col.new_zeros(col.numel() + 1)
idx[1:] = row
idx[1:] *= n
idx[1:] += col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
row = row[perm]
col = col[perm]
if value is not None:
value = value[perm]
idx = col.new_full((col.numel() + 1,), -1)
idx[1:] = n * row + col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return torch.stack([row, col], dim=0), value
row = row[mask]
col = col[mask]
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1,), value.size(0))])
value = segment(value, ptr)
value = value[0] if isinstance(value, tuple) else value
return torch.stack([row, col], dim=0), value
[docs]def read_tu_data(folder, prefix):
files = glob.glob(osp.join(folder, "{}_*.txt".format(prefix)))
names = [f.split(os.sep)[-1][len(prefix) + 1 : -4] for f in files]
edge_index = read_file(folder, prefix, "A", torch.long).t() - 1
batch = read_file(folder, prefix, "graph_indicator", torch.long) - 1
node_attributes = node_labels = None
if "node_attributes" in names:
node_attributes = read_file(folder, prefix, "node_attributes")
if "node_labels" in names:
node_labels = read_file(folder, prefix, "node_labels", torch.long)
if node_labels.dim() == 1:
node_labels = node_labels.unsqueeze(-1)
node_labels = node_labels - node_labels.min(dim=0)[0]
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])
edge_attributes, edge_labels = None, None
if "edge_attributes" in names:
edge_attributes = read_file(folder, prefix, "edge_attributes")
if "edge_labels" in names:
edge_labels = read_file(folder, prefix, "edge_labels", torch.long)
if edge_labels.dim() == 1:
edge_labels = edge_labels.unsqueeze(-1)
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
edge_attr = cat([edge_attributes, edge_labels])
y = None
if "graph_attributes" in names: # Regression problem.
y = read_file(folder, prefix, "graph_attributes")
elif "graph_labels" in names: # Classification problem.
y = read_file(folder, prefix, "graph_labels", torch.long)
_, y = y.unique(sorted=True, return_inverse=True)
num_nodes = edge_index.max().item() + 1 if x is None else x.size(0)
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, num_nodes)
if x is not None:
x = x[:, num_node_attributes(x) :]
if edge_attr is not None:
edge_attr = edge_attr[:, : num_edge_attributes(edge_attr)]
graphs = _split(edge_index, batch=batch, x=x, y=y, edge_attr=edge_attr)
return graphs, y
[docs]def num_node_labels(x=None):
if x is None:
return 0
for i in range(x.size(1)):
_x = x[:, i:]
if ((_x == 0) | (_x == 1)).all() and (_x.sum(dim=1) == 1).all():
return x.size(1) - i
return 0
[docs]def num_node_attributes(x=None):
if x is None:
return 0
return x.size(1) - num_node_labels(x)
[docs]def num_edge_labels(edge_attr=None):
if edge_attr is None:
return 0
for i in range(edge_attr.size(1)):
if edge_attr[:, i:].sum() == edge_attr.size(0):
return edge_attr.size(1) - i
return 0
[docs]def num_edge_attributes(edge_attr=None):
if edge_attr is None:
return 0
return edge_attr.size(1) - num_edge_labels(edge_attr)
[docs]class TUDataset(MultiGraphDataset):
url = "https://www.chrsmrrs.com/graphkerneldatasets"
def __init__(self, root, name):
self.name = name
super(TUDataset, self).__init__(root)
# self.data = torch.load(self.processed_paths[0])
# if self.data[0].x is not None:
# num_node_attributes = self.num_node_attributes
# self.data.x = self.data.x[:, num_node_attributes:]
# if self.data.edge_attr is not None:
# num_edge_attributes = self.num_edge_attributes
# self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:]
self.data, self.y = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
names = ["A", "graph_indicator"]
return ["{}_{}.txt".format(self.name, name) for name in names]
@property
def processed_file_names(self):
return "data.pt"
[docs] def download(self):
url = self.url
folder = osp.join(self.root)
path = download_url("{}/{}.zip".format(url, self.name), folder)
with zipfile.ZipFile(path, "r") as f:
f.extractall(folder)
os.unlink(path)
shutil.rmtree(self.raw_dir)
os.rename(osp.join(folder, self.name), self.raw_dir)
[docs] def process(self):
data = read_tu_data(self.raw_dir, self.name)
torch.save(data, self.processed_paths[0])
@property
def num_classes(self):
r"""The number of classes in the dataset."""
return self.y.max().item() + 1 if self.y.dim() == 1 else self.y.size(1)
def __len__(self):
return len(self.data)
[docs]class MUTAGDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "MUTAG"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(MUTAGDataset, self).__init__(path, name=dataset)
[docs]class ImdbBinaryDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "IMDB-BINARY"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(ImdbBinaryDataset, self).__init__(path, name=dataset)
[docs]class ImdbMultiDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "IMDB-MULTI"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(ImdbMultiDataset, self).__init__(path, name=dataset)
[docs]class CollabDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "COLLAB"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(CollabDataset, self).__init__(path, name=dataset)
[docs]class ProteinsDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "PROTEINS"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(ProteinsDataset, self).__init__(path, name=dataset)
[docs]class RedditBinary(TUDataset):
def __init__(self, data_path="data"):
dataset = "REDDIT-BINARY"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(RedditBinary, self).__init__(path, name=dataset)
[docs]class RedditMulti5K(TUDataset):
def __init__(self, data_path="data"):
dataset = "REDDIT-MULTI-5K"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(RedditMulti5K, self).__init__(path, name=dataset)
[docs]class RedditMulti12K(TUDataset):
def __init__(self, data_path="data"):
dataset = "REDDIT-MULTI-12K"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(RedditMulti12K, self).__init__(path, name=dataset)
[docs]class PTCMRDataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "PTC_MR"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(PTCMRDataset, self).__init__(path, name=dataset)
[docs]class NCI1Dataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "NCI1"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(NCI1Dataset, self).__init__(path, name=dataset)
[docs]class NCI109Dataset(TUDataset):
def __init__(self, data_path="data"):
dataset = "NCI109"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(NCI109Dataset, self).__init__(path, name=dataset)
[docs]class ENZYMES(TUDataset):
def __init__(self, data_path="data"):
dataset = "ENZYMES"
path = osp.join(data_path, dataset)
if not osp.exists(path):
TUDataset(path, name=dataset)
super(ENZYMES, self).__init__(path, name=dataset)
def __getitem__(self, idx):
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
edge_nodes = data.edge_index.max() + 1
if edge_nodes < data.x.size(0):
data.x = data.x[:edge_nodes]
return data
else:
return self.index_select(idx)