Source code for nntoolbox.components.shunting

"""Shunting Inhibition Modules"""
from torch import nn, Tensor, clamp
from .components import BiasLayer


__all__ = ['GeneralizedShuntingModule', 'GeneralizedShuntingMLP']


[docs]class GeneralizedShuntingModule(nn.Module): """ Implement a module that exhibits the shunting inhibition mechanism: y = f(x) / (a + g(x)) Difference from original implementation: clamping denominator. References: Ganesh Arulampalam, Abdesselam Bouzerdoum. "A generalized feedforward neural network architecture for classification and regression." https://www.sciencedirect.com/science/article/pii/S0893608003001163 """ def __init__(self, num: nn.Module, denom: nn.Module, bound_denom: bool=True, bound: float=0.1): super().__init__() assert bound > 0.0 self.num = num self.denom = denom self.bound_denom = bound_denom self.bound = bound
[docs] def forward(self, input: Tensor) -> Tensor: denom = self.denom(input) if self.bound_denom: denom = clamp(denom, min=self.bound) return self.num(input) / denom
[docs]class GeneralizedShuntingMLP(GeneralizedShuntingModule): def __init__( self, in_channels: int, out_channels: int, num_activation: nn.Module=nn.Identity(), denom_activation: nn.Module=nn.ReLU(), bound_denom: bool=True, bound: float=0.1 ): num = nn.Sequential(nn.Linear(in_channels, out_channels, True), num_activation, BiasLayer((out_channels,))) denom = nn.Sequential( nn.Linear(in_channels, out_channels, True), denom_activation, BiasLayer((out_channels,), init=1.0) ) super().__init__(num, denom, bound_denom, bound)