import torch
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Tuple
import numpy as np
__all__ = ['L2NormalizationLayer', 'AdaIN', 'SelfStabilizer', 'BatchRenorm2D']
[docs]class L2NormalizationLayer(nn.Module):
def __init__(self):
super(L2NormalizationLayer, self).__init__()
[docs] def forward(self, input):
return F.normalize(input, dim=-1, p=2)
[docs]class AdaIN(nn.Module):
"""
Implement adaptive instance normalization layer
"""
def __init__(self):
super(AdaIN, self).__init__()
self._style = None
[docs] def forward(self, input: Tensor) -> Tensor:
if self._style is None:
self.set_style(input)
return self.forward(input)
else:
input_mean, input_std = AdaIN.compute_mean_std(input) # (batch_size, C, H, W)
style_mean, style_std = AdaIN.compute_mean_std(self._style) # (batch_size, C, H, W)
return (input - input_mean) / (input_std + 1e-8) * style_std + style_mean
[docs] def set_style(self, style):
self._style = style
[docs] @staticmethod
def compute_mean_std(images: Tensor) -> Tuple[Tensor, Tensor]:
"""
:param images: (n_img, C, H, W)
:return: (n_img, C, H, W)
"""
images_reshaped = images.view(images.shape[0], images.shape[1], -1)
images_mean = images_reshaped.mean(2).unsqueeze(-1).unsqueeze(-1)
images_std = images_reshaped.std(2).unsqueeze(-1).unsqueeze(-1)
return images_mean, images_std
[docs]class SelfStabilizer(nn.Module):
"""
Self stabilize layer, based on:
https://www.cntk.ai/pythondocs/cntk.layers.blocks.html
https://www.cntk.ai/pythondocs/_modules/cntk/layers/blocks.html#Stabilizer
https://www.cntk.ai/pythondocs/layerref.html#batchnormalization-layernormalization-stabilizer
https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/SelfLR.pdf
"""
def __init__(self, steepness: float=4.0):
super(SelfStabilizer, self).__init__()
self.steepness = steepness
self.param = nn.Parameter(torch.tensor(np.log(np.exp(steepness) - 1) / steepness))
[docs] def forward(self, input: Tensor) -> Tensor:
return F.softplus(self.param, beta=self.steepness) * input
[docs]class BatchRenorm2D(nn.Module):
"""
Modified from batch norm implementation in FastAI course 2 v3's notebook. Works better for smaller batches
(UNTESTED)
References:
https://github.com/fastai/course-v3/blob/master/nbs/dl2/07_batchnorm.ipynb
Ioffe, Sergey. "Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models."
https://arxiv.org/pdf/1702.03275.pdf
"""
def __init__(self, num_features: int, r_max: float, d_max: float, eps: float=1e-6, momentum: float=0.1):
assert r_max > 1.0
assert d_max > 0.0
assert 0.0 < momentum < 1.0
super(BatchRenorm2D, self).__init__()
self.weight = nn.Parameter(torch.Tensor(num_features))
self.bias = nn.Parameter(torch.Tensor(num_features))
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
self.r_max, self.d_max = r_max, d_max
self.eps, self.momentum = eps, momentum
[docs] def update_stats(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
batch_mean = input.mean((0, 2, 3), keepdim=True)
batch_var = input.var((0, 2, 3), keepdim=True)
batch_std = (batch_var + self.eps).sqrt()
running_std = (self.running_var + self.eps).sqrt()
r = torch.clamp(batch_std / running_std, min=1 / self.r_max, max=self.r_max).detach()
d = torch.clamp((batch_mean - self.running_mean) / running_std, min=-self.d_max, max=self.d_max).detach()
self.running_mean.lerp_(batch_mean, self.momentum)
self.running_var.lerp_ (batch_var, self.momentum)
return batch_mean, batch_std, r, d
[docs] def forward(self, input: Tensor) -> Tensor:
if self.training:
with torch.no_grad():
mean, std, r, d = self.update_stats(input)
input = (input - mean) / std * r + d
else:
mean, std = self.running_mean, self.running_var
input = (input - mean) / (self.running_var + self.eps).sqrt()
return self.weight * input + self.bias