Source code for cogdl.models.nn.gcnmix

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from cogdl.utils import spmm
from .. import BaseModel, register_model


def mix_hidden_state(feat, target, train_index, alpha):
    if alpha > 0:
        lamb = np.random.beta(alpha, alpha)
    else:
        lamb = 1
    permuted_index = train_index[torch.randperm(train_index.size(0))]
    feat[train_index] = lamb * feat[train_index] + (1 - lamb) * feat[permuted_index]
    return feat, target[train_index], target[permuted_index], lamb


def sharpen(prob, temperature):
    prob = torch.pow(prob, 1.0 / temperature)
    row_sum = torch.sum(prob, dim=1).reshape(-1, 1)
    return prob / row_sum


def get_one_hot_label(labels, index):
    num_classes = int(torch.max(labels) + 1)
    target = torch.zeros(labels.shape[0], num_classes).to(labels.device)

    target[index, labels[index]] = 1
    return target


def get_current_consistency_weight(final_consistency_weight, rampup_starts, rampup_ends, epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    rampup_length = rampup_ends - rampup_starts
    rampup = 1.0
    epoch = epoch - rampup_starts
    if rampup_length != 0:
        current = np.clip(epoch, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        rampup = float(np.exp(-5.0 * phase * phase))
    return final_consistency_weight * rampup


class GCNConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNConv, self).__init__()
        self.weight = nn.Linear(in_features=in_feats, out_features=out_feats)
        self.edge_index = None
        self.edge_attr = None

    def forward(self, graph, x):
        h = self.weight(x)
        h = spmm(graph, h)
        return h

    def forward_aux(self, x):
        return self.weight(x)


class BaseGNNMix(BaseModel):
    def __init__(self, in_feat, hidden_size, num_classes, k, temperature, alpha, dropout):
        super(BaseGNNMix, self).__init__()
        self.dropout = dropout
        self.alpha = alpha
        self.k = k
        self.temperature = temperature

        self.input_gnn = GCNConv(in_feat, hidden_size)
        self.hidden_gnn = GCNConv(hidden_size, num_classes)
        self.loss_f = nn.BCELoss()

    def forward(self, graph):
        graph.sym_norm()
        x = graph.x
        h = F.dropout(x, p=self.dropout, training=self.training)
        h = self.input_gnn(graph, h)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = self.hidden_gnn(graph, h)
        return h

    def forward_aux(self, x, label, train_index, mix_hidden=True, layer_mix=1):
        h = F.dropout(x, p=self.dropout, training=self.training)
        assert layer_mix in (0, 1)
        if layer_mix == 0:
            h, target, target_mix, lamb = mix_hidden_state(h, label, train_index, self.alpha)
        h = self.input_gnn.forward_aux(h)
        h = F.relu(h)
        if layer_mix == 1:
            h, target, target_mix, lamb = mix_hidden_state(h, label, train_index, self.alpha)
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = self.hidden_gnn.forward_aux(h)
        target_label = lamb * target + (1 - lamb) * target_mix
        return h, target_label

    def update_aux(self, data, vector_labels, train_index, opt):
        device = data.x.device
        train_unlabelled = torch.where(data.train_mask == False)[0].to(device)  # noqa E712
        temp_labels = torch.zeros(self.k, vector_labels.shape[0], vector_labels.shape[1]).to(device)
        with torch.no_grad():
            for i in range(self.k):
                temp_labels[i, :, :] = self.predict_noise(data)

        target_labels = temp_labels.mean(dim=0)
        target_labels = sharpen(target_labels, self.temperature)
        vector_labels[train_unlabelled] = target_labels[train_unlabelled]
        sampled_unlabelled = torch.randint(0, train_unlabelled.shape[0], size=(train_index.shape[0],))
        train_unlabelled = train_unlabelled[sampled_unlabelled]

        def get_loss(index):
            mix_logits, target = self.forward_aux(data.x, vector_labels, index, mix_hidden=True)
            temp_loss = self.loss_f(F.softmax(mix_logits[index], -1), target)
            return temp_loss

        sup_loss = get_loss(train_index)
        unsup_loss = get_loss(train_unlabelled)

        mixup_weight = get_current_consistency_weight(
            opt["final_consistency_weight"], opt["rampup_starts"], opt["rampup_ends"], opt["epoch"]
        )

        loss_sum = sup_loss + mixup_weight * unsup_loss
        return loss_sum

    def update_soft(self, graph, labels, train_index):
        out = self.forward(graph)
        out = F.log_softmax(out, dim=-1)
        loss_sum = F.nll_loss(out[train_index], labels[train_index])
        return loss_sum

    def loss(self, data, opt):
        device = data.x.device
        train_index = torch.where(data.train_mask)[0].to(device)
        rand_n = random.randint(0, 1)
        if rand_n == 0:
            vector_labels = get_one_hot_label(data.y, train_index).to(device)
            return self.update_aux(data, vector_labels, train_index, opt)
        else:
            return self.update_soft(data, data.y, train_index)

    def predict_noise(self, data, tau=1):
        out = self.forward(data) / tau
        return out


[docs]@register_model("gcnmix") class GCNMix(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--k", type=int, default=10) parser.add_argument("--temperature", type=float, default=0.1) parser.add_argument("--rampup-starts", type=int, default=500) parser.add_argument("--rampup_ends", type=int, default=1000) parser.add_argument("--mixup-consistency", type=float, default=10.0) parser.add_argument("--ema-decay", type=float, default=0.999) parser.add_argument("--tau", type=float, default=1.0)
[docs] @classmethod def build_model_from_args(cls, args): return cls( in_feat=args.num_features, hidden_size=args.hidden_size, num_classes=args.num_classes, k=args.k, temperature=args.temperature, alpha=args.alpha, rampup_starts=args.rampup_starts, rampup_ends=args.rampup_ends, final_consistency_weight=args.mixup_consistency, ema_decay=args.ema_decay, dropout=args.dropout, )
def __init__( self, in_feat, hidden_size, num_classes, k, temperature, alpha, rampup_starts, rampup_ends, final_consistency_weight, ema_decay, dropout, ): super(GCNMix, self).__init__() self.final_consistency_weight = final_consistency_weight self.rampup_starts = rampup_starts self.rampup_ends = rampup_ends self.ema_decay = ema_decay self.base_gnn = BaseGNNMix(in_feat, hidden_size, num_classes, k, temperature, alpha, dropout) self.ema_gnn = BaseGNNMix(in_feat, hidden_size, num_classes, k, temperature, alpha, dropout) for param in self.ema_gnn.parameters(): param.detach_() self.epoch = 0
[docs] def forward(self, graph): return self.base_gnn.forward(graph)
[docs] def forward_ema(self, graph): return self.ema_gnn(graph)
[docs] def node_classification_loss(self, data): opt = { "epoch": self.epoch, "final_consistency_weight": self.final_consistency_weight, "rampup_starts": self.rampup_starts, "rampup_ends": self.rampup_ends, } self.base_gnn.train() loss_n = self.base_gnn.loss(data, opt) alpha = min(1 - 1 / (self.epoch + 1), self.ema_decay) for ema_param, param in zip(self.ema_gnn.parameters(), self.base_gnn.parameters()): ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) self.epoch += 1 return loss_n
[docs] def predict(self, data): prediction = self.forward_ema(data) return prediction