Source code for cogdl.models.nn.pyg_unsup_graphsage

import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import random_walk

from .dgi import LogRegTrainer
from .. import register_model, BaseModel
from cogdl.models.nn.graphsage import sage_sampler, GraphSAGELayer


[docs]class SAGE(nn.Module):
[docs] def sampling(self, edge_index, num_sample): return sage_sampler(self.adjlist, edge_index, num_sample)
def __init__( self, num_features, hidden_size, num_layers, sample_size, dropout, walk_length, negative_samples ): super(SAGE, self).__init__() self.adjlist = {} self.num_features = num_features self.hidden_size = hidden_size self.num_layers = num_layers self.sample_size = sample_size self.dropout = dropout self.walk_length = walk_length self.num_negative_samples = negative_samples self.walk_res = None self.num_nodes = None self.negative_samples = None shapes = [num_features] + [hidden_size] * num_layers self.convs = nn.ModuleList( [ GraphSAGELayer(shapes[layer], shapes[layer+1]) for layer in range(num_layers) ] )
[docs] def forward(self, x, edge_index): for i in range(self.num_layers): edge_index_sp = self.sampling(edge_index, self.sample_size[i]).to(x.device) x = self.convs[i](x, edge_index_sp) if i != self.num_layers - 1: x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) return x
[docs] def loss(self, data): x = self.forward(data.x, data.edge_index) device = x.device # if self.walk_res is None: self.walk_res = random_walk(data.edge_index[0], data.edge_index[1], start=torch.arange(0, x.shape[0]).to(device), walk_length=self.walk_length)[:, 1:] if not self.num_nodes: self.num_nodes = int(torch.max(data.edge_index)) + 1 # if self.negative_samples is None: self.negative_samples = torch.from_numpy( np.random.choice(self.num_nodes, (self.num_nodes, self.num_negative_samples)) ).to(device) pos_loss = -torch.log( torch.sigmoid( torch.sum(x.unsqueeze(1).repeat(1, self.walk_length, 1) * x[self.walk_res], dim=-1) ) ).mean() neg_loss = -torch.log( torch.sigmoid(-torch.sum(x.unsqueeze(1).repeat(1, self.num_negative_samples, 1) * x[self.negative_samples], dim=-1)) ).mean() return (pos_loss + neg_loss)/2
[docs] def embed(self, data): emb = self.forward(data.x, data.edge_index) return emb
[docs]@register_model("unsup_graphsage") class Graphsage(BaseModel): @staticmethod
[docs] def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--hidden-size", type=int, default=128) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--sample-size", type=int, nargs='+', default=[10, 10]) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--walk-length", type=int, default=10) parser.add_argument("--negative-samples", type=int, default=30) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--max-epochs", type=int, default=3000)
# fmt: on @classmethod
[docs] def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.sample_size, args.dropout, args.walk_length, args.negative_samples, args.lr, args.max_epochs, args.patience,
) def __init__( self, num_features, hidden_size, num_classes, num_layers, sample_size, dropout, walk_length, negative_samples, lr, epochs, patience ): super(Graphsage, self).__init__() self.model = SAGE(num_features, hidden_size, num_layers, sample_size, dropout, walk_length, negative_samples) self.epochs = epochs self.patience = patience self.lr = lr self.nhid = hidden_size self.nclass = num_classes self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
[docs] def train(self, data): data.apply(lambda x: x.to(self.device)) self.model.to(self.device) device = data.x.device best = 1e9 cnt_wait = 0 optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0.0) epoch_iter = tqdm.tqdm(range(self.epochs)) for epoch in epoch_iter: self.model.train() optimizer.zero_grad() loss = self.model.loss(data) epoch_iter.set_description(f'Epoch: {epoch:03d}, Loss: {loss.item():.4f}') if loss < best: best = loss best_t = epoch cnt_wait = 0 else: cnt_wait += 1 if cnt_wait == self.patience: print('Early stopping!') break loss.backward() optimizer.step() self.model.eval() embeds = self.model.embed(data).detach() opt = { "idx_train": data.train_mask, "idx_val": data.val_mask, "idx_test": data.test_mask, "num_classes": self.nclass } result = LogRegTrainer().train(embeds, data.y, opt) return result