Source code for cogdl.wrappers.data_wrapper.node_classification.node_classification_dw

from .. import DataWrapper
from cogdl.data import Graph


[docs]class FullBatchNodeClfDataWrapper(DataWrapper): def __init__(self, dataset): super(FullBatchNodeClfDataWrapper, self).__init__(dataset) self.dataset = dataset
[docs] def train_wrapper(self) -> Graph: return self.dataset.data
[docs] def val_wrapper(self): return self.dataset.data
[docs] def test_wrapper(self): return self.dataset.data
[docs] def pre_transform(self): self.dataset.data.add_remaining_self_loops()