import numpy as np
import networkx as nx
from collections import defaultdict
from gensim.models.keyedvectors import Vocab
import random
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from .. import BaseModel, register_model
[docs]@register_model("gatne")
class GATNE(BaseModel):
r"""The GATNE model from the `"Representation Learning for Attributed Multiplex Heterogeneous Network"
<https://dl.acm.org/doi/10.1145/3292500.3330964>`_ paper
Args:
walk_length (int) : The walk length.
walk_num (int) : The number of walks to sample for each node.
window_size (int) : The actual context size which is considered in language model.
worker (int) : The number of workers for word2vec.
epoch (int) : The number of training epochs.
batch_size (int) : The size of each training batch.
edge_dim (int) : Number of edge embedding dimensions.
att_dim (int) : Number of attention dimensions.
negative_samples (int) : Negative samples for optimization.
neighbor_samples (int) : Neighbor samples for aggregation
schema (str) : The metapath schema used in model. Metapaths are splited with ",",
while each node type are connected with "-" in each metapath. For example:"0-1-0,0-1-2-1-0"
"""
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--walk-length', type=int, default=10,
help='Length of walk per source. Default is 10.')
parser.add_argument('--walk-num', type=int, default=10,
help='Number of walks per source. Default is 10.')
parser.add_argument('--window-size', type=int, default=5,
help='Window size of skip-gram model. Default is 5.')
parser.add_argument('--worker', type=int, default=10,
help='Number of parallel workers. Default is 10.')
parser.add_argument('--epoch', type=int, default=20,
help='Number of epoch. Default is 20.')
parser.add_argument('--batch-size', type=int, default=256,
help='Number of batch_size. Default is 256.')
parser.add_argument('--edge-dim', type=int, default=10,
help='Number of edge embedding dimensions. Default is 10.')
parser.add_argument('--att-dim', type=int, default=20,
help='Number of attention dimensions. Default is 20.')
parser.add_argument('--negative-samples', type=int, default=5,
help='Negative samples for optimization. Default is 5.')
parser.add_argument('--neighbor-samples', type=int, default=10,
help='Neighbor samples for aggregation. Default is 10.')
parser.add_argument('--schema', type=str, default=None,
help="Input schema for metapath random walk.")
# fmt: on
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(
args.hidden_size,
args.walk_length,
args.walk_num,
args.window_size,
args.worker,
args.epoch,
args.batch_size,
args.edge_dim,
args.att_dim,
args.negative_samples,
args.neighbor_samples,
args.schema,
)
def __init__(
self,
dimension,
walk_length,
walk_num,
window_size,
worker,
epoch,
batch_size,
edge_dim,
att_dim,
negative_samples,
neighbor_samples,
schema,
):
super(GATNE, self).__init__()
self.embedding_size = dimension
self.walk_length = walk_length
self.walk_num = walk_num
self.window_size = window_size
self.worker = worker
self.epochs = epoch
self.batch_size = batch_size
self.embedding_u_size = edge_dim
self.dim_att = att_dim
self.num_sampled = negative_samples
self.neighbor_samples = neighbor_samples
self.schema = schema
self.multiplicity = True
[docs] def train(self, network_data):
all_walks = generate_walks(network_data, self.walk_num, self.walk_length, schema=self.schema)
vocab, index2word = generate_vocab(all_walks)
train_pairs = generate_pairs(all_walks, vocab)
edge_types = list(network_data.keys())
num_nodes = len(index2word)
edge_type_count = len(edge_types)
epochs = self.epochs
batch_size = self.batch_size
embedding_size = self.embedding_size
embedding_u_size = self.embedding_u_size
num_sampled = self.num_sampled
dim_att = self.dim_att
neighbor_samples = self.neighbor_samples
neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)]
for r in range(edge_type_count):
g = network_data[edge_types[r]]
for (x, y) in g:
ix = vocab[x].index
iy = vocab[y].index
neighbors[ix][r].append(iy)
neighbors[iy][r].append(ix)
for i in range(num_nodes):
if len(neighbors[i][r]) == 0:
neighbors[i][r] = [i] * neighbor_samples
elif len(neighbors[i][r]) < neighbor_samples:
neighbors[i][r].extend(
list(
np.random.choice(
neighbors[i][r],
size=neighbor_samples - len(neighbors[i][r]),
)
)
)
elif len(neighbors[i][r]) > neighbor_samples:
neighbors[i][r] = list(np.random.choice(neighbors[i][r], size=neighbor_samples))
model = GATNEModel(num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_att)
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
model.to(self.device)
nsloss.to(self.device)
optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4)
for epoch in range(epochs):
random.shuffle(train_pairs)
batches = get_batches(train_pairs, neighbors, batch_size)
data_iter = tqdm.tqdm(
batches,
desc="epoch %d" % (epoch),
total=(len(train_pairs) + (batch_size - 1)) // batch_size,
bar_format="{l_bar}{r_bar}",
)
avg_loss = 0.0
for i, data in enumerate(data_iter):
optimizer.zero_grad()
embs = model(
data[0].to(self.device),
data[2].to(self.device),
data[3].to(self.device),
)
loss = nsloss(data[0].to(self.device), embs, data[1].to(self.device))
loss.backward()
optimizer.step()
avg_loss += loss.item()
if i % 5000 == 0:
post_fix = {
"epoch": epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"loss": loss.item(),
}
data_iter.write(str(post_fix))
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
for i in range(num_nodes):
train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(self.device)
train_types = torch.tensor(list(range(edge_type_count))).to(self.device)
node_neigh = torch.tensor([neighbors[i] for _ in range(edge_type_count)]).to(self.device)
node_emb = model(train_inputs, train_types, node_neigh)
for j in range(edge_type_count):
final_model[edge_types[j]][index2word[i]] = node_emb[j].cpu().detach().numpy()
return final_model
class GATNEModel(nn.Module):
def __init__(self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a):
super(GATNEModel, self).__init__()
self.num_nodes = num_nodes
self.embedding_size = embedding_size
self.embedding_u_size = embedding_u_size
self.edge_type_count = edge_type_count
self.dim_a = dim_a
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_type_embeddings = Parameter(torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size))
self.trans_weights = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size))
self.trans_weights_s1 = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, dim_a))
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))
self.reset_parameters()
def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
def forward(self, train_inputs, train_types, node_neigh):
node_embed = self.node_embeddings[train_inputs]
node_embed_neighbors = self.node_type_embeddings[node_neigh]
node_embed_tmp = torch.cat(
[node_embed_neighbors[:, i, :, i, :].unsqueeze(1) for i in range(self.edge_type_count)],
dim=1,
)
node_type_embed = torch.sum(node_embed_tmp, dim=2)
trans_w = self.trans_weights[train_types]
trans_w_s1 = self.trans_weights_s1[train_types]
trans_w_s2 = self.trans_weights_s2[train_types]
attention = F.softmax(
torch.matmul(F.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2).squeeze()
).unsqueeze(1)
node_type_embed = torch.matmul(attention, node_type_embed)
node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze()
last_node_embed = F.normalize(node_embed, dim=1)
return last_node_embed
class NSLoss(nn.Module):
def __init__(self, num_nodes, num_sampled, embedding_size):
super(NSLoss, self).__init__()
self.num_nodes = num_nodes
self.num_sampled = num_sampled
self.embedding_size = embedding_size
self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.sample_weights = F.normalize(
torch.Tensor([(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes)]),
dim=0,
)
self.reset_parameters()
def reset_parameters(self):
self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
def forward(self, input, embs, label):
n = input.shape[0]
log_target = torch.log(torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1)))
negs = torch.multinomial(self.sample_weights, self.num_sampled * n, replacement=True).view(n, self.num_sampled)
noise = torch.neg(self.weights[negs])
sum_log_sampled = torch.sum(torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1).squeeze()
loss = log_target + sum_log_sampled
return -loss.sum() / n
class RWGraph:
def __init__(self, nx_G, node_type=None):
self.G = nx_G
self.node_type = node_type
def walk(self, walk_length, start, schema=None):
# Simulate a random walk starting from start node.
G = self.G
rand = random.Random()
if schema:
schema_items = schema.split("-")
assert schema_items[0] == schema_items[-1]
walk = [start]
while len(walk) < walk_length:
cur = walk[-1]
candidates = []
for node in G[cur].keys():
if schema is None or self.node_type[node] == schema_items[len(walk) % (len(schema_items) - 1)]:
candidates.append(node)
if candidates:
walk.append(rand.choice(candidates))
else:
break
return walk
def simulate_walks(self, num_walks, walk_length, schema=None):
G = self.G
walks = []
nodes = list(G.nodes())
# print('Walk iteration:')
if schema is not None:
schema_list = schema.split(",")
for walk_iter in range(num_walks):
random.shuffle(nodes)
for node in nodes:
if schema is None:
walks.append(self.walk(walk_length=walk_length, start=node))
else:
for schema_iter in schema_list:
if schema_iter.split("-")[0] == self.node_type[node]:
walks.append(
self.walk(
walk_length=walk_length,
start=node,
schema=schema_iter,
)
)
return walks
def get_G_from_edges(edges):
edge_dict = dict()
for edge in edges:
edge_key = str(edge[0]) + "_" + str(edge[1])
if edge_key not in edge_dict:
edge_dict[edge_key] = 1
else:
edge_dict[edge_key] += 1
tmp_G = nx.Graph()
for edge_key in edge_dict:
weight = edge_dict[edge_key]
x = int(edge_key.split("_")[0])
y = int(edge_key.split("_")[1])
tmp_G.add_edge(x, y)
tmp_G[x][y]["weight"] = weight
return tmp_G
def generate_pairs(all_walks, vocab, window_size=5):
pairs = []
skip_window = window_size // 2
for layer_id, walks in enumerate(all_walks):
for walk in walks:
for i in range(len(walk)):
for j in range(1, skip_window + 1):
if i - j >= 0:
pairs.append((vocab[walk[i]].index, vocab[walk[i - j]].index, layer_id))
if i + j < len(walk):
pairs.append((vocab[walk[i]].index, vocab[walk[i + j]].index, layer_id))
return pairs
def generate_vocab(all_walks):
index2word = []
raw_vocab = defaultdict(int)
for walks in all_walks:
for walk in walks:
for word in walk:
raw_vocab[word] += 1
vocab = {}
for word, v in raw_vocab.items():
vocab[word] = Vocab(count=v, index=len(index2word))
index2word.append(word)
index2word.sort(key=lambda word: vocab[word].count, reverse=True)
for i, word in enumerate(index2word):
vocab[word].index = i
return vocab, index2word
def get_batches(pairs, neighbors, batch_size):
n_batches = (len(pairs) + (batch_size - 1)) // batch_size
# result = []
for idx in range(n_batches):
x, y, t, neigh = [], [], [], []
for i in range(batch_size):
index = idx * batch_size + i
if index >= len(pairs):
break
x.append(pairs[index][0])
y.append(pairs[index][1])
t.append(pairs[index][2])
neigh.append(neighbors[pairs[index][0]])
yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)
def generate_walks(network_data, num_walks, walk_length, schema=None):
# if schema is not None:
# pass
# else:
# node_type = None
all_walks = []
for layer_id in network_data:
tmp_data = network_data[layer_id]
# start to do the random walk on a layer
layer_walker = RWGraph(get_G_from_edges(tmp_data))
layer_walks = layer_walker.simulate_walks(num_walks, walk_length, schema=schema)
all_walks.append(layer_walks)
return all_walks