Source code for cogdl.wrappers.data_wrapper.node_classification.pprgo_dw

import os
import scipy.sparse as sp
import torch

from .. import DataWrapper
from cogdl.utils.ppr_utils import build_topk_ppr_matrix_from_data


[docs]class PPRGoDataWrapper(DataWrapper):
[docs] @staticmethod def add_args(parser): # fmt: off parser.add_argument("--alpha", type=float, default=0.5) parser.add_argument("--topk", type=int, default=32) parser.add_argument("--norm", type=str, default="sym") parser.add_argument("--eps", type=float, default=1e-4) parser.add_argument("--batch-size", type=int, default=512) parser.add_argument("--test-batch-size", type=int, default=-1)
# fmt: on def __init__(self, dataset, topk, alpha=0.2, norm="sym", batch_size=512, eps=1e-4, test_batch_size=-1): super(PPRGoDataWrapper, self).__init__(dataset) self.batch_size, self.test_batch_size = batch_size, test_batch_size self.topk, self.alpha, self.norm, self.eps = topk, alpha, norm, eps self.dataset = dataset
[docs] def train_wrapper(self): """ batch: tuple(x, targets, ppr_scores, y) x: shape=(b, num_features) targets: shape=(num_edges_of_batch,) ppr_scores: shape=(num_edges_of_batch,) y: shape=(b, num_classes) """ self.dataset.data.train() ppr_dataset_train = pre_transform(self.dataset, self.topk, self.alpha, self.eps, self.norm, mode="train") train_loader = setup_dataloader(ppr_dataset_train, self.batch_size) return train_loader
[docs] def val_wrapper(self): self.dataset.data.eval() if self.test_batch_size > 0: ppr_dataset_val = pre_transform(self.dataset, self.topk, self.alpha, self.eps, self.norm, mode="val") val_loader = setup_dataloader(ppr_dataset_val, self.test_batch_size) return val_loader else: return self.dataset.data
[docs] def test_wrapper(self): self.dataset.data.eval() if self.test_batch_size > 0: ppr_dataset_test = pre_transform(self.dataset, self.topk, self.alpha, self.eps, self.norm, mode="test") test_loader = setup_dataloader(ppr_dataset_test, self.test_batch_size) return test_loader else: return self.dataset.data
def setup_dataloader(ppr_dataset, batch_size): data_loader = torch.utils.data.DataLoader( dataset=ppr_dataset, sampler=torch.utils.data.BatchSampler( torch.utils.data.SequentialSampler(ppr_dataset), batch_size=batch_size, drop_last=False, ), batch_size=None, ) return data_loader def pre_transform(dataset, topk, alpha, epsilon, normalization, mode="train"): dataset_name = dataset.__class__.__name__ data = dataset[0] num_nodes = data.x.shape[0] nodes = torch.arange(num_nodes) mask = getattr(data, f"{mode}_mask") index = nodes[mask].numpy() if mode == "train": data.train() else: data.eval() edge_index = data.edge_index if not os.path.exists("./pprgo_saved"): os.mkdir("pprgo_saved") path = f"./pprgo_saved/{dataset_name}_{topk}_{alpha}_{normalization}.{mode}.npz" if os.path.exists(path): print(f"Load {mode} from cached") topk_matrix = sp.load_npz(path) else: print(f"Fail to load {mode}, generating...") topk_matrix = build_topk_ppr_matrix_from_data(edge_index, alpha, epsilon, index, topk, normalization) sp.save_npz(path, topk_matrix) result = PPRGoDataset(data.x, topk_matrix, index, data.y) return result class PPRGoDataset(torch.utils.data.Dataset): def __init__( self, features: torch.Tensor, ppr_matrix: sp.csr_matrix, node_indices: torch.Tensor, labels_all: torch.Tensor = None, ): self.features = features self.matrix = ppr_matrix self.node_indices = node_indices self.labels_all = labels_all self.cache = dict() def __len__(self): return self.node_indices.shape[0] def __getitem__(self, items): key = str(items) if key not in self.cache: sample_matrix = self.matrix[items] source, neighbor = sample_matrix.nonzero() ppr_scores = torch.from_numpy(sample_matrix.data).float() features = self.features[neighbor].float() targets = torch.from_numpy(source).long() labels = self.labels_all[self.node_indices[items]] self.cache[key] = (features, targets, ppr_scores, labels) return self.cache[key]