nntoolbox.callbacks.regularization module

A few regularizers, implemented as callbacks (UNTESTED)

class nntoolbox.callbacks.regularization.ActivationRegularization(output_hook: nntoolbox.hooks.io.OutputHook, regularizer: Callable[[torch.Tensor], torch.Tensor], lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.callbacks.Callback

Regularization by penalizing activations

after_losses(losses: Dict[str, torch.Tensor], train: bool) → Dict[str, torch.Tensor][source]
on_train_end()[source]
class nntoolbox.callbacks.regularization.L1AR(output_hook: nntoolbox.hooks.io.OutputHook, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.ActivationRegularization

class nntoolbox.callbacks.regularization.L1TAR(output_hook: nntoolbox.hooks.io.OutputHook, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.TemporalActivationRegularization

class nntoolbox.callbacks.regularization.L1WR(lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.WeightRegularization

class nntoolbox.callbacks.regularization.L2AR(output_hook: nntoolbox.hooks.io.OutputHook, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.ActivationRegularization

class nntoolbox.callbacks.regularization.L2TAR(output_hook: nntoolbox.hooks.io.OutputHook, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.TemporalActivationRegularization

class nntoolbox.callbacks.regularization.L2WR(lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.WeightRegularization

class nntoolbox.callbacks.regularization.StudentTPenaltyAR(output_hook: nntoolbox.hooks.io.OutputHook, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.ActivationRegularization

Student’s T Activation Regularization:

omega(t) = sum_i log(1 + t_i^2)

class nntoolbox.callbacks.regularization.TemporalActivationRegularization(output_hook: nntoolbox.hooks.io.OutputHook, regularizer: Callable[[torch.Tensor], torch.Tensor], lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.callbacks.Callback

Regularizing by penalizing activation change.

References:

Stephen Merity, Bryan McCann, Richard Socher. “Revisiting Activation Regularization for Language RNNs.” https://arxiv.org/pdf/1708.01009.pdf

after_losses(losses: Dict[str, torch.Tensor], train: bool) → Dict[str, torch.Tensor][source]
on_train_end()[source]
class nntoolbox.callbacks.regularization.WeightElimination(scale: float, lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.regularization.WeightRegularization

class nntoolbox.callbacks.regularization.WeightRegularization(regularizer: Callable[[torch.Tensor], torch.Tensor], lambd: float, loss_name: str = 'loss')[source]

Bases: nntoolbox.callbacks.callbacks.Callback

Regularization by penalizing weights

after_losses(losses: Dict[str, torch.Tensor], train: bool) → Dict[str, torch.Tensor][source]