Source code for nntoolbox.init.uniform

import math
from torch.nn import init, Module


__all__ = ['sqrt_uniform_init']


[docs]def sqrt_uniform_init(component: Module): for weight in component.parameters(): stdv = 1.0 / math.sqrt(weight.shape[-1]) init.uniform_(weight, -stdv, stdv)