import os.path as osp
import torch
from ogb.nodeproppred import NodePropPredDataset
from ogb.graphproppred import GraphPropPredDataset
from . import register_dataset
from cogdl.data import Dataset, Data, DataLoader
from cogdl.utils import cross_entropy_loss, accuracy, remove_self_loops
[docs]def coalesce(row, col, edge_attr=None):
row = torch.tensor(row)
col = torch.tensor(col)
if edge_attr is not None:
edge_attr = torch.tensor(edge_attr)
num = col.shape[0] + 1
idx = torch.full((num,), -1, dtype=torch.float)
idx[1:] = row * num + col
mask = idx[1:] > idx[:-1]
if mask.all():
return row, col, edge_attr
row = row[mask]
col = col[mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]
return row, col, edge_attr
[docs]class OGBNDataset(Dataset):
def __init__(self, root, name):
super(OGBNDataset, self).__init__(root)
dataset = NodePropPredDataset(name, root)
graph, y = dataset[0]
x = torch.tensor(graph["node_feat"])
y = torch.tensor(y.squeeze())
row, col, edge_attr = coalesce(graph["edge_index"][0], graph["edge_index"][1], graph["edge_feat"])
edge_index = torch.stack([row, col], dim=0)
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])
edge_index = torch.stack([row, col], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
self.data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
self.data.num_nodes = graph["num_nodes"]
assert self.data.num_nodes == self.data.x.shape[0]
# split
split_index = dataset.get_idx_split()
self.data.train_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.test_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.val_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.train_mask[split_index["train"]] = True
self.data.test_mask[split_index["test"]] = True
self.data.val_mask[split_index["valid"]] = True
self.transform = None
[docs] def get(self, idx):
assert idx == 0
return self.data
[docs] def get_loss_fn(self):
return cross_entropy_loss
[docs] def get_evaluator(self):
return accuracy
def _download(self):
pass
def _process(self):
pass
[docs]@register_dataset("ogbn-arxiv")
class OGBArxivDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-arxiv"
path = "data"
super(OGBArxivDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbn-products")
class OGBProductsDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-products"
path = "data"
super(OGBProductsDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbn-proteins")
class OGBProteinsDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-proteins"
path = "data"
super(OGBProteinsDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbn-mag")
class OGBMAGDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-mag"
path = "data"
super(OGBMAGDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbn-papers100M")
class OGBPapers100MDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-papers100M"
path = "data"
super(OGBPapers100MDataset, self).__init__(path, dataset)
[docs]class OGBGDataset(Dataset):
def __init__(self, root, name):
super(OGBGDataset, self).__init__(root)
self.name = name
self.dataset = GraphPropPredDataset(self.name, root)
self.graphs = []
self.all_nodes = 0
self.all_edges = 0
for i in range(len(self.dataset.graphs)):
graph, label = self.dataset[i]
data = Data(
x=torch.tensor(graph["node_feat"], dtype=torch.float),
edge_index=torch.tensor(graph["edge_index"]),
edge_attr=None if "edge_feat" not in graph else torch.tensor(graph["edge_feat"], dtype=torch.float),
y=torch.tensor(label),
)
data.num_nodes = graph["num_nodes"]
self.graphs.append(data)
self.all_nodes += graph["num_nodes"]
self.all_edges += graph["edge_index"].shape[1]
self.transform = None
[docs] def get_loader(self, args):
split_index = self.dataset.get_idx_split()
train_loader = DataLoader(self.get_subset(split_index["train"]), batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(self.get_subset(split_index["valid"]), batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(self.get_subset(split_index["test"]), batch_size=args.batch_size, shuffle=False)
return train_loader, valid_loader, test_loader
[docs] def get_subset(self, subset):
datalist = []
for idx in subset:
datalist.append(self.graphs[idx])
return datalist
[docs] def get(self, idx):
return self.graphs[idx]
def _download(self):
pass
def _process(self):
pass
@property
def num_classes(self):
return int(self.dataset.num_classes)
[docs]@register_dataset("ogbg-molbace")
class OGBMolbaceDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-molbace"
path = "data"
super(OGBMolbaceDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbg-molhiv")
class OGBMolhivDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-molhiv"
path = "data"
super(OGBMolhivDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbg-molpcba")
class OGBMolpcbaDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-molpcba"
path = "data"
super(OGBMolpcbaDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbg-ppa")
class OGBPpaDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-ppa"
path = "data"
super(OGBPpaDataset, self).__init__(path, dataset)
[docs]@register_dataset("ogbg-code")
class OGBCodeDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-code"
path = "data"
super(OGBCodeDataset, self).__init__(path, dataset)