Source code for cogdl.tasks.base_task

from abc import ABC, ABCMeta
import argparse
import atexit
import os
import torch
from cogdl.trainers import build_trainer

[docs]class LoadFrom(ABCMeta): def __call__(cls, *args, **kwargs): obj = type.__call__(cls, *args, **kwargs) obj.load_from_pretrained() if hasattr(obj, "model") and hasattr(obj, "device"): obj.model.set_device(obj.device) return obj
[docs]class BaseTask(ABC, metaclass=LoadFrom):
[docs] @staticmethod def add_args(parser: argparse.ArgumentParser): """Add task-specific arguments to the parser.""" pass
def __init__(self, args): super(BaseTask, self).__init__() os.makedirs("./checkpoints", exist_ok=True) self.loss_fn = None self.evaluator = None self.load_from_checkpoint = hasattr(args, "checkpoint") and args.checkpoint if self.load_from_checkpoint: self._checkpoint = args.checkpoint else: self._checkpoint = None if hasattr(args, "save_model") and args.save_model is not None: atexit.register(self.save_checkpoint) self.save_path = args.save_model else: self.save_path = None
[docs] def train(self): raise NotImplementedError
[docs] def load_from_pretrained(self): if self.load_from_checkpoint: try: ck_pt = torch.load(self._checkpoint) self.model.load_state_dict(ck_pt) except FileNotFoundError: print(f"'{self._checkpoint}' doesn't exists") return self.model
[docs] def save_checkpoint(self): if self.save_path and hasattr(self.model, "_parameters"):, self.save_path) print(f"Model saved in {self.save_path}")
[docs] def set_loss_fn(self, dataset): self.loss_fn = dataset.get_loss_fn() self.model.set_loss_fn(self.loss_fn)
[docs] def set_evaluator(self, dataset): self.evaluator = dataset.get_evaluator()
[docs] def get_trainer(self, model, args): if hasattr(args, "trainer") and args.trainer is not None: if "self_auxiliary_task" in args.trainer and not hasattr(model, "get_embeddings"): raise ValueError("Model ({}) must implement get_embeddings method".format(args.model)) return build_trainer(args) elif model.get_trainer(None, args) is not None: return model.get_trainer(None, args)(args) else: return None