Source code for nntoolbox.vision.components.regularization
from torch import nn, Tensor
import torch
from typing import Tuple, Optional
__all__ = ['ShakeShake']
[docs]class ShakeShake(nn.Module):
"""
Implement shake-shake regularizer:
y = x + sum_i alpha_i branch_i
(alpha_i > 0 are random variables such that sum_i alpha_i = 1)
At test time:
y = x + 1 / n_branch sum_i branch_i
Based on https://arxiv.org/abs/1705.07485
"""
def __init__(self, keep: str='shake'):
super(ShakeShake, self).__init__()
self._keep = keep
[docs] def forward(self, branches: Tensor, training: bool) -> Tensor:
return ShakeShakeFunction.apply(branches, training, self._keep)
class ShakeShakeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, branches: Tensor, training: bool, mode: str) -> Tensor:
"""
:param ctx: context (to save info for backward pass)
:param branches: outputs of all branches concatenated (cardinality, batch_size, n_channel, h, w)
:param training: boolean, true if is training
:param mode: 'keep': keep the forward weights for backward; 'even': backward with 1/n_branch weight;
'shake': randomly choose new weights
:return: weighted sum of all branches' outputs
"""
if training:
branch_weights = ShakeShakeFunction.get_branch_weights(
len(branches), branches[0].shape[0]
).to(branches.dtype).to(branches.device)
output = branch_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * branches
if mode == 'keep':
ctx.save_for_backward(torch.ones(1), branch_weights)
elif training and mode == 'even':
ctx.save_for_backward(-torch.ones(1), len(branches) * torch.ones(1).int())
else: # shake mode
ctx.save_for_backward(torch.zeros(1), len(branches) * torch.ones(1).int())
return torch.sum(output, dim=0)
else:
return torch.mean(branches, dim=0)
@staticmethod
def backward(ctx, grad_output) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if ctx.saved_tensors[0] == 1: # keep mode
branch_weights = ctx.saved_tensors[1]
elif ctx.saved_tensors[0] == -1: # even mode:
cardinality = ctx.saved_tensors[1].item()
branch_weights = 1.0 / cardinality * torch.ones(
size=(cardinality, grad_output.shape[0])
).to(grad_output.device)
else: # shake mode
cardinality = ctx.saved_tensors[1].item()
branch_weights = ShakeShakeFunction.get_branch_weights(
cardinality, grad_output.shape[0]
).to(grad_output.device)
return branch_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(grad_output.dtype) * grad_output, None, None
@staticmethod
def get_branch_weights(cardinality: Tensor, batch_size: Tensor) -> Tensor:
branch_weights = torch.rand(size=(cardinality, batch_size))
branch_weights /= torch.sum(branch_weights, dim=0, keepdim=True)
return branch_weights