Source code for nntoolbox.vision.components.upsample

import torch
from torch import nn
from torch.nn import functional as F
from .layers import ConvolutionalLayer


__all__ = ['ResizeConvolutionalLayer', 'PixelShuffleConvolutionLayer']


[docs]class ResizeConvolutionalLayer(nn.Module): """ Upsample the image (using an interpolation algorithm), then pass to a conv layer """ def __init__(self, in_channels, out_channels, activation=nn.ReLU, normalization=nn.BatchNorm2d, mode='bilinear'): super(ResizeConvolutionalLayer, self).__init__() self._mode = mode self.conv = ConvolutionalLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation, normalization=normalization )
[docs] def forward(self, input, out_h, out_w): upsampled = F.interpolate(input, size=(out_h, out_w), mode=self._mode) return self.conv(upsampled)
[docs]class PixelShuffleConvolutionLayer(nn.Sequential): """ Upsample the image using normal convolution follow by pixel shuffling References: https://arxiv.org/pdf/1609.05158.pdf https://arxiv.org/pdf/1806.02658.pdf (additional blurring at the end) """ def __init__( self, in_channels: int, out_channels: int, upscale_factor: int, activation=nn.ReLU, normalization=nn.BatchNorm2d, blur: bool=True ): """ :param in_channels: input channels :param out_channels: output channels :param upscale_factor: factor to increase spatial resolution by :param activation: activation function :param normalization: normalization function :param: whether to blur at the end to remove checkerboard artifact """ conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels * (upscale_factor ** 2), # kernel_size=3, kernel_size=1, padding=0, ) self.initialize_conv(conv, in_channels, out_channels, upscale_factor) layers = [ # nn.ReplicationPad2d(1), conv, activation(), normalization(num_features=out_channels * (upscale_factor ** 2)), nn.PixelShuffle(upscale_factor) ] if blur: layers += [nn.ReplicationPad2d((1, 0, 1, 0)), nn.AvgPool2d(kernel_size=2, stride=1)] super(PixelShuffleConvolutionLayer, self).__init__(*layers)
[docs] def initialize_conv(self, conv, in_channels: int, out_channels: int, upscale_factor: int): """ Initialize according to: https://arxiv.org/pdf/1707.02937.pdf :param conv: :param in_channels: :param out_channels: :param upscale_factor: :return: """ from torch.nn.init import kaiming_uniform_ import math weight_tensor = torch.rand(out_channels, in_channels, 1, 1) kaiming_uniform_(weight_tensor, a=math.sqrt(5)) weight_tensor = weight_tensor.repeat((upscale_factor ** 2, 1, 1, 1)) conv.weight.data.copy_(weight_tensor)