Source code for cogdl.wrappers.model_wrapper.node_classification.correct_smooth_mw

import copy

import torch
import numpy as np

from cogdl.wrappers.data_wrapper import DataWrapper

from .node_classification_mw import NodeClfModelWrapper


[docs]class CorrectSmoothModelWrapper(NodeClfModelWrapper):
[docs] @staticmethod def add_args(parser): pass
def __init__(self, model, optimizer_cfg): super(CorrectSmoothModelWrapper, self).__init__(model, optimizer_cfg) self.model = model self.optimizer_cfg = optimizer_cfg
[docs] def val_step(self, subgraph): graph = subgraph pred = self.model(graph) pred = self.model.postprocess(graph, pred) y = graph.y val_mask = graph.val_mask loss = self.default_loss_fn(pred[val_mask], y[val_mask]) metric = self.evaluate(pred[val_mask], graph.y[val_mask], metric="auto") self.note("val_loss", loss.item()) self.note("val_metric", metric)
[docs] def test_step(self, batch): graph = batch pred = self.model(graph) pred = self.model.postprocess(graph, pred) test_mask = batch.test_mask loss = self.default_loss_fn(pred[test_mask], batch.y[test_mask]) metric = self.evaluate(pred[test_mask], batch.y[test_mask], metric="auto") self.note("test_loss", loss.item()) self.note("test_metric", metric)