import hashlib
import networkx as nx
import numpy as np
import random
from .. import BaseModel, register_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from tqdm import tqdm
[docs]class Hin2vec_layer(nn.Module):
def __init__(self, num_node, num_relation, hidden_size, cpu):
super(Hin2vec_layer, self).__init__()
self.num_node = num_node
self.Wx = Parameter(torch.randn(num_node, hidden_size))
self.Wr = Parameter(torch.randn(num_relation, hidden_size))
self.device = torch.device('cpu' if cpu else 'cuda')
self.X = F.one_hot(torch.arange(num_node), num_node).float().to(self.device)
self.R = F.one_hot(torch.arange(num_relation), num_relation).float().to(self.device)
self.criterion = nn.CrossEntropyLoss()
[docs] def regulartion(self, embr):
clamp_embr = torch.clamp(embr, -6.0, 6.0)
sigmod1 = torch.sigmoid(clamp_embr)
# return sigmod1
re_embr = torch.mul(sigmod1, 1-sigmod1)
return re_embr
[docs] def forward(self, x, y, r, l):
x_one, y_one, r_one = torch.index_select(self.X, 0, x), torch.index_select(self.X, 0, y), torch.index_select(self.R, 0, r)
self.embx, self.emby, self.embr = torch.mm(x_one, self.Wx), torch.mm(y_one, self.Wx), torch.mm(r_one, self.Wr)
self.re_embr = self.regulartion(self.embr)
self.preds = torch.unsqueeze(torch.sigmoid(torch.sum(torch.mul(torch.mul(self.embx, self.emby), self.re_embr), 1)),1)
self.logits = torch.cat((self.preds, 1- self.preds), 1)
return self.logits, self.criterion(self.logits, l)
[docs] def get_emb(self,):
x = F.one_hot(torch.arange(0, self.num_node), num_classes=self.num_node).float().to(self.device)
return torch.mm(x, self.Wx)
[docs]class RWgraph():
def __init__(self, nx_G, node_type=None):
self.G = nx_G
self.node_type = node_type
[docs] def _walk(self, start_node, walk_length):
# Simulate a random walk starting from start node.
walk = [start_node]
while len(walk) < walk_length:
cur = walk[-1]
cur_nbrs = list(self.G.neighbors(cur))
if len(cur_nbrs) == 0:
break
k = int(np.floor(np.random.rand() * len(cur_nbrs)))
walk.append(cur_nbrs[k])
return walk
[docs] def _simulate_walks(self, walk_length, num_walks):
# Repeatedly simulate random walks from each node.
walks = []
nodes = list(self.G.nodes())
print("node number:", len(nodes))
for walk_iter in range(num_walks):
print(str(walk_iter + 1), "/", str(num_walks))
random.shuffle(nodes)
for node in nodes:
walks.append(self._walk(node, walk_length))
return walks
[docs] def data_preparation(self, walks, hop, negative):
# data preparation via process walks and negative sampling
node_type = self.node_type
num_node_type = len(set(node_type))
type2list = [[] for _ in range(num_node_type)]
for node, nt in enumerate(node_type):
type2list[nt].append(node)
print("number of type2list", num_node_type)
relation = dict()
pairs = []
for walk in walks:
for i in range(len(walk) - hop):
for j in range(1, hop+1):
x, y = walk[i], walk[i+j]
tx, ty = node_type[x], node_type[y]
if x ==y: continue
meta_str = "-".join([str(node_type[a]) for a in walk[i:i+j+1]])
if meta_str not in relation:
relation[meta_str] = len(meta_str)
pairs.append([x, y, relation[meta_str], 1])
for k in range(negative):
if random.random() > 0.5:
fx = random.choice(type2list[node_type[x]])
while fx == x:
fx = random.choice(type2list[node_type[x]])
pairs.append([fx, y, relation[meta_str], 0])
else:
fy = random.choice(type2list[node_type[y]])
while fy == y:
fy = random.choice(type2list[node_type[y]])
pairs.append([x, fy, relation[meta_str], 0])
print("number of relation", len(relation))
return np.asarray(pairs), relation
[docs]@register_model("hin2vec")
class Hin2vec(BaseModel):
r"""The Hin2vec model from the `"HIN2Vec: Explore Meta-paths in Heterogeneous Information Networks for Representation Learning"
<https://dl.acm.org/doi/10.1145/3132847.3132953>`_ paper.
Args:
hidden_size (int) : The dimension of node representation.
walk_length (int) : The walk length.
walk_num (int) : The number of walks to sample for each node.
batch_size (int) : The batch size of training in Hin2vec.
hop (int) : The number of hop to construct training samples in Hin2vec.
negative (int) : The number of nagative samples for each meta2path pair.
epochs (int) : The number of training iteration.
lr (float) : The initial learning rate of SGD.
cpu (bool) : Use CPU or GPU to train hin2vec.
"""
@staticmethod
[docs] def add_args(parser):
parser.add_argument("--hidden-size", type=int, default=128)
parser.add_argument('--walk-length', type=int, default=80,
help='Length of walk per source. Default is 80.')
parser.add_argument('--walk-num', type=int, default=40,
help='Number of walks per source. Default is 40.')
parser.add_argument('--batch-size', type=int, default=1000,
help='Batch size in SGD training process. Default is 1000.')
parser.add_argument("--hop", type=int, default=2)
parser.add_argument("--negative", type=int, default=5)
parser.add_argument("--epochs", type=int, default=1)
@classmethod
[docs] def build_model_from_args(cls, args):
return cls(
args.hidden_size,
args.walk_length,
args.walk_num,
args.batch_size,
args.hop,
args.negative,
args.epochs,
args.lr,
args.cpu
)
def __init__(self, hidden_dim, walk_length, walk_num, batch_size, hop, negative, epochs, lr, cpu=True):
super(Hin2vec, self).__init__()
self.hidden_dim = hidden_dim
self.walk_length = walk_length
self.walk_num = walk_num
self.batch_size = batch_size
self.hop = hop
self.negative = negative
self.epochs = epochs
self.lr = lr
self.cpu = cpu
self.device = torch.device('cpu' if self.cpu else 'cuda')
[docs] def train(self, G, node_type):
self.num_node = G.number_of_nodes()
rw = RWgraph(G, node_type)
walks = rw._simulate_walks(self.walk_length, self.walk_num)
pairs, relation = rw.data_preparation(walks, self.hop, self.negative)
self.num_relation = len(relation)
model = Hin2vec_layer(self.num_node, self.num_relation, self.hidden_dim, self.cpu)
self.model = model.to(self.device)
num_batch = int(len(pairs) / self.batch_size)
print_num_batch = 100
print("number of batch", num_batch)
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
epoch_iter = tqdm(range(self.epochs))
for epoch in epoch_iter:
loss_n, pred, label = [], [], []
for i in range(num_batch):
batch_pairs = torch.from_numpy(pairs[i *self.batch_size :(i+1) * self.batch_size])
batch_pairs = batch_pairs.to(self.device)
batch_pairs = batch_pairs.T
x, y, r, l = batch_pairs[0], batch_pairs[1], batch_pairs[2], batch_pairs[3]
opt.zero_grad()
logits, loss = self.model.forward(x, y, r, l)
loss_n.append(loss.item())
label.append(l)
pred.extend(logits)
if i% print_num_batch ==0 and i!=0:
label = torch.cat(label).to(self.device)
pred = torch.stack(pred, dim=0)
pred = pred.max(1)[1]
acc = pred.eq(label).sum().item() / len(label)
epoch_iter.set_description(
f"Epoch: {i:03d}, Loss: {sum(loss_n)/print_num_batch:.5f}, Acc: {acc:.5f}"
)
loss_n, pred, label = [], [], []
loss.backward()
opt.step()
embedding = self.model.get_emb()
return embedding.cpu().detach().numpy()