Source code for nntoolbox.init.normal
from torch import Tensor
from torch.nn import Module
from torch.nn.init import normal_, constant_
[docs]def normal_init(module: Module, mean: float, std: float):
"""
Initialize the weight of a module to normal distribution of given mean and std
If module has bias, assign it to zero constant
(UNTESTED)
:param module: must have weight tensor
:param: mean of distribution
:param: std: standard deviation of the distribution
"""
normal_(module.weight.data, mean=mean, std=std)
if module.bias is not None:
constant_(module.bias.data, val=0.0)