Source code for nntoolbox.vision.components.local

"""Locally Connected Layer and Subsampling layer for 2D input"""
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.nn import init
import math
from typing import Union, Tuple, Optional
from nntoolbox.vision.utils import compute_output_shape
from nntoolbox.vision.components import GlobalAveragePool


__all__ = ['LocallyConnected2D', 'Subsampling2D', 'CondConv2d']


[docs]class LocallyConnected2D(nn.Module): """ Works similarly to Conv2d, but does not share weight. Much more memory intensive, and slower (due to suboptimal native pytorch implementation) (UNTESTED) Example usages: Yaniv Taigman et al. "DeepFace: Closing the Gap to Human-Level Performance in Face Verification" https://www.cs.toronto.edu/~ranzato/publications/taigman_cvpr14.pdf """ def __init__( self, in_channels: int, out_channels: int, in_h: int, in_w: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]]=1, padding: Union[int, Tuple[int, int]]=0, dilation: Union[int, Tuple[int, int]]=1, groups: int=1, bias: bool=True, padding_mode: str='zeros' ): super(LocallyConnected2D, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else _pair(kernel_size) self.stride = stride if isinstance(stride, tuple) else _pair(stride) self.padding = padding if isinstance(padding, tuple) else _pair(padding) self.dilation = dilation if isinstance(dilation, tuple) else _pair(dilation) self.groups = groups self.padding_mode = padding_mode self.in_h, self.in_w = in_h, in_w self.output_h, self.output_w = self.compute_output_shape(in_h, in_w) self.weight = nn.Parameter(torch.Tensor( out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1], self.output_h, self.output_w) ) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels, self.output_h, self.output_w)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound)
[docs] def compute_output_shape(self, height: int, width: int) -> Tuple[int, int]: return ( compute_output_shape(height, self.padding[0], self.kernel_size[0], self.dilation[0], self.stride[0]), compute_output_shape(width, self.padding[1], self.kernel_size[1], self.dilation[1], self.stride[1]), )
[docs] def forward(self, input: Tensor) -> Tensor: assert input.shape[2] == self.in_h and input.shape[3] == self.in_w if self.padding_mode == 'circular': expanded_padding = [(self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2] input = F.pad(input, expanded_padding, mode='circular') padding = 0 else: padding = self.padding input = F.unfold( input, kernel_size=self.kernel_size, dilation=self.dilation, padding=padding, stride=self.stride ) output = (input.unsqueeze(1) * self.weight.view( 1, self.out_channels, self.in_channels * self.kernel_size[0] * self.kernel_size[1], -1 )).sum(2) output = output.view(-1, output.shape[1], self.output_h, self.output_w) if self.bias is not None: output = output + self.bias[None, :] return output
[docs]class Subsampling2D(nn.AvgPool2d): """ For each feature map of input, subsample one patch at the time, sum the values and then perform a linear transformation. Use in LeNet. (UNTESTED) References: Yann Lecun et al. "Gradient-Based Learning Applied to Document Recognition." http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf """ def __init__( self, in_channels: int, kernel_size: Union[int, Tuple[int, int]]=2, stride: Union[int, Tuple[int, int]]=2, padding: Union[int, Tuple[int, int]]=0, bias: bool=True, trainable: bool=True, ceil_mode: bool=False, count_include_pad: bool=True ): super().__init__(kernel_size, stride, padding, ceil_mode, count_include_pad) self.weight = nn.Parameter(torch.ones(in_channels), requires_grad=trainable) if bias: self.bias = nn.Parameter(torch.zeros(in_channels), requires_grad=trainable) else: self.register_parameter('bias', None)
[docs] def forward(self, input: Tensor) -> Tensor: output = super().forward(input) output = output * self.weight[None, :, None, None] if self.bias is not None: output = output + self.bias[None, :, None, None] return output
[docs]class CondConv2d(nn.Conv2d): """ Conditionally Parameterized Convolution Layer. References: Brandon Yang, Gabriel Bender, Quoc V. Le, Jiquan Ngiam. "CondConv: Conditionally Parameterized Convolutions for Efficient Inference." https://arxiv.org/abs/1904.04971 Pytorch implementation of Conv2d """ def __init__( self, num_experts: int, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros' ): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) convs = [ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) for _ in range(num_experts) ] self.weight = nn.Parameter(torch.stack([conv.weight for conv in convs], dim=0)) if bias: self.bias = nn.Parameter(torch.stack([conv.bias for conv in convs], dim=0)) else: self.bias = None self.routing_weight_fn = nn.Sequential(GlobalAveragePool(), nn.Linear(in_channels, num_experts), nn.Sigmoid()) self.num_experts = num_experts
[docs] def forward(self, input: Tensor) -> Tensor: if self.train and self.num_experts < 4: return self.branched_forward(input) else: return self.efficient_forward(input)
[docs] def branched_forward(self, input: Tensor) -> Tensor: if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) routing_weights = self.routing_weight_fn(input).transpose(0, 1) outputs = [] for e in range(self.num_experts): weight = self.weight[e] if self.bias is None: bias = self.bias else: bias = self.bias[e] if self.padding_mode == 'circular': output = F.conv2d(F.pad(input, expanded_padding, mode='circular'), weight, bias, self.stride, (0, 0), self.dilation, self.groups) else: output = F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) outputs.append(output) return (torch.stack(outputs, dim=0) * routing_weights[:, :, None, None, None]).sum(0)
[docs] def efficient_forward(self, input: Tensor) -> Tensor: if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) routing_weights = self.routing_weight_fn(input) outputs = [] for i in range(len(input)): weight = (self.weight * routing_weights[i][:, None, None, None, None]).sum(0) if self.bias is None: bias = self.bias else: bias = (self.bias * routing_weights[i][:, None]).sum(0) if self.padding_mode == 'circular': output = F.conv2d(F.pad(input[i:i + 1], expanded_padding, mode='circular'), weight, bias, self.stride, (0, 0), self.dilation, self.groups) else: output = F.conv2d(input[i:i + 1], weight, bias, self.stride, self.padding, self.dilation, self.groups) outputs.append(output) return torch.cat(outputs, dim=0)