Source code for cogdl.models.emb.gatne

import numpy as np
import networkx as nx
from collections import defaultdict
from gensim.models.keyedvectors import Vocab  # Retained for now to ease the loading of older models.
# See: https://radimrehurek.com/gensim/models/keyedvectors.html?highlight=vocab#gensim.models.keyedvectors.CompatVocab
import random
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .. import BaseModel


[docs]class GATNE(BaseModel): r"""The GATNE model from the `"Representation Learning for Attributed Multiplex Heterogeneous Network" <https://dl.acm.org/doi/10.1145/3292500.3330964>`_ paper Args: walk_length (int) : The walk length. walk_num (int) : The number of walks to sample for each node. window_size (int) : The actual context size which is considered in language model. worker (int) : The number of workers for word2vec. epochs (int) : The number of training epochs. batch_size (int) : The size of each training batch. edge_dim (int) : Number of edge embedding dimensions. att_dim (int) : Number of attention dimensions. negative_samples (int) : Negative samples for optimization. neighbor_samples (int) : Neighbor samples for aggregation schema (str) : The metapath schema used in model. Metapaths are splited with ",", while each node type are connected with "-" in each metapath. For example:"0-1-0,0-1-2-1-0" """
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--walk-length', type=int, default=10, help='Length of walk per source. Default is 10.') parser.add_argument('--walk-num', type=int, default=10, help='Number of walks per source. Default is 10.') parser.add_argument('--window-size', type=int, default=5, help='Window size of skip-gram model. Default is 5.') parser.add_argument('--worker', type=int, default=10, help='Number of parallel workers. Default is 10.') parser.add_argument('--epochs', type=int, default=20, help='Number of epochs. Default is 20.') parser.add_argument('--batch-size', type=int, default=256, help='Number of batch_size. Default is 256.') parser.add_argument('--edge-dim', type=int, default=10, help='Number of edge embedding dimensions. Default is 10.') parser.add_argument('--att-dim', type=int, default=20, help='Number of attention dimensions. Default is 20.') parser.add_argument('--negative-samples', type=int, default=5, help='Negative samples for optimization. Default is 5.') parser.add_argument('--neighbor-samples', type=int, default=10, help='Neighbor samples for aggregation. Default is 10.') parser.add_argument('--schema', type=str, default=None, help="Input schema for metapath random walk.")
# fmt: on
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.hidden_size, args.walk_length, args.walk_num, args.window_size, args.worker, args.epochs, args.batch_size, args.edge_dim, args.att_dim, args.negative_samples, args.neighbor_samples, args.schema, )
def __init__( self, dimension, walk_length, walk_num, window_size, worker, epochs, batch_size, edge_dim, att_dim, negative_samples, neighbor_samples, schema, ): super(GATNE, self).__init__() self.embedding_size = dimension self.walk_length = walk_length self.walk_num = walk_num self.window_size = window_size self.worker = worker self.epochs = epochs self.batch_size = batch_size self.embedding_u_size = edge_dim self.dim_att = att_dim self.num_sampled = negative_samples self.neighbor_samples = neighbor_samples self.schema = schema self.multiplicity = True
[docs] def forward(self, network_data): device = "cpu" if not torch.cuda.is_available() else "cuda" all_walks = generate_walks(network_data, self.walk_num, self.walk_length, schema=self.schema) vocab, index_to_key = generate_vocab(all_walks) train_pairs = generate_pairs(all_walks, vocab) edge_types = list(network_data.keys()) num_nodes = len(index_to_key) edge_type_count = len(edge_types) epochs = self.epochs batch_size = self.batch_size embedding_size = self.embedding_size embedding_u_size = self.embedding_u_size num_sampled = self.num_sampled dim_att = self.dim_att neighbor_samples = self.neighbor_samples neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)] for r in range(edge_type_count): g = network_data[edge_types[r]] for (x, y) in g: ix = vocab[x].index iy = vocab[y].index neighbors[ix][r].append(iy) neighbors[iy][r].append(ix) for i in range(num_nodes): if len(neighbors[i][r]) == 0: neighbors[i][r] = [i] * neighbor_samples elif len(neighbors[i][r]) < neighbor_samples: neighbors[i][r].extend( list(np.random.choice(neighbors[i][r], size=neighbor_samples - len(neighbors[i][r]),)) ) elif len(neighbors[i][r]) > neighbor_samples: neighbors[i][r] = list(np.random.choice(neighbors[i][r], size=neighbor_samples)) model = GATNEModel(num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_att) nsloss = NSLoss(num_nodes, num_sampled, embedding_size) model.to(device) nsloss.to(device) optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4) for epoch in range(epochs): random.shuffle(train_pairs) batches = get_batches(train_pairs, neighbors, batch_size) data_iter = tqdm.tqdm( batches, desc="epoch %d" % (epoch), total=(len(train_pairs) + (batch_size - 1)) // batch_size, bar_format="{l_bar}{r_bar}", ) avg_loss = 0.0 for i, data in enumerate(data_iter): optimizer.zero_grad() embs = model(data[0].to(device), data[2].to(device), data[3].to(device),) loss = nsloss(data[0].to(device), embs, data[1].to(device)) loss.backward() optimizer.step() avg_loss += loss.item() if i % 5000 == 0: post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.write(str(post_fix)) final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)])) for i in range(num_nodes): train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device) train_types = torch.tensor(list(range(edge_type_count))).to(device) node_neigh = torch.tensor([neighbors[i] for _ in range(edge_type_count)]).to(device) node_emb = model(train_inputs, train_types, node_neigh) for j in range(edge_type_count): final_model[edge_types[j]][index_to_key[i]] = node_emb[j].cpu().detach().numpy() return final_model
class GATNEModel(nn.Module): def __init__(self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a): super(GATNEModel, self).__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size self.edge_type_count = edge_type_count self.dim_a = dim_a self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.node_type_embeddings = Parameter(torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)) self.trans_weights = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)) self.trans_weights_s1 = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)) self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.reset_parameters() def reset_parameters(self): self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, train_inputs, train_types, node_neigh): node_embed = self.node_embeddings[train_inputs] node_embed_neighbors = self.node_type_embeddings[node_neigh] node_embed_tmp = torch.cat( [node_embed_neighbors[:, i, :, i, :].unsqueeze(1) for i in range(self.edge_type_count)], dim=1, ) node_type_embed = torch.sum(node_embed_tmp, dim=2) trans_w = self.trans_weights[train_types] trans_w_s1 = self.trans_weights_s1[train_types] trans_w_s2 = self.trans_weights_s2[train_types] attention = F.softmax( torch.matmul(F.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2).squeeze() ).unsqueeze(1) node_type_embed = torch.matmul(attention, node_type_embed) node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze() last_node_embed = F.normalize(node_embed, dim=1) return last_node_embed class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): super(NSLoss, self).__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.sample_weights = F.normalize( torch.Tensor([(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes)]), dim=0, ) self.reset_parameters() def reset_parameters(self): self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, input, embs, label): n = input.shape[0] log_target = torch.log(torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))) negs = torch.multinomial(self.sample_weights, self.num_sampled * n, replacement=True).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) sum_log_sampled = torch.sum(torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n class RWGraph: def __init__(self, nx_G, node_type=None): self.G = nx_G self.node_type = node_type def walk(self, walk_length, start, schema=None): # Simulate a random walk starting from start node. G = self.G rand = random.Random() if schema: schema_items = schema.split("-") assert schema_items[0] == schema_items[-1] walk = [start] while len(walk) < walk_length: cur = walk[-1] candidates = [] for node in G[cur].keys(): if schema is None or self.node_type[node] == schema_items[len(walk) % (len(schema_items) - 1)]: candidates.append(node) if candidates: walk.append(rand.choice(candidates)) else: break return walk def simulate_walks(self, num_walks, walk_length, schema=None): G = self.G walks = [] nodes = list(G.nodes()) # print('Walk iteration:') if schema is not None: schema_list = schema.split(",") for walk_iter in range(num_walks): random.shuffle(nodes) for node in nodes: if schema is None: walks.append(self.walk(walk_length=walk_length, start=node)) else: for schema_iter in schema_list: if schema_iter.split("-")[0] == self.node_type[node]: walks.append(self.walk(walk_length=walk_length, start=node, schema=schema_iter,)) return walks def get_G_from_edges(edges): edge_dict = dict() for edge in edges: edge_key = str(edge[0]) + "_" + str(edge[1]) if edge_key not in edge_dict: edge_dict[edge_key] = 1 else: edge_dict[edge_key] += 1 tmp_G = nx.Graph() for edge_key in edge_dict: weight = edge_dict[edge_key] x = int(edge_key.split("_")[0]) y = int(edge_key.split("_")[1]) tmp_G.add_edge(x, y) tmp_G[x][y]["weight"] = weight return tmp_G def generate_pairs(all_walks, vocab, window_size=5): pairs = [] skip_window = window_size // 2 for layer_id, walks in enumerate(all_walks): for walk in walks: for i in range(len(walk)): for j in range(1, skip_window + 1): if i - j >= 0: pairs.append((vocab[walk[i]].index, vocab[walk[i - j]].index, layer_id)) if i + j < len(walk): pairs.append((vocab[walk[i]].index, vocab[walk[i + j]].index, layer_id)) return pairs def generate_vocab(all_walks): index_to_key = [] raw_vocab = defaultdict(int) for walks in all_walks: for walk in walks: for word in walk: raw_vocab[word] += 1 vocab = {} for word, v in raw_vocab.items(): vocab[word] = Vocab(count=v, index=len(index_to_key)) index_to_key.append(word) index_to_key.sort(key=lambda word: vocab[word].count, reverse=True) for i, word in enumerate(index_to_key): vocab[word].index = i return vocab, index_to_key def get_batches(pairs, neighbors, batch_size): n_batches = (len(pairs) + (batch_size - 1)) // batch_size # result = [] for idx in range(n_batches): x, y, t, neigh = [], [], [], [] for i in range(batch_size): index = idx * batch_size + i if index >= len(pairs): break x.append(pairs[index][0]) y.append(pairs[index][1]) t.append(pairs[index][2]) neigh.append(neighbors[pairs[index][0]]) yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh) def generate_walks(network_data, num_walks, walk_length, schema=None): # if schema is not None: # pass # else: # node_type = None all_walks = [] for layer_id in network_data: tmp_data = network_data[layer_id] # start to do the random walk on a layer layer_walker = RWGraph(get_G_from_edges(tmp_data)) layer_walks = layer_walker.simulate_walks(num_walks, walk_length, schema=schema) all_walks.append(layer_walks) return all_walks