Source code for cogdl.datasets

import importlib

from cogdl.data.dataset import Dataset
from .customized_data import NodeDataset, GraphDataset

try:
    import torch_geometric
except ImportError:
    pyg = False
else:
    pyg = True


DATASET_REGISTRY = {}


[docs]def register_dataset(name): """ New dataset types can be added to cogdl with the :func:`register_dataset` function decorator. For example:: @register_dataset('my_dataset') class MyDataset(): (...) Args: name (str): the name of the dataset """ def register_dataset_cls(cls): # if name in DATASET_REGISTRY: # raise ValueError("Cannot register duplicate dataset ({})".format(name)) if not issubclass(cls, Dataset) and (pyg and not issubclass(cls, torch_geometric.data.Dataset)): raise ValueError("Dataset ({}: {}) must extend cogdl.data.Dataset".format(name, cls.__name__)) DATASET_REGISTRY[name] = cls return cls return register_dataset_cls
[docs]def try_import_dataset(dataset): if dataset not in DATASET_REGISTRY: if dataset in SUPPORTED_DATASETS: importlib.import_module(SUPPORTED_DATASETS[dataset]) else: print(f"Failed to import {dataset} dataset.") return False return True
[docs]def build_dataset(args): if not try_import_dataset(args.dataset): assert hasattr(args, "task") dataset = build_dataset_from_path(args.dataset, args.task) if dataset is not None: return dataset exit(1) else: dataset = DATASET_REGISTRY[args.dataset]() if dataset.num_classes > 0: args.num_classes = dataset.num_classes if dataset.num_features > 0: args.num_features = dataset.num_features return dataset
[docs]def build_dataset_from_name(dataset): if not try_import_dataset(dataset): exit(1) return DATASET_REGISTRY[dataset]()
[docs]def build_dataset_from_path(data_path, task=None, dataset=None): if dataset is not None and dataset in SUPPORTED_DATASETS: if try_import_dataset(dataset): return DATASET_REGISTRY[dataset](data_path=data_path) if task is None: return None if "node_classification" in task: return NodeDataset(data_path) elif "graph_classification" in task: return GraphDataset(data_path) else: return None
SUPPORTED_DATASETS = { "kdd_icdm": "cogdl.datasets.gcc_data", "sigir_cikm": "cogdl.datasets.gcc_data", "sigmod_icde": "cogdl.datasets.gcc_data", "usa-airport": "cogdl.datasets.gcc_data", "test_small": "cogdl.datasets.test_data", "ogbn-arxiv": "cogdl.datasets.ogb", "ogbn-products": "cogdl.datasets.ogb", "ogbn-proteins": "cogdl.datasets.ogb", "ogbn-mag": "cogdl.datasets.ogb", "ogbn-papers100M": "cogdl.datasets.ogb", "ogbg-molbace": "cogdl.datasets.ogb", "ogbg-molhiv": "cogdl.datasets.ogb", "ogbg-molpcba": "cogdl.datasets.ogb", "ogbg-ppa": "cogdl.datasets.ogb", "ogbg-code": "cogdl.datasets.ogb", "amazon": "cogdl.datasets.gatne", "twitter": "cogdl.datasets.gatne", "youtube": "cogdl.datasets.gatne", "gtn-acm": "cogdl.datasets.gtn_data", "gtn-dblp": "cogdl.datasets.gtn_data", "gtn-imdb": "cogdl.datasets.gtn_data", "fb13": "cogdl.datasets.kg_data", "fb15k": "cogdl.datasets.kg_data", "fb15k237": "cogdl.datasets.kg_data", "wn18": "cogdl.datasets.kg_data", "wn18rr": "cogdl.datasets.kg_data", "fb13s": "cogdl.datasets.kg_data", "cora": "cogdl.datasets.planetoid_data", "citeseer": "cogdl.datasets.planetoid_data", "pubmed": "cogdl.datasets.planetoid_data", "blogcatalog": "cogdl.datasets.matlab_matrix", "flickr-ne": "cogdl.datasets.matlab_matrix", "dblp-ne": "cogdl.datasets.matlab_matrix", "youtube-ne": "cogdl.datasets.matlab_matrix", "wikipedia": "cogdl.datasets.matlab_matrix", "ppi-ne": "cogdl.datasets.matlab_matrix", "han-acm": "cogdl.datasets.han_data", "han-dblp": "cogdl.datasets.han_data", "han-imdb": "cogdl.datasets.han_data", "mutag": "cogdl.datasets.tu_data", "imdb-b": "cogdl.datasets.tu_data", "imdb-m": "cogdl.datasets.tu_data", "collab": "cogdl.datasets.tu_data", "proteins": "cogdl.datasets.tu_data", "reddit-b": "cogdl.datasets.tu_data", "reddit-multi-5k": "cogdl.datasets.tu_data", "reddit-multi-12k": "cogdl.datasets.tu_data", "ptc-mr": "cogdl.datasets.tu_data", "nci1": "cogdl.datasets.tu_data", "nci109": "cogdl.datasets.tu_data", "enzymes": "cogdl.datasets.tu_data", "yelp": "cogdl.datasets.saint_data", "amazon-s": "cogdl.datasets.saint_data", "flickr": "cogdl.datasets.saint_data", "reddit": "cogdl.datasets.saint_data", "ppi": "cogdl.datasets.saint_data", "ppi-large": "cogdl.datasets.saint_data", "test_bio": "cogdl.datasets.strategies_data", "test_chem": "cogdl.datasets.strategies_data", "bio": "cogdl.datasets.strategies_data", "chem": "cogdl.datasets.strategies_data", "bace": "cogdl.datasets.strategies_data", "bbbp": "cogdl.datasets.strategies_data", "l0fos": "cogdl.datasets.oagbert_data", "aff30": "cogdl.datasets.oagbert_data", "arxivvenue": "cogdl.datasets.oagbert_data", "yelp2018": "cogdl.datasets.rec_data", "ali": "cogdl.datasets.rec_data", "amazon-rec": "cogdl.datasets.rec_data", }