Source code for cogdl.models

import importlib

from .base_model import BaseModel
from cogdl.utils import init_operator_configs


init_operator_configs()

MODEL_REGISTRY = {}


[docs]def register_model(name): """ New model types can be added to cogdl with the :func:`register_model` function decorator. For example:: @register_model('gat') class GAT(BaseModel): (...) Args: name (str): the name of the model """ def register_model_cls(cls): if name in MODEL_REGISTRY: raise ValueError("Cannot register duplicate model ({})".format(name)) if not issubclass(cls, BaseModel): raise ValueError("Model ({}: {}) must extend BaseModel".format(name, cls.__name__)) MODEL_REGISTRY[name] = cls cls.model_name = name return cls return register_model_cls
[docs]def try_import_model(model): if model not in MODEL_REGISTRY: if model in SUPPORTED_MODELS: importlib.import_module(SUPPORTED_MODELS[model]) else: print(f"Failed to import {model} model.") return False return True
[docs]def build_model(args): if not try_import_model(args.model): exit(1) return MODEL_REGISTRY[args.model].build_model_from_args(args)
SUPPORTED_MODELS = { "hope": "cogdl.models.emb.hope", "spectral": "cogdl.models.emb.spectral", "hin2vec": "cogdl.models.emb.hin2vec", "netmf": "cogdl.models.emb.netmf", "distmult": "cogdl.models.emb.distmult", "transe": "cogdl.models.emb.transe", "deepwalk": "cogdl.models.emb.deepwalk", "rotate": "cogdl.models.emb.rotate", "gatne": "cogdl.models.emb.gatne", "dgk": "cogdl.models.emb.dgk", "grarep": "cogdl.models.emb.grarep", "dngr": "cogdl.models.emb.dngr", "prone++": "cogdl.models.emb.pronepp", "graph2vec": "cogdl.models.emb.graph2vec", "metapath2vec": "cogdl.models.emb.metapath2vec", "node2vec": "cogdl.models.emb.node2vec", "complex": "cogdl.models.emb.complex", "pte": "cogdl.models.emb.pte", "netsmf": "cogdl.models.emb.netsmf", "line": "cogdl.models.emb.line", "sdne": "cogdl.models.emb.sdne", "prone": "cogdl.models.emb.prone", "daegc": "cogdl.models.agc.daegc", "agc": "cogdl.models.agc.agc", "gae": "cogdl.models.nn.gae", "vgae": "cogdl.models.nn.gae", "dgi": "cogdl.models.nn.dgi", "dgi_sampling": "cogdl.models.nn.dgi", "mvgrl": "cogdl.models.nn.mvgrl", "patchy_san": "cogdl.models.nn.patchy_san", "chebyshev": "cogdl.models.nn.pyg_cheb", "gcn": "cogdl.models.nn.gcn", "gdc_gcn": "cogdl.models.nn.gdc_gcn", "hgpsl": "cogdl.models.nn.pyg_hgpsl", "graphsage": "cogdl.models.nn.graphsage", "compgcn": "cogdl.models.nn.compgcn", "drgcn": "cogdl.models.nn.drgcn", "gpt_gnn": "cogdl.models.nn.pyg_gpt_gnn", "unet": "cogdl.models.nn.pyg_graph_unet", "gcnmix": "cogdl.models.nn.gcnmix", "diffpool": "cogdl.models.nn.diffpool", "gcnii": "cogdl.models.nn.gcnii", "sign": "cogdl.models.nn.sign", "pyg_gcn": "cogdl.models.nn.pyg_gcn", "mixhop": "cogdl.models.nn.mixhop", "gat": "cogdl.models.nn.gat", "han": "cogdl.models.nn.han", "ppnp": "cogdl.models.nn.ppnp", "grace": "cogdl.models.nn.grace", "jknet": "cogdl.models.nn.dgl_jknet", "pprgo": "cogdl.models.nn.pprgo", "gin": "cogdl.models.nn.gin", "dgcnn": "cogdl.models.nn.pyg_dgcnn", "grand": "cogdl.models.nn.grand", "gtn": "cogdl.models.nn.pyg_gtn", "rgcn": "cogdl.models.nn.rgcn", "deepergcn": "cogdl.models.nn.deepergcn", "drgat": "cogdl.models.nn.drgat", "infograph": "cogdl.models.nn.infograph", "dropedge_gcn": "cogdl.models.nn.dropedge_gcn", "disengcn": "cogdl.models.nn.disengcn", "fastgcn": "cogdl.models.nn.fastgcn", "mlp": "cogdl.models.nn.mlp", "sgc": "cogdl.models.nn.sgc", "stpgnn": "cogdl.models.nn.stpgnn", "sortpool": "cogdl.models.nn.sortpool", "srgcn": "cogdl.models.nn.pyg_srgcn", "asgcn": "cogdl.models.nn.asgcn", "gcc": "cogdl.models.nn.dgl_gcc", "unsup_graphsage": "cogdl.models.nn.unsup_graphsage", "sagpool": "cogdl.models.nn.pyg_sagpool", "graphsaint": "cogdl.models.nn.graphsaint", "m3s": "cogdl.models.nn.m3s", "supergat": "cogdl.models.nn.pyg_supergat", "moe_gcn": "cogdl.models.nn.moe_gcn", }