Source code for cogdl.trainers.sampled_trainer

from abc import ABC, abstractmethod
from typing import Any
import copy

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from cogdl.data import Dataset
from cogdl.data.sampler import NodeSampler, EdgeSampler, RWSampler, MRWSampler, LayerSampler
from cogdl.models.supervised_model import (
    SupervisedHeterogeneousNodeClassificationModel,
    SupervisedHomogeneousNodeClassificationModel,
)
from cogdl.trainers.supervised_trainer import SupervisedHeterogeneousNodeClassificationTrainer

[docs]class SampledTrainer(SupervisedHeterogeneousNodeClassificationTrainer): @abstractmethod
[docs] def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset): raise NotImplemented
[docs]class SAINTTrainer(SampledTrainer): def __init__(self, args): self.device = args.device_id[0] if not args.cpu else "cpu" self.patience = args.patience self.max_epoch = args.max_epoch self.lr = args.lr self.weight_decay = args.weight_decay self.args_sampler = self.sampler_from_args(args) @staticmethod
[docs] def build_trainer_from_args(args): pass
[docs] def sampler_from_args(self, args): args_sampler = {} args_sampler["sampler"] = args.sampler args_sampler["sample_coverage"] = args.sample_coverage args_sampler["size_subgraph"] = args.size_subgraph args_sampler["num_walks"] = args.num_walks args_sampler["walk_length"] = args.walk_length args_sampler["size_frontier"] = args.size_frontier return args_sampler
[docs] def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset): self.data = dataset.data self.data.apply(lambda x: x.to(self.device)) self.model = model if self.args_sampler["sampler"] == "node": self.sampler = NodeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "edge": self.sampler = EdgeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "rw": self.sampler = RWSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "mrw": self.sampler = MRWSampler(self.data, self.args_sampler) self.optimizer = torch.optim.Adam( model.parameters(), lr=self.lr, weight_decay=self.weight_decay ) epoch_iter = tqdm(range(self.max_epoch)) patience = 0 best_score = 0 best_loss = np.inf max_score = 0 min_loss = np.inf for epoch in epoch_iter: self._train_step() train_acc, _ = self._test_step(split="train") val_acc, val_loss = self._test_step(split="val") epoch_iter.set_description( f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}" ) if val_loss <= min_loss or val_acc >= max_score: if val_acc >= best_score: # SAINT loss is not accurate best_loss = val_loss best_score = val_acc best_model = copy.deepcopy(self.model) min_loss = np.min((min_loss, val_loss)) max_score = np.max((max_score, val_acc)) patience = 0 else: patience += 1 if patience == self.patience: self.model = best_model epoch_iter.close() break return best_model
[docs] def _train_step(self): self.data = self.sampler.get_subgraph("train") self.model.train() self.optimizer.zero_grad() self.model.loss(self.data).backward() self.optimizer.step()
[docs] def _test_step(self, split="val"): self.data = self.sampler.get_subgraph(split) self.model.eval() if split == "train": mask = self.data.train_mask elif split == "val": mask = self.data.val_mask else: mask = self.data.test_mask logits = self.model.predict(self.data) loss = (torch.nn.NLLLoss(reduction = 'none')(logits[mask], self.data.y[mask]) * self.data.norm_loss[mask]).sum() pred = logits[mask].max(1)[1] acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() return acc, loss
""" class LayerSampledTrainer(SampledTrainer): def __init__(self, args): self.device = torch.device('cpu' if args.cpu else 'cuda') self.patience = args.patience self.max_epoch = args.max_epoch self.batch_size = args.batch_size def fit(self, model: SamplingNodeClassificationModel, dataset: Dataset): self.model = model.to(self.device) self.data = dataset.data self.data.apply(lambda x: x.to(self.device)) self.sampler = LayerSampler(self.data, self.model, {}) self.num_nodes = self.data.x.shape[0] self.adj_list = self.data.edge_index.detach().cpu().numpy() self.model.set_adj(self.adj_list, self.num_nodes) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) epoch_iter = tqdm(range(self.max_epoch)) patience = 0 best_score = 0 best_loss = np.inf max_score = 0 min_loss = np.inf for epoch in epoch_iter: self._train_step() train_acc, _ = self._test_step(split="train") val_acc, val_loss = self._test_step(split="val") epoch_iter.set_description( f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}" ) if val_loss <= min_loss or val_acc >= max_score: if val_loss <= best_loss: # and val_acc >= best_score: best_loss = val_loss best_score = val_acc best_model = copy.deepcopy(self.model) min_loss = np.min((min_loss, val_loss)) max_score = np.max((max_score, val_acc)) patience = 0 else: patience += 1 if patience == self.patience: self.model = best_model epoch_iter.close() break def _train_step(self): self.model.train() train_nodes = np.where(self.data.train_mask.detach().cpu().numpy())[0] train_labels = self.data.y.detach().cpu().numpy() for batch_nodes, batch_labels in get_batches(train_nodes, train_labels, batch_size=self.batch_size): batch_nodes = torch.LongTensor(batch_nodes) batch_labels = torch.LongTensor(batch_labels).to(self.device) sampled_x, sampled_adj, var_loss = self.sampler.sampling(self.data.x, batch_nodes) self.optimizer.zero_grad() output = self.model(sampled_x, sampled_adj) loss = F.nll_loss(output, batch_labels) + 0.5 * var_loss loss.backward() self.optimizer.step() def _test_step(self, split="val"): self.model.eval() _, mask = list(self.data(f"{split}_mask"))[0] test_nodes = np.where(mask.detach().cpu().numpy())[0] test_labels = self.data.y.detach().cpu().numpy() all_loss = [] all_acc = [] for batch_nodes, batch_labels in get_batches(test_nodes, test_labels, batch_size=self.batch_size): batch_nodes = torch.LongTensor(batch_nodes) batch_labels = torch.LongTensor(batch_labels).to(self.device) sampled_x, sampled_adj, var_loss = self.model.sampling(self.data.x, batch_nodes) with torch.no_grad(): logits = self.model(sampled_x, sampled_adj) loss = F.nll_loss(logits, batch_labels) pred = logits.max(1)[1] acc = pred.eq(self.data.y[batch_nodes]).sum().item() / batch_nodes.shape[0] all_loss.append(loss.item()) all_acc.append(acc) return np.mean(all_acc), np.mean(all_loss) """