Source code for cogdl.layers.gcc_module

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.nn.pytorch import NNConv, Set2Set
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import AvgPooling, MaxPooling, SumPooling


[docs]class GATLayer(nn.Module): def __init__(self, g, in_dim, out_dim): super(GATLayer, self).__init__() self.g = g self.fc = nn.Linear(in_dim, out_dim, bias=False) self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
[docs] def edge_attention(self, edges): z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1) a = self.attn_fc(z2) return {"e": F.leaky_relu(a)}
[docs] def message_func(self, edges): return {"z": edges.src["z"], "e": edges.data["e"]}
[docs] def reduce_func(self, nodes): alpha = F.softmax(nodes.mailbox["e"], dim=1) h = torch.sum(alpha * nodes.mailbox["z"], dim=1) return {"h": h}
[docs] def forward(self, h): z = self.fc(h) self.g.ndata["z"] = z self.g.apply_edges(self.edge_attention) self.g.update_all(self.message_func, self.reduce_func) return self.g.ndata.pop("h")
[docs]class SELayer(nn.Module): """Squeeze-and-excitation networks""" def __init__(self, in_channels, se_channels): super(SELayer, self).__init__() self.in_channels = in_channels self.se_channels = se_channels self.encoder_decoder = nn.Sequential( nn.Linear(in_channels, se_channels), nn.ELU(), nn.Linear(se_channels, in_channels), nn.Sigmoid(), )
[docs] def forward(self, x): """""" # Aggregate input representation x_global = torch.mean(x, dim=0) # Compute reweighting vector s s = self.encoder_decoder(x_global) return x * s
[docs]class ApplyNodeFunc(nn.Module): """Update the node feature hv with MLP, BN and ReLU.""" def __init__(self, mlp, use_selayer): super(ApplyNodeFunc, self).__init__() self.mlp = mlp self.bn = ( SELayer(self.mlp.output_dim, int(np.sqrt(self.mlp.output_dim))) if use_selayer else nn.BatchNorm1d(self.mlp.output_dim) )
[docs] def forward(self, h): h = self.mlp(h) h = self.bn(h) h = F.relu(h) return h
[docs]class MLP(nn.Module): """MLP with linear output""" def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_selayer): """MLP layers construction Paramters --------- num_layers: int The number of linear layers input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction """ super(MLP, self).__init__() self.linear_or_not = True # default is linear model self.num_layers = num_layers self.output_dim = output_dim if num_layers < 1: raise ValueError("number of layers should be positive!") elif num_layers == 1: # Linear model self.linear = nn.Linear(input_dim, output_dim) else: # Multi-layer model self.linear_or_not = False self.linears = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.linears.append(nn.Linear(input_dim, hidden_dim)) for layer in range(num_layers - 2): self.linears.append(nn.Linear(hidden_dim, hidden_dim)) self.linears.append(nn.Linear(hidden_dim, output_dim)) for layer in range(num_layers - 1): self.batch_norms.append( SELayer(hidden_dim, int(np.sqrt(hidden_dim))) if use_selayer else nn.BatchNorm1d(hidden_dim) )
[docs] def forward(self, x): if self.linear_or_not: # If linear model return self.linear(x) else: # If MLP h = x for i in range(self.num_layers - 1): h = F.relu(self.batch_norms[i](self.linears[i](h))) return self.linears[-1](h)
[docs]class UnsupervisedGAT(nn.Module): def __init__(self, node_input_dim, node_hidden_dim, edge_input_dim, num_layers, num_heads): super(UnsupervisedGAT, self).__init__() assert node_hidden_dim % num_heads == 0 self.layers = nn.ModuleList( [ GATLayer( in_feats=node_input_dim if i == 0 else node_hidden_dim, out_feats=node_hidden_dim // num_heads, num_heads=num_heads, feat_drop=0.0, attn_drop=0.0, alpha=0.2, residual=False, agg_mode="flatten", activation=F.leaky_relu if i + 1 < num_layers else None, ) for i in range(num_layers) ] )
[docs] def forward(self, g, n_feat, e_feat): for i, layer in enumerate(self.layers): n_feat = layer(g, n_feat) return n_feat
[docs]class UnsupervisedMPNN(nn.Module): """ MPNN from `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__ Parameters ---------- node_input_dim : int Dimension of input node feature, default to be 15. edge_input_dim : int Dimension of input edge feature, default to be 15. output_dim : int Dimension of prediction, default to be 12. node_hidden_dim : int Dimension of node feature in hidden layers, default to be 64. edge_hidden_dim : int Dimension of edge feature in hidden layers, default to be 128. num_step_message_passing : int Number of message passing steps, default to be 6. num_step_set2set : int Number of set2set steps num_layer_set2set : int Number of set2set layers """ def __init__( self, output_dim=32, node_input_dim=32, node_hidden_dim=32, edge_input_dim=32, edge_hidden_dim=32, num_step_message_passing=6, lstm_as_gate=False, ): super(UnsupervisedMPNN, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim), ) self.conv = NNConv( in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type="sum", ) self.lstm_as_gate = lstm_as_gate if lstm_as_gate: self.lstm = nn.LSTM(node_hidden_dim, node_hidden_dim) else: self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
[docs] def forward(self, g, n_feat, e_feat): """Predict molecule labels Parameters ---------- g : DGLGraph Input DGLGraph for molecule(s) n_feat : tensor of dtype float32 and shape (B1, D1) Node features. B1 for number of nodes and D1 for the node feature size. e_feat : tensor of dtype float32 and shape (B2, D2) Edge features. B2 for number of edges and D2 for the edge feature size. Returns ------- res : Predicted labels """ out = F.relu(self.lin0(n_feat)) # (B1, H1) h = out.unsqueeze(0) # (1, B1, H1) c = torch.zeros_like(h) for i in range(self.num_step_message_passing): m = F.relu(self.conv(g, out, e_feat)) # (B1, H1) if self.lstm_as_gate: out, (h, c) = self.lstm(m.unsqueeze(0), (h, c)) else: out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) return out
[docs]class UnsupervisedGIN(nn.Module): """GIN model""" def __init__( self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, use_selayer, ): """model parameters setting Paramters --------- num_layers: int The number of linear layers in the neural network num_mlp_layers: int The number of linear layers in mlps input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction final_dropout: float dropout ratio on the final linear layer learn_eps: boolean If True, learn epsilon to distinguish center nodes from neighbors If False, aggregate neighbors and center nodes altogether. neighbor_pooling_type: str how to aggregate neighbors (sum, mean, or max) graph_pooling_type: str how to aggregate entire nodes in a graph (sum, mean or max) """ super(UnsupervisedGIN, self).__init__() self.num_layers = num_layers self.learn_eps = learn_eps # List of MLPs self.ginlayers = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(self.num_layers - 1): if layer == 0: mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim, use_selayer) else: mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim, use_selayer) self.ginlayers.append( GINConv( ApplyNodeFunc(mlp, use_selayer), neighbor_pooling_type, 0, self.learn_eps, ) ) self.batch_norms.append( SELayer(hidden_dim, int(np.sqrt(hidden_dim))) if use_selayer else nn.BatchNorm1d(hidden_dim) ) # Linear function for graph poolings of output of each layer # which maps the output of different layers into a prediction score self.linears_prediction = torch.nn.ModuleList() for layer in range(num_layers): if layer == 0: self.linears_prediction.append(nn.Linear(input_dim, output_dim)) else: self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) self.drop = nn.Dropout(final_dropout) if graph_pooling_type == "sum": self.pool = SumPooling() elif graph_pooling_type == "mean": self.pool = AvgPooling() elif graph_pooling_type == "max": self.pool = MaxPooling() else: raise NotImplementedError
[docs] def forward(self, g, h, efeat): # list of hidden representation at each layer (including input) hidden_rep = [h] for i in range(self.num_layers - 1): h = self.ginlayers[i](g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform pooling over all nodes in each graph in every layer all_outputs = [] for i, h in list(enumerate(hidden_rep)): pooled_h = self.pool(g, h) all_outputs.append(pooled_h) score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) return score_over_layer, all_outputs[1:]
[docs]class GraphEncoder(nn.Module): """ MPNN from `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__ Parameters ---------- node_input_dim : int Dimension of input node feature, default to be 15. edge_input_dim : int Dimension of input edge feature, default to be 15. output_dim : int Dimension of prediction, default to be 12. node_hidden_dim : int Dimension of node feature in hidden layers, default to be 64. edge_hidden_dim : int Dimension of edge feature in hidden layers, default to be 128. num_step_message_passing : int Number of message passing steps, default to be 6. num_step_set2set : int Number of set2set steps num_layer_set2set : int Number of set2set layers """ def __init__( self, positional_embedding_size=32, max_node_freq=8, max_edge_freq=8, max_degree=128, freq_embedding_size=32, degree_embedding_size=32, output_dim=32, node_hidden_dim=32, edge_hidden_dim=32, num_layers=6, num_heads=4, num_step_set2set=6, num_layer_set2set=3, norm=False, gnn_model="mpnn", degree_input=False, lstm_as_gate=False, ): super(GraphEncoder, self).__init__() if degree_input: node_input_dim = positional_embedding_size + degree_embedding_size + 1 else: node_input_dim = positional_embedding_size + 1 edge_input_dim = freq_embedding_size + 1 if gnn_model == "mpnn": self.gnn = UnsupervisedMPNN( output_dim=output_dim, node_input_dim=node_input_dim, node_hidden_dim=node_hidden_dim, edge_input_dim=edge_input_dim, edge_hidden_dim=edge_hidden_dim, num_step_message_passing=num_layers, lstm_as_gate=lstm_as_gate, ) elif gnn_model == "gat": self.gnn = UnsupervisedGAT( node_input_dim=node_input_dim, node_hidden_dim=node_hidden_dim, edge_input_dim=edge_input_dim, num_layers=num_layers, num_heads=num_heads, ) elif gnn_model == "gin": self.gnn = UnsupervisedGIN( num_layers=num_layers, num_mlp_layers=2, input_dim=node_input_dim, hidden_dim=node_hidden_dim, output_dim=output_dim, final_dropout=0.5, learn_eps=False, graph_pooling_type="sum", neighbor_pooling_type="sum", use_selayer=False, ) self.gnn_model = gnn_model self.max_node_freq = max_node_freq self.max_edge_freq = max_edge_freq self.max_degree = max_degree self.degree_input = degree_input if degree_input: self.degree_embedding = nn.Embedding(num_embeddings=max_degree + 1, embedding_dim=degree_embedding_size) self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set) self.lin_readout = nn.Sequential( nn.Linear(2 * node_hidden_dim, node_hidden_dim), nn.ReLU(), nn.Linear(node_hidden_dim, output_dim), ) self.norm = norm
[docs] def forward(self, g, return_all_outputs=False): """Predict molecule labels Parameters ---------- g : DGLGraph Input DGLGraph for molecule(s) n_feat : tensor of dtype float32 and shape (B1, D1) Node features. B1 for number of nodes and D1 for the node feature size. e_feat : tensor of dtype float32 and shape (B2, D2) Edge features. B2 for number of edges and D2 for the edge feature size. Returns ------- res : Predicted labels """ if self.degree_input: device = g.ndata["seed"].device degrees = g.in_degrees() if device != torch.device("cpu"): degrees = degrees.cuda(device) n_feat = torch.cat( ( g.ndata["pos_undirected"], self.degree_embedding(degrees.clamp(0, self.max_degree)), g.ndata["seed"].unsqueeze(1).float(), ), dim=-1, ) else: n_feat = torch.cat( (g.ndata["pos_undirected"], g.ndata["seed"].unsqueeze(1).float()), dim=-1, ) e_feat = None if self.gnn_model == "gin": x, all_outputs = self.gnn(g, n_feat, e_feat) else: x, all_outputs = self.gnn(g, n_feat, e_feat), None x = self.set2set(g, x) x = self.lin_readout(x) if self.norm: x = F.normalize(x, p=2, dim=-1, eps=1e-5) if return_all_outputs: return x, all_outputs else: return x