Source code for layers.se_layer

import torch
import torch.nn as nn


[docs]class SELayer(nn.Module): """Squeeze-and-excitation networks""" def __init__(self, in_channels, se_channels): super(SELayer, self).__init__() self.in_channels = in_channels self.se_channels = se_channels self.encoder_decoder = nn.Sequential( nn.Linear(in_channels, se_channels), nn.ELU(), nn.Linear(se_channels, in_channels), nn.Sigmoid(), ) # self.reset_parameters()
[docs] def forward(self, x): """""" # Aggregate input representation x_global = torch.mean(x, dim=0) # Compute reweighting vector s s = self.encoder_decoder(x_global) return x * s