import collections
import os.path as osp
import torch.utils.data
from .makedirs import makedirs
[docs]def to_list(x):
if not isinstance(x, collections.Iterable) or isinstance(x, str):
x = [x]
return x
[docs]def files_exist(files):
return all([osp.exists(f) for f in files])
[docs]class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://rusty1s.github.io/pycogdl/build/html/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`cogdl.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`cogdl.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`cogdl.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@property
[docs] def raw_file_names(self):
r"""The name of the files to find in the :obj:`self.raw_dir` folder in
order to skip the download."""
raise NotImplementedError
@property
[docs] def processed_file_names(self):
r"""The name of the files to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
raise NotImplementedError
[docs] def download(self):
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
[docs] def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
[docs] def __len__(self):
r"""The number of examples in the dataset."""
raise NotImplementedError
[docs] def get(self, idx):
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super(Dataset, self).__init__()
self.root = osp.expanduser(osp.normpath(root))
self.raw_dir = osp.join(self.root, "raw")
self.processed_dir = osp.join(self.root, "processed")
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self._download()
self._process()
@property
[docs] def num_features(self):
r"""Returns the number of features per node in the graph."""
return self[0].num_features
@property
[docs] def raw_paths(self):
r"""The filepaths to find in order to skip the download."""
files = to_list(self.raw_file_names)
return [osp.join(self.raw_dir, f) for f in files]
@property
[docs] def processed_paths(self):
r"""The filepaths to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
files = to_list(self.processed_file_names)
return [osp.join(self.processed_dir, f) for f in files]
[docs] def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
[docs] def _process(self):
if files_exist(self.processed_paths): # pragma: no cover
return
print("Processing...")
makedirs(self.processed_dir)
self.process()
print("Done!")
[docs] def __getitem__(self, idx): # pragma: no cover
r"""Gets the data object at index :obj:`idx` and transforms it (in case
a :obj:`self.transform` is given)."""
data = self.get(idx)
data = data if self.transform is None else self.transform(data)
return data
[docs] def __repr__(self): # pragma: no cover
return "{}({})".format(self.__class__.__name__, len(self))