Source code for nntoolbox.vision.components.kervolution

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import numpy as np

"""
Implement kervolution (kernel convolution) layers
https://arxiv.org/pdf/1904.03955.pdf
"""


[docs]class LinearKernel(nn.Module): def __init__(self, cp: float=1.0, trainable=True): assert cp > 0 super(LinearKernel, self).__init__() self.log_cp = nn.Parameter(torch.tensor(np.log(cp), requires_grad=trainable))
[docs] def forward(self, input: Tensor, weight: Tensor, bias: Tensor): weight = weight.view(weight.shape[0], -1).t() output = input.permute(0, 2, 1).matmul(weight).permute(0, 2, 1) + torch.exp(self.log_cp) return output + bias.unsqueeze(0).unsqueeze(-1) if bias is not None else output
[docs]class PolynomialKernel(LinearKernel): def __init__(self, dp: int=3, cp: float=2.0, trainable=True): super(PolynomialKernel, self).__init__(cp, trainable) self._dp = dp
[docs] def forward(self, input: Tensor, weight: Tensor, bias: Tensor): return super().forward(input, weight, bias).pow(self._dp)
[docs]class GaussianKernel(nn.Module): def __init__(self, bandwidth: int=1, trainable=True): assert bandwidth > 0 super(GaussianKernel, self).__init__() self.log_bandwidth = nn.Parameter(torch.tensor(np.log(bandwidth)), requires_grad=trainable)
[docs] def forward(self, input: Tensor, weight: Tensor, bias: Tensor): """ :param input: (batch_size, patch_size, n_patches) :param weight: (out_channels, in_channels, kernel_height, kernel_width) :return: """ input = input.unsqueeze(-2) weight = weight.view(weight.shape[0], -1).t().unsqueeze(0).unsqueeze(-1) output = torch.exp(-torch.exp(self.log_bandwidth) * (input - weight).pow(2).sum(1)) return output + bias.unsqueeze(0).unsqueeze(-1) if bias is not None else output
[docs]class Kervolution2D(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode='zeros' ): super(Kervolution2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, 1, bias, padding_mode ) self.kernel = kernel()
[docs] def compute_output_shape(self, height, width): def compute_shape_helper(inp_dim, padding, kernel_size, dilation, stride): return np.floor( (inp_dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 ).astype(np.uint32) return ( compute_shape_helper(height, self.padding[0], self.kernel_size[0], self.dilation[0], self.stride[0]), compute_shape_helper(width, self.padding[1], self.kernel_size[1], self.dilation[1], self.stride[1]), )
[docs] def forward(self, input): output_h, output_w = self.compute_output_shape(input.shape[2], input.shape[3]) if self.padding_mode == 'circular': expanded_padding = [(self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2] input = F.pad(input, expanded_padding, mode='circular') padding = 0 else: padding = self.padding input = F.unfold( input, kernel_size=self.kernel_size, dilation=self.dilation, padding=padding, stride=self.stride ) output = self.kernel(input, self.weight, self.bias) # output = torch.clamp(output, min=-10.0, max=10.0) return output.view(-1, self.out_channels, output_h, output_w)
[docs]class KervolutionalLayer(nn.Sequential): """ Simple convolutional layer: input -> conv2d -> activation -> norm 2d """ def __init__( self, in_channels, out_channels, kernel, kernel_size=3, stride=1, padding=0, bias=False, activation=nn.ReLU, normalization=nn.BatchNorm2d ): super(KervolutionalLayer, self).__init__() self.add_module( "main", nn.Sequential( Kervolution2D( in_channels=in_channels, out_channels=out_channels, kernel=kernel, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ), activation(), normalization(num_features=out_channels) ) )