tasks.node_classification_sampling

Module Contents

Classes

NodeClassificationSampling

Node classification task with sampling.

Functions

get_batches(train_nodes, train_labels, batch_size=64, shuffle=True)

tasks.node_classification_sampling.get_batches(train_nodes, train_labels, batch_size=64, shuffle=True)[source]
class tasks.node_classification_sampling.NodeClassificationSampling(args)[source]

Bases: tasks.BaseTask

Node classification task with sampling.

static add_args(parser)[source]

Add task-specific arguments to the parser.

train(self)[source]
_train_step(self)[source]
_test_step(self, split='val')[source]