Source code for options

import argparse

from cogdl.datasets import DATASET_REGISTRY
from cogdl.models import MODEL_REGISTRY
from cogdl.tasks import TASK_REGISTRY


[docs]def get_parser(): parser = argparse.ArgumentParser(conflict_handler="resolve") # fmt: off # parser.add_argument('--log-interval', type=int, default=1000, metavar='N', # help='log progress every N batches (when progress bar is disabled)') # parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', # help='path to save logs for tensorboard, should match --logdir ' # 'of running tensorboard (default: no tensorboard logging)') parser.add_argument('--seed', default=[1], type=int, nargs='+', metavar='N', help='pseudo random number generator seed') parser.add_argument('--max-epoch', default=500, type=int) parser.add_argument("--patience", type=int, default=100) parser.add_argument('--lr', default=0.01, type=float) parser.add_argument('--weight-decay', default=5e-4, type=float) parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--device-id', default=[0], type=int, nargs='+', help='which GPU to use') parser.add_argument('--save-dir', default='.', type=str) parser.add_argument('--enhance', action='store_true', help='use prone to enhance embedding') # fmt: on return parser
[docs]def add_task_args(parser): group = parser.add_argument_group("Task configuration") # fmt: off group.add_argument('--task', '-t', default='node_classification', metavar='TASK', required=True, choices=TASK_REGISTRY.keys(), help='Task') # fmt: on return group
[docs]def add_dataset_args(parser): group = parser.add_argument_group("Dataset and data loading") # fmt: off group.add_argument('--dataset', '-dt', metavar='DATASET', nargs='+', required=True, choices=DATASET_REGISTRY.keys(), help='Dataset') # fmt: on return group
[docs]def add_model_args(parser): group = parser.add_argument_group("Model configuration") # fmt: off group.add_argument('--model', '-m', metavar='MODEL', nargs='+', required=True, choices=MODEL_REGISTRY.keys(), help='Model Architecture') # fmt: on return group
[docs]def get_training_parser(): parser = get_parser() add_task_args(parser) add_dataset_args(parser) add_model_args(parser) return parser
[docs]def get_display_data_parser(): parser = get_parser() add_dataset_args(parser) parser.add_argument("--depth", default=3, type=int) return parser
[docs]def get_download_data_parser(): parser = get_parser() add_dataset_args(parser) return parser
[docs]def parse_args_and_arch(parser, args): """The parser doesn't know about model-specific args, so we parse twice.""" # args, _ = parser.parse_known_args() # Add *-specific args to parser. TASK_REGISTRY[args.task].add_args(parser) for model in args.model: MODEL_REGISTRY[model].add_args(parser) # Parse a second time. args = parser.parse_args() return args