Source code for cogdl.models.nn.dgl_gcc

import math
import os

import dgl
import numpy as np
import scipy.sparse as sparse
import sklearn.preprocessing as preprocessing
import torch
import torch.nn.functional as F
from cogdl.layers.gcc_module import GraphEncoder
from cogdl.utils import download_url
from scipy.sparse import linalg

from .. import BaseModel, register_model


def batcher():
    def batcher_dev(batch):
        graph_q, graph_k = zip(*batch)
        graph_q, graph_k = dgl.batch(graph_q), dgl.batch(graph_k)
        return graph_q, graph_k

    return batcher_dev


def test_moco(train_loader, model, opt):
    """
    one epoch training for moco
    """

    model.eval()

    emb_list = []
    for idx, batch in enumerate(train_loader):
        graph_q, graph_k = batch
        bsz = graph_q.batch_size
        graph_q.to(opt.device)
        graph_k.to(opt.device)

        with torch.no_grad():
            feat_q = model(graph_q)
            feat_k = model(graph_k)

        assert feat_q.shape == (bsz, opt.hidden_size)
        emb_list.append(((feat_q + feat_k) / 2).detach().cpu())
    return torch.cat(emb_list)


def eigen_decomposision(n, k, laplacian, hidden_size, retry):
    if k <= 0:
        return torch.zeros(n, hidden_size)
    laplacian = laplacian.astype("float64")
    ncv = min(n, max(2 * k + 1, 20))
    # follows https://stackoverflow.com/questions/52386942/scipy-sparse-linalg-eigsh-with-fixed-seed
    v0 = np.random.rand(n).astype("float64")
    for i in range(retry):
        try:
            s, u = linalg.eigsh(laplacian, k=k, which="LA", ncv=ncv, v0=v0)
        except sparse.linalg.eigen.arpack.ArpackError:
            # print("arpack error, retry=", i)
            ncv = min(ncv * 2, n)
            if i + 1 == retry:
                sparse.save_npz("arpack_error_sparse_matrix.npz", laplacian)
                u = torch.zeros(n, k)
        else:
            break
    x = preprocessing.normalize(u, norm="l2")
    x = torch.from_numpy(x.astype("float32"))
    x = F.pad(x, (0, hidden_size - k), "constant", 0)
    return x


def _add_undirected_graph_positional_embedding(g, hidden_size, retry=10):
    # We use eigenvectors of normalized graph laplacian as vertex features.
    # It could be viewed as a generalization of positional embedding in the
    # attention is all you need paper.
    # Recall that the eignvectors of normalized laplacian of a line graph are cos/sin functions.
    # See section 2.4 of http://www.cs.yale.edu/homes/spielman/561/2009/lect02-09.pdf
    n = g.number_of_nodes()
    adj = g.adjacency_matrix_scipy(transpose=False, return_edge_ids=False).astype(float)
    norm = sparse.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
    laplacian = norm * adj * norm
    k = min(n - 2, hidden_size)
    x = eigen_decomposision(n, k, laplacian, hidden_size, retry)
    g.ndata["pos_undirected"] = x.float()
    return g


def _rwr_trace_to_dgl_graph(g, seed, trace, positional_embedding_size, entire_graph=False):
    subv = torch.unique(torch.cat(trace)).tolist()
    try:
        subv.remove(seed)
    except ValueError:
        pass
    subv = [seed] + subv
    if entire_graph:
        subg = g.subgraph(g.nodes())
    else:
        subg = g.subgraph(subv)

    subg = _add_undirected_graph_positional_embedding(subg, positional_embedding_size)

    subg.ndata["seed"] = torch.zeros(subg.number_of_nodes(), dtype=torch.long)
    if entire_graph:
        subg.ndata["seed"][seed] = 1
    else:
        subg.ndata["seed"][0] = 1
    return subg


