Source code for cogdl.models.nn.pyg_dgcnn

import torch
import torch.nn as nn
from torch_geometric.nn import DynamicEdgeConv, global_max_pool

from cogdl.utils import split_dataset_general

from .. import BaseModel, register_model
from .mlp import MLP


[docs]@register_model("dgcnn") class DGCNN(BaseModel): r"""EdgeConv and DynamicGraph in paper `"Dynamic Graph CNN for Learning on Point Clouds" <https://arxiv.org/pdf/1801.07829.pdf>__ .` Parameters ---------- in_feats : int Size of each input sample. out_feats : int Size of each output sample. hidden_dim : int Dimension of hidden layer embedding. k : int Number of neareast neighbors. """
[docs] @staticmethod def add_args(parser): parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) parser.add_argument("--lr", type=float, default=0.001)
[docs] @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.hidden_size, args.num_classes, )
[docs] @classmethod def split_dataset(cls, dataset, args): return split_dataset_general(dataset, args)
def __init__(self, in_feats, hidden_dim, out_feats, k=20, dropout=0.5): super(DGCNN, self).__init__() mlp1 = nn.Sequential( MLP(2 * in_feats, hidden_dim, hidden_dim, num_layers=3, norm="batchnorm"), nn.ReLU(), nn.BatchNorm1d(hidden_dim), ) mlp2 = nn.Sequential( MLP(2 * hidden_dim, 2 * hidden_dim, 2 * hidden_dim, num_layers=1, norm="batchnorm"), nn.ReLU(), nn.BatchNorm1d(2 * hidden_dim), ) self.conv1 = DynamicEdgeConv(mlp1, k, "max") self.conv2 = DynamicEdgeConv(mlp2, k, "max") self.linear = nn.Linear(hidden_dim + 2 * hidden_dim, 1024) self.final_mlp = nn.Sequential( nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.Dropout(dropout), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.Dropout(dropout), nn.Linear(256, out_feats), )
[docs] def forward(self, batch): h = batch.x h1 = self.conv1(h, batch.batch) h2 = self.conv2(h1, batch.batch) h = self.linear(torch.cat([h1, h2], dim=1)) h = global_max_pool(h, batch.batch) out = self.final_mlp(h) return out