Source code for nntoolbox.vision.components.ho

"""Some higher order layers"""
import torch
from torch import nn, Tensor


__all__ = ['QuadraticPolynomialConv2D']


[docs]class QuadraticPolynomialConv2D(nn.Module): """ h(x) = sum_k(A_k * x)^2 + b * x + c where the * represents convolution References: Bergstra et al. "Quadratic Polynomials Learn Better Image Features." http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205 (dead link, use web archive) """ def __init__( self, in_channels, out_channels, kernel_size, rank: int, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', sqrt: bool=False, eps: float=1e-6 ): super().__init__() self.linear = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode ) self.quadratic = nn.Conv2d( in_channels, out_channels * rank, kernel_size, stride, padding, dilation, groups, False, padding_mode ) self.out_channels = out_channels self.rank = rank self.sqrt = sqrt self.eps = eps
[docs] def forward(self, input: Tensor) -> Tensor: linear_features = self.linear(input) quadratic_features = self.quadratic(input).pow(2) quadratic_features = quadratic_features.view( -1, self.rank, self.out_channels, quadratic_features.shape[2], quadratic_features.shape[3] ).sum(1) if self.sqrt: quadratic_features = torch.sqrt(quadratic_features + self.eps) return quadratic_features + linear_features