Source code for cogdl.tasks.pretrain

import torch

from . import register_task, BaseTask
from cogdl.models import build_model


[docs]@register_task("pretrain") class PretrainTask(BaseTask): @staticmethod
[docs] def add_args(parser):
"""Add task-specific arguments to the parser.""" # fmt: off # parser.add_argument("--num-features", type=int) # fmt: on def __init__(self, args): super(PretrainTask, self).__init__(args) self.device = torch.device("cpu" if args.cpu else "cuda") self.model = build_model(args) self.model = self.model.to(self.device)
[docs] def train(self): return self.model.trainer.fit()