Source code for cogdl.data.dataloader

from abc import ABCMeta
import torch.utils.data
from torch.utils.data.dataloader import default_collate

from cogdl.data import Batch, Graph

try:
    from typing import GenericMeta  # python 3.6
except ImportError:
    # in 3.7, genericmeta doesn't exist but we don't need it
    class GenericMeta(type):
        pass


class RecordParameters(ABCMeta):
    def __call__(cls, *args, **kwargs):
        obj = type.__call__(cls, *args, **kwargs)
        obj.record_parameters([args, kwargs])
        return obj


class GenericRecordParameters(GenericMeta, RecordParameters):
    pass


[docs]class DataLoader(torch.utils.data.DataLoader, metaclass=GenericRecordParameters): r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a mini-batch. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch (default: :obj:`True`) """ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): if "collate_fn" not in kwargs or kwargs["collate_fn"] is None: kwargs["collate_fn"] = self.collate_fn super(DataLoader, self).__init__( dataset, batch_size, shuffle, **kwargs, )
[docs] @staticmethod def collate_fn(batch): item = batch[0] if isinstance(item, Graph): return Batch.from_data_list(batch) elif isinstance(item, torch.Tensor): return default_collate(batch) elif isinstance(item, float): return torch.tensor(batch, dtype=torch.float) raise TypeError("DataLoader found invalid type: {}".format(type(item)))
[docs] def get_parameters(self): return self.default_kwargs
[docs] def record_parameters(self, params): self.default_kwargs = params