models.emb.gatne

Module Contents

Classes

GATNE

The GATNE model from the `”Representation Learning for Attributed Multiplex Heterogeneous Network”

GATNEModel

NSLoss

RWGraph

Functions

get_G_from_edges(edges)

generate_pairs(all_walks, vocab, window_size=5)

generate_vocab(all_walks)

get_batches(pairs, neighbors, batch_size)

generate_walks(network_data, num_walks, walk_length, schema=None)

class models.emb.gatne.GATNE(dimension, walk_length, walk_num, window_size, worker, epoch, batch_size, edge_dim, att_dim, negative_samples, neighbor_samples, schema)[source]

Bases: models.BaseModel

The GATNE model from the “Representation Learning for Attributed Multiplex Heterogeneous Network” paper

Args:

walk_length (int) : The walk length. walk_num (int) : The number of walks to sample for each node. window_size (int) : The actual context size which is considered in language model. worker (int) : The number of workers for word2vec. epoch (int) : The number of training epochs. batch_size (int) : The size of each training batch. edge_dim (int) : Number of edge embedding dimensions. att_dim (int) : Number of attention dimensions. negative_samples (int) : Negative samples for optimization. neighbor_samples (int) : Neighbor samples for aggregation schema (str) : The metapath schema used in model. Metapaths are splited with “,”, while each node type are connected with “-” in each metapath. For example:”0-1-0,0-1-2-1-0”

static add_args(parser)[source]

Add model-specific arguments to the parser.

classmethod build_model_from_args(cls, args)[source]

Build a new model instance.

train(self, network_data)[source]
class models.emb.gatne.GATNEModel(num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a)[source]

Bases: torch.nn.Module

reset_parameters(self)[source]
forward(self, train_inputs, train_types, node_neigh)[source]
class models.emb.gatne.NSLoss(num_nodes, num_sampled, embedding_size)[source]

Bases: torch.nn.Module

reset_parameters(self)[source]
forward(self, input, embs, label)[source]
class models.emb.gatne.RWGraph(nx_G, node_type=None)[source]
walk(self, walk_length, start, schema=None)[source]
simulate_walks(self, num_walks, walk_length, schema=None)[source]
models.emb.gatne.get_G_from_edges(edges)[source]
models.emb.gatne.generate_pairs(all_walks, vocab, window_size=5)[source]
models.emb.gatne.generate_vocab(all_walks)[source]
models.emb.gatne.get_batches(pairs, neighbors, batch_size)[source]
models.emb.gatne.generate_walks(network_data, num_walks, walk_length, schema=None)[source]