Source code for nntoolbox.vision.components.shunting

from torch import nn
from ...components import GeneralizedShuntingModule
from .layers import BiasLayer2D


__all__ = ['SiConv2D']


[docs]class SiConv2D(GeneralizedShuntingModule): """ Implement a shunting inhibition convolution layer. Right now only support channelwise fully connected variant. Difference from original implementation: clamping denominator. References: Fok Hing Chi Tivive and Abdesselam Bouzerdoum. "Efficient Training Algorithms for a Class of Shunting Inhibitory Convolutional Neural Networks." https://ieeexplore.ieee.org/document/1427760 """ def __init__( self, in_channels, out_channels, kernel_size, stride: int=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', num_activation: nn.Module = nn.Identity(), denom_activation: nn.Module = nn.ReLU(), bound_denom: bool = True, bound: float = 0.1 ): num = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode), num_activation ) denom = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode), denom_activation, BiasLayer2D(out_channels, init=1.0) ) super().__init__(num, denom, bound_denom, bound)