Source code for cogdl.trainers.gpt_gnn_trainer

import multiprocessing.pool as mp
import os
import time
from copy import deepcopy
from typing import Any

import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score

from cogdl.data import Dataset
from cogdl.layers.gpt_gnn_module import (
    sample_subgraph,
    feature_reddit,
    to_torch,
    randint,
    GNN,
    load_gnn,
    Classifier,
    preprocess_dataset,
)
from cogdl.models.supervised_model import SupervisedHeterogeneousNodeClassificationModel
from cogdl.trainers.supervised_trainer import (
    SupervisedHomogeneousNodeClassificationTrainer,
    SupervisedHeterogeneousNodeClassificationTrainer,
)


[docs]graph_pool = None
[docs]def node_classification_sample(args, target_type, seed, nodes, time_range): """ sub-graph sampling and label preparation for node classification: (1) Sample batch_size number of output nodes (papers) and their time. """ global graph_pool np.random.seed(seed) samp_nodes = np.random.choice(nodes, args.batch_size, replace=False) feature, times, edge_list, _, texts = sample_subgraph( graph_pool, time_range, inp={ target_type: np.concatenate([samp_nodes, np.ones(args.batch_size)]) .reshape(2, -1) .transpose() }, sampled_depth=args.sample_depth, sampled_number=args.sample_width, feature_extractor=feature_reddit, ) ( node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict, ) = to_torch(feature, times, edge_list, graph_pool) x_ids = np.arange(args.batch_size) return ( node_feature, node_type, edge_time, edge_index, edge_type, x_ids, graph_pool.y[samp_nodes],
)
[docs]def prepare_data( args, graph, target_type, train_target_nodes, valid_target_nodes, pool ): """ Sampled and prepare training and validation data using multi-process parallization. """ jobs = [] for batch_id in np.arange(args.n_batch): p = pool.apply_async( node_classification_sample, args=(args, target_type, randint(), train_target_nodes, {1: True}), ) jobs.append(p) p = pool.apply_async( node_classification_sample, args=(args, target_type, randint(), valid_target_nodes, {1: True}), ) jobs.append(p) return jobs
[docs]class GPT_GNNHomogeneousTrainer(SupervisedHomogeneousNodeClassificationTrainer): def __init__(self, args): super(GPT_GNNHomogeneousTrainer, self).__init__() self.args = args
[docs] def fit( self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset ) -> None: args = self.args self.device = args.device_id[0] if not args.cpu else "cpu" self.data = preprocess_dataset(dataset) global graph_pool graph_pool = self.data self.target_type = "def" self.train_target_nodes = self.data.train_target_nodes self.valid_target_nodes = self.data.valid_target_nodes self.test_target_nodes = self.data.test_target_nodes self.types = self.data.get_types() self.criterion = torch.nn.NLLLoss() self.stats = [] self.res = [] self.best_val = 0 self.train_step = 0 self.pool = mp.Pool(args.n_pool) self.st = time.time() self.jobs = prepare_data( args, self.data, self.target_type, self.train_target_nodes, self.valid_target_nodes, self.pool, ) """ Initialize GNN (model is specified by conv_name) and Classifier """ self.gnn = GNN( conv_name=args.conv_name, in_dim=len(self.data.node_feature[self.target_type]["emb"].values[0]), n_hid=args.n_hid, n_heads=args.n_heads, n_layers=args.n_layers, dropout=args.dropout, num_types=len(self.types), num_relations=len(self.data.get_meta_graph()) + 1, prev_norm=args.prev_norm, last_norm=args.last_norm, use_RTE=False, ) if args.use_pretrain: self.gnn.load_state_dict( load_gnn(torch.load(args.pretrain_model_dir)), strict=False ) print("Load Pre-trained Model from (%s)" % args.pretrain_model_dir) self.classifier = Classifier(args.n_hid, self.data.y.max().item() + 1) self.model = torch.nn.Sequential(self.gnn, self.classifier).to(self.device) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4) if args.scheduler == "cycle": self.scheduler = torch.optim.lr_scheduler.OneCycleLR( self.optimizer, pct_start=0.02, anneal_strategy="linear", final_div_factor=100, max_lr=args.max_lr, total_steps=args.n_batch * args.n_epoch + 1, ) elif args.scheduler == "cosine": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, 500, eta_min=1e-6 ) else: assert False self.train_data = [job.get() for job in self.jobs[:-1]] self.valid_data = self.jobs[-1].get() self.pool.close() self.pool.join() self.et = time.time() print("Data Preparation: %.1fs" % (self.et - self.st)) for epoch in np.arange(self.args.n_epoch) + 1: """ Prepare Training and Validation Data """ train_data = [job.get() for job in self.jobs[:-1]] valid_data = self.jobs[-1].get() self.pool.close() self.pool.join() """ After the data is collected, close the pool and then reopen it. """ self.pool = mp.Pool(self.args.n_pool) self.jobs = prepare_data( self.args, self.data, self.target_type, self.train_target_nodes, self.valid_target_nodes, self.pool, ) self.et = time.time() print("Data Preparation: %.1fs" % (self.et - self.st)) """ Train """ self.model.train() train_losses = [] for ( node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel, ) in train_data: node_rep = self.gnn.forward( node_feature.to(self.device), node_type.to(self.device), edge_time.to(self.device), edge_index.to(self.device), edge_type.to(self.device), ) res = self.classifier.forward(node_rep[x_ids]) loss = self.criterion(res, ylabel.to(self.device)) self.optimizer.zero_grad() torch.cuda.empty_cache() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) self.optimizer.step() train_losses += [loss.cpu().detach().tolist()] self.train_step += 1 self.scheduler.step(self.train_step) del res, loss """ Valid """ self.model.eval() with torch.no_grad(): ( node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel, ) = valid_data node_rep = self.gnn.forward( node_feature.to(self.device), node_type.to(self.device), edge_time.to(self.device), edge_index.to(self.device), edge_type.to(self.device), ) res = self.classifier.forward(node_rep[x_ids]) loss = self.criterion(res, ylabel.to(self.device)) """ Calculate Valid F1. Update the best model based on highest F1 score. """ valid_f1 = f1_score( ylabel.tolist(), res.argmax(dim=1).cpu().tolist(), average="micro" ) if valid_f1 > self.best_val: self.best_val = valid_f1 # torch.save( # self.model, # os.path.join( # self.args.model_dir, # self.args.task_name + "_" + self.args.conv_name, # ), # ) self.best_model_dict = deepcopy(self.model.state_dict()) print("UPDATE!!!") self.st = time.time() print( ( "Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid F1: %.4f" ) % ( epoch, (self.st - self.et), self.optimizer.param_groups[0]["lr"], np.average(train_losses), loss.cpu().detach().tolist(), valid_f1, ) ) self.stats += [[np.average(train_losses), loss.cpu().detach().tolist()]] del res, loss del train_data, valid_data self.model.load_state_dict(self.best_model_dict) best_model = self.model.to(self.device) # best_model = torch.load( # os.path.join( # self.args.model_dir, self.args.task_name + "_" + self.args.conv_name # ) # ).to(self.device) best_model.eval() gnn, classifier = best_model with torch.no_grad(): test_res = [] for _ in range(10): ( node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel, ) = node_classification_sample( self.args, self.target_type, randint(), self.test_target_nodes, {1: True}, ) paper_rep = gnn.forward( node_feature.to(self.device), node_type.to(self.device), edge_time.to(self.device), edge_index.to(self.device), edge_type.to(self.device), )[x_ids] res = classifier.forward(paper_rep) test_acc = accuracy_score( ylabel.tolist(), res.argmax(dim=1).cpu().tolist() ) test_res += [test_acc] return dict(Acc=np.average(test_res))
# # print("Best Test F1: %.4f" % np.average(test_res)) @classmethod
[docs] def build_trainer_from_args(cls, args): pass
[docs]class GPT_GNNHeterogeneousTrainer(SupervisedHeterogeneousNodeClassificationTrainer): def __init__(self, model, dataset): super(GPT_GNNHeterogeneousTrainer, self).__init__(model, dataset)
[docs] def fit(self) -> None: raise NotImplemented
[docs] def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: raise NotImplemented