Source code for cogdl.trainers.supervised_trainer

from abc import ABC, abstractmethod
from typing import Any

from cogdl.data import Dataset
from cogdl.models.supervised_model import (
    SupervisedHeterogeneousNodeClassificationModel,
    SupervisedHomogeneousNodeClassificationModel,
)
from cogdl.trainers.base_trainer import BaseTrainer


[docs]class SupervisedTrainer(BaseTrainer, ABC): @abstractmethod
[docs] def fit(self) -> None: raise NotImplemented
@abstractmethod
[docs] def predict(self) -> Any: raise NotImplemented
[docs]class SupervisedHeterogeneousNodeClassificationTrainer(BaseTrainer, ABC): @abstractmethod
[docs] def fit( self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset ) -> None: raise NotImplemented
# @abstractmethod # def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: # raise NotImplemented
[docs]class SupervisedHomogeneousNodeClassificationTrainer(BaseTrainer, ABC): @abstractmethod
[docs] def fit( self, model: SupervisedHomogeneousNodeClassificationModel, dataset: Dataset ) -> None: raise NotImplemented
# @abstractmethod # def predictAll(self) -> Any: # raise NotImplemented