Source code for cogdl.datasets.pyg

import os.path as osp

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Reddit, TUDataset, QM9
from torch_geometric.utils import remove_self_loops
from . import register_dataset


[docs]def normalize_feature(data): x_sum = torch.sum(data.x, dim=1) x_rev = x_sum.pow(-1).flatten() x_rev[torch.isnan(x_rev)] = 0. x_rev[torch.isinf(x_rev)] = 0. data.x = data.x * x_rev.unsqueeze(-1).expand_as(data.x) return data
[docs]@register_dataset("cora") class CoraDataset(Planetoid): def __init__(self): dataset = "Cora" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): Planetoid(path, dataset, transform=T.TargetIndegree()) super(CoraDataset, self).__init__(path, dataset, transform=T.TargetIndegree()) self.data = normalize_feature(self.data)
[docs]@register_dataset("citeseer") class CiteSeerDataset(Planetoid): def __init__(self): dataset = "CiteSeer" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): Planetoid(path, dataset, transform=T.TargetIndegree()) super(CiteSeerDataset, self).__init__(path, dataset, transform=T.TargetIndegree()) self.data = normalize_feature(self.data)
[docs]@register_dataset("pubmed") class PubMedDataset(Planetoid): def __init__(self): dataset = "PubMed" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): Planetoid(path, dataset, transform=T.TargetIndegree()) super(PubMedDataset, self).__init__(path, dataset, transform=T.TargetIndegree()) self.data = normalize_feature(self.data)
[docs]@register_dataset("reddit") class RedditDataset(Reddit): def __init__(self): self.url = "https://data.dgl.ai/dataset/reddit.zip" dataset = "Reddit" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): Reddit(path) super(RedditDataset, self).__init__(path, transform=T.TargetIndegree())
[docs]@register_dataset("mutag") class MUTAGDataset(TUDataset): def __init__(self): dataset = "MUTAG" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(MUTAGDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("imdb-b") class ImdbBinaryDataset(TUDataset): def __init__(self): dataset = "IMDB-BINARY" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(ImdbBinaryDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("imdb-m") class ImdbMultiDataset(TUDataset): def __init__(self): dataset = "IMDB-MULTI" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(ImdbMultiDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("collab") class CollabDataset(TUDataset): def __init__(self): dataset = "COLLAB" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(CollabDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("proteins") class ProtainsDataset(TUDataset): def __init__(self): dataset = "PROTEINS" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(ProtainsDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("reddit-b") class RedditBinary(TUDataset): def __init__(self): dataset = "REDDIT-BINARY" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(RedditBinary, self).__init__(path, name=dataset)
[docs]@register_dataset("reddit-multi-5k") class RedditMulti5K(TUDataset): def __init__(self): dataset = "REDDIT-MULTI-5K" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(RedditMulti5K, self).__init__(path, name=dataset)
[docs]@register_dataset("reddit-multi-12k") class RedditMulti12K(TUDataset): def __init__(self): dataset = "REDDIT-MULTI-12K" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(RedditMulti12K, self).__init__(path, name=dataset)
[docs]@register_dataset("ptc-mr") class PTCMRDataset(TUDataset): def __init__(self): dataset = "PTC_MR" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(PTCMRDataset, self).__init__(path, name=dataset)
[docs]@register_dataset("nci1") class NCT1Dataset(TUDataset): def __init__(self): dataset = "NCI1" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(NCT1Dataset, self).__init__(path, name=dataset)
[docs]@register_dataset("nci109") class NCT109Dataset(TUDataset): def __init__(self): dataset = "NCI109" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(NCT109Dataset, self).__init__(path, name=dataset)
[docs]@register_dataset("enzymes") class ENZYMES(TUDataset): def __init__(self): dataset = "ENZYMES" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): TUDataset(path, name=dataset) super(ENZYMES, self).__init__(path, name=dataset)
[docs] def __getitem__(self, idx): if isinstance(idx, int): data = self.get(self.indices()[idx]) data = data if self.transform is None else self.transform(data) edge_nodes = data.edge_index.max() + 1 if edge_nodes < data.x.size(0): data.x = data.x[:edge_nodes] return data else: return self.index_select(idx)
[docs]@register_dataset("qm9") class QM9Dataset(QM9): def __init__(self): dataset = "QM9" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) target=0 class MyTransform(object): def __call__(self, data): # Specify target. data.y = data.y[:, target] return data class Complete(object): def __call__(self, data): device = data.edge_index.device row = torch.arange(data.num_nodes, dtype=torch.long, device=device) col = torch.arange(data.num_nodes, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) col = col.repeat(data.num_nodes) edge_index = torch.stack([row, col], dim=0) edge_attr = None if data.edge_attr is not None: idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] size = list(data.edge_attr.size()) size[0] = data.num_nodes * data.num_nodes edge_attr = data.edge_attr.new_zeros(size) edge_attr[idx] = data.edge_attr edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) data.edge_attr = edge_attr data.edge_index = edge_index return data transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)]) if not osp.exists(path): QM9(path) super(QM9Dataset, self).__init__(path)