from abc import ABC, abstractmethod
from typing import Any
from cogdl.data import Dataset
from cogdl.models.supervised_model import (
SupervisedHeterogeneousNodeClassificationModel,
SupervisedHomogeneousNodeClassificationModel,
)
from cogdl.trainers.base_trainer import BaseTrainer
[docs]class SupervisedTrainer(BaseTrainer, ABC):
@abstractmethod
[docs] def fit(self) -> None:
raise NotImplemented
@abstractmethod
[docs] def predict(self) -> Any:
raise NotImplemented
[docs]class SupervisedHeterogeneousNodeClassificationTrainer(BaseTrainer, ABC):
@abstractmethod
[docs] def fit(
self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset
) -> None:
raise NotImplemented
# @abstractmethod
# def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any:
# raise NotImplemented
[docs]class SupervisedHomogeneousNodeClassificationTrainer(BaseTrainer, ABC):
@abstractmethod
[docs] def fit(
self, model: SupervisedHomogeneousNodeClassificationModel, dataset: Dataset
) -> None:
raise NotImplemented
# @abstractmethod
# def predictAll(self) -> Any:
# raise NotImplemented