Source code for cogdl.models.supervised_model

from abc import ABC, abstractmethod
from typing import Any, Optional, Type
from typing import TYPE_CHECKING

from cogdl.models import BaseModel

if TYPE_CHECKING:
    # trick for resolve circular import
    from cogdl.trainers.supervised_trainer import (
        SupervisedHomogeneousNodeClassificationTrainer,
        SupervisedHeterogeneousNodeClassificationTrainer,
    )


[docs]class SupervisedModel(BaseModel, ABC): @abstractmethod
[docs] def loss(self, data: Any) -> Any: raise NotImplemented
[docs]class SupervisedHeterogeneousNodeClassificationModel(BaseModel, ABC): @abstractmethod
[docs] def loss(self, data: Any) -> Any: raise NotImplemented
[docs] def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: raise NotImplemented
@staticmethod
[docs] def get_trainer( taskType: Any, args: Any ) -> "Optional[Type[SupervisedHeterogeneousNodeClassificationTrainer]]": return None
[docs]class SupervisedHomogeneousNodeClassificationModel(BaseModel, ABC): @abstractmethod
[docs] def loss(self, data: Any) -> Any: raise NotImplemented
@abstractmethod
[docs] def predict(self, data: Any) -> Any: raise NotImplemented
@staticmethod
[docs] def get_trainer( taskType: Any, args: Any, ) -> "Optional[Type[SupervisedHomogeneousNodeClassificationTrainer]]": return None