Source code for cogdl.wrappers.data_wrapper.graph_classification.graph_classification_dw

from .. import DataWrapper
from cogdl.wrappers.tools.wrapper_utils import node_degree_as_feature, split_dataset
from cogdl.data import DataLoader


[docs]class GraphClassificationDataWrapper(DataWrapper):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--degree-node-features", action="store_true", help="Use one-hot degree vector as input node features") # parser.add_argument("--kfold", action="store_true", help="Use 10-fold cross-validation") parser.add_argument("--train-ratio", type=float, default=0.5) parser.add_argument("--test-ratio", type=float, default=0.3) parser.add_argument("--batch-size", type=int, default=16)
# fmt: on def __init__(self, dataset, degree_node_features=False, batch_size=32, train_ratio=0.5, test_ratio=0.3): super(GraphClassificationDataWrapper, self).__init__(dataset) self.dataset = dataset self.degree_node_features = degree_node_features self.train_ratio = train_ratio self.test_ratio = test_ratio self.batch_size = batch_size self.split_idx = None self.setup_node_features()
[docs] def train_wrapper(self): return DataLoader(self.dataset[self.split_idx[0]], batch_size=self.batch_size, shuffle=True, num_workers=4)
[docs] def val_wrapper(self): if self.split_idx[1] is not None: return DataLoader(self.dataset[self.split_idx[1]], batch_size=self.batch_size, shuffle=False, num_workers=4)
[docs] def test_wrapper(self): return DataLoader(self.dataset[self.split_idx[2]], batch_size=self.batch_size, shuffle=False, num_workers=4)
[docs] def setup_node_features(self): if self.degree_node_features or self.dataset.data[0].x is None: self.dataset.data = node_degree_as_feature(self.dataset.data) train_idx, val_idx, test_idx = split_dataset(len(self.dataset), self.train_ratio, self.test_ratio) self.split_idx = [train_idx, val_idx, test_idx]