class NodeClassificationDataset(object):
    def __init__(
        self,
        data,
        rw_hops=64,
        subgraph_size=64,
        restart_prob=0.8,
        positional_embedding_size=32,
        step_dist=[1.0, 0.0, 0.0],
    ):
        self.rw_hops = rw_hops
        self.subgraph_size = subgraph_size
        self.restart_prob = restart_prob
        self.positional_embedding_size = positional_embedding_size
        self.step_dist = step_dist
        assert positional_embedding_size > 1

        self.data = data
        self.graphs = [self._create_dgl_graph(self.data)]
        self.length = sum([g.number_of_nodes() for g in self.graphs])
        self.total = self.length

    def _create_dgl_graph(self, data):
        graph = dgl.DGLGraph()
        src, dst = data.edge_index.tolist()
        num_nodes = data.edge_index.max() + 1
        graph.add_nodes(num_nodes)
        graph.add_edges(src, dst)
        graph.add_edges(dst, src)
        graph.readonly()
        return graph

    def __len__(self):
        return self.length

    def _convert_idx(self, idx):
        graph_idx = 0
        node_idx = idx
        for i in range(len(self.graphs)):
            if node_idx < self.graphs[i].number_of_nodes():
                graph_idx = i
                break
            else:
                node_idx -= self.graphs[i].number_of_nodes()
        return graph_idx, node_idx

    def __getitem__(self, idx):
        graph_idx, node_idx = self._convert_idx(idx)

        step = np.random.choice(len(self.step_dist), 1, p=self.step_dist)[0]
        if step == 0:
            other_node_idx = node_idx
        else:
            other_node_idx = dgl.contrib.sampling.random_walk(
                g=self.graphs[graph_idx], seeds=[node_idx], num_traces=1, num_hops=step
            )[0][0][-1].item()

        max_nodes_per_seed = max(
            self.rw_hops,
            int((self.graphs[graph_idx].out_degree(node_idx) * math.e / (math.e - 1) / self.restart_prob) + 0.5),
        )
        traces = dgl.contrib.sampling.random_walk_with_restart(
            self.graphs[graph_idx],
            seeds=[node_idx, other_node_idx],
            restart_prob=self.restart_prob,
            max_nodes_per_seed=max_nodes_per_seed,
        )

        graph_q = _rwr_trace_to_dgl_graph(
            g=self.graphs[graph_idx],
            seed=node_idx,
            trace=traces[0],
            positional_embedding_size=self.positional_embedding_size,
            entire_graph=hasattr(self, "entire_graph") and self.entire_graph,
        )
        graph_k = _rwr_trace_to_dgl_graph(
            g=self.graphs[graph_idx],
            seed=other_node_idx,
            trace=traces[1],
            positional_embedding_size=self.positional_embedding_size,
            entire_graph=hasattr(self, "entire_graph") and self.entire_graph,
        )
        return graph_q, graph_k


class GraphClassificationDataset(NodeClassificationDataset):
    def __init__(
        self,
        data,
        rw_hops=64,
        subgraph_size=64,
        restart_prob=0.8,
        positional_embedding_size=32,
        step_dist=[1.0, 0.0, 0.0],
    ):
        super(GraphClassificationDataset, self).__init__()
        self.rw_hops = rw_hops
        self.subgraph_size = subgraph_size
        self.restart_prob = restart_prob
        self.positional_embedding_size = positional_embedding_size
        self.step_dist = step_dist
        self.entire_graph = True
        assert positional_embedding_size > 1

        self.graphs = data

        self.length = len(self.graphs)
        self.total = self.length

    def _convert_idx(self, idx):
        graph_idx = idx
        node_idx = self.graphs[idx].out_degrees().argmax().item()
        return graph_idx, node_idx


[docs]@register_model("gcc") class GCC(BaseModel):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--load-path", type=str, default='./saved/gcc_pretrained.pth') parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--epoch", type=int, default=0)
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls(args.load_path)
def __init__(self, load_path): super(GCC, self).__init__() self.load_path = load_path
[docs] def train(self, data): if not os.path.isfile(self.load_path): print("=> no checkpoint found at '{}'".format(self.load_path)) url = "https://github.com/cenyk1230/gcc-data/raw/master/saved/gcc_pretrained.pth" path = "/".join(self.load_path.split("/")[:-1]) name = self.load_path.split("/")[-1] download_url(url, path, name=name) print("=> loading checkpoint '{}'".format(self.load_path)) checkpoint = torch.load(self.load_path, map_location="cpu") print("=> loaded successfully '{}' (epoch {})".format(self.load_path, checkpoint["epoch"])) args = checkpoint["opt"] args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(data, list): train_dataset = GraphClassificationDataset( data=data, rw_hops=args.rw_hops, subgraph_size=args.subgraph_size, restart_prob=args.restart_prob, positional_embedding_size=args.positional_embedding_size, ) else: train_dataset = NodeClassificationDataset( data=data, rw_hops=args.rw_hops, subgraph_size=args.subgraph_size, restart_prob=args.restart_prob, positional_embedding_size=args.positional_embedding_size, ) args.batch_size = len(train_dataset) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size, collate_fn=batcher(), shuffle=False, num_workers=args.num_workers, ) # create model and optimizer model = GraphEncoder( positional_embedding_size=args.positional_embedding_size, max_node_freq=args.max_node_freq, max_edge_freq=args.max_edge_freq, max_degree=args.max_degree, freq_embedding_size=args.freq_embedding_size, degree_embedding_size=args.degree_embedding_size, output_dim=args.hidden_size, node_hidden_dim=args.hidden_size, edge_hidden_dim=args.hidden_size, num_layers=args.num_layer, num_step_set2set=args.set2set_iter, num_layer_set2set=args.set2set_lstm_layer, gnn_model=args.model, norm=args.norm, degree_input=True, ) model = model.to(args.device) model.load_state_dict(checkpoint["model"]) del checkpoint emb = test_moco(train_loader, model, args) return emb.numpy()