Source code for nntoolbox.init.lsuv

"""
Implement LSUV initialization from "ALL YOU NEED IS A GOOD INIT"
https://arxiv.org/pdf/1511.06422.pdf
Adopt from fastai
"""
from torch.nn import Module
from torch import Tensor,nn
from nntoolbox.hooks import Hook, OutputStatsHook
from ..utils import get_all_submodules
from torch.nn.init import orthogonal_


LINEAR_TYPE = [nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear]
__all__ = ['lsuv_init']


[docs]def lsuv_init(module: Module, input: Tensor, tol: float=1e-3, Tmax: int=100): """ LSUV initialization :param module: :param input: :param tol: maximum tolerance :param Tmax: maximum iterations to attempt to demean and normalize weight :return: final mean and std of each layer's output """ means, stds = [], [] for layer in get_all_submodules(module): for type in LINEAR_TYPE: if isinstance(layer, type): orthogonal_(layer.weight) # orginal paper starts with orthogonal initialization hook = OutputStatsHook(layer) # fastai suggests demean bias as well: if layer.bias is not None: t = 0 while module(input) is not None and abs(hook.stats[0][-1]) > tol and t < Tmax: layer.bias.data -= hook.stats[0][-1] t += 1 if layer.weight is not None: t = 0 while module(input) is not None and abs(hook.stats[1][-1] - 1.0) > tol and t < Tmax: layer.weight.data /= hook.stats[1][-1] t += 1 hook.remove() means.append(hook.stats[0][-1]) stds.append(hook.stats[1][-1]) return means, stds