Source code for nntoolbox.components.activation

import torch
from torch import nn, Tensor
from ..utils import to_onehotv2


__all__ = ['ZeroCenterRelu', 'LWTA']


[docs]class ZeroCenterRelu(nn.ReLU): """ As described by Jeremy of FastAI """ def __init__(self, inplace: bool=False): super(ZeroCenterRelu, self).__init__(inplace)
[docs] def forward(self, input: Tensor) -> Tensor: return super().forward(input) - 0.5
[docs]class LWTA(nn.Module): """ Local Winner-Take-All Layer For every k consecutive units, keep only the one with highest activations and zero-out the rest. References: Rupesh Kumar Srivastava, Jonathan Masci, Sohrob Kazerounian, Faustino Gomez, Jürgen Schmidhuber. "Compete to Compute." https://papers.nips.cc/paper/5059-compete-to-compute.pdf """ def __init__(self, block_size): super().__init__() self.block_size = block_size
[docs] def forward(self, input: Tensor) -> Tensor: assert input.shape[1] % self.block_size == 0 input = input.view(-1, input.shape[1] // self.block_size, self.block_size) mask = to_onehotv2(torch.max(input, -1)[1], self.block_size).to(input.dtype).to(input.device) return (input * mask).view(-1, input.shape[1] * input.shape[2])