Source code for nntoolbox.callbacks.regularization

"""A few regularizers, implemented as callbacks (UNTESTED)"""
import torch
from torch import Tensor
from .callbacks import Callback
from ..hooks import OutputHook
from typing import Dict, Callable


__all__ = [
    'WeightRegularization', 'WeightElimination', 'L1WR', 'L2WR',
    'ActivationRegularization', 'L1AR', 'L2AR', 'StudentTPenaltyAR',
    'TemporalActivationRegularization', 'L1TAR', 'L2TAR'
]


[docs]class WeightRegularization(Callback): """Regularization by penalizing weights""" def __init__(self, regularizer: Callable[[Tensor], Tensor], lambd: float, loss_name: str='loss'): self.loss_name = loss_name self.regularizer = regularizer self.lambd = lambd
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]: assert self.loss_name in losses reg = 0.0 for p in self.learner._model.parameters(): reg = reg + self.regularizer(p.data) losses[self.loss_name] += self.lambd * reg return losses
[docs]class WeightElimination(WeightRegularization): def __init__(self, scale: float, lambd: float, loss_name: str='loss'): assert scale > 0.0 def weight_elimination(t: Tensor) -> Tensor: t_sq = t.pow(2) return t_sq / (t_sq + scale ** 2).sum() super().__init__( regularizer=weight_elimination, lambd=lambd, loss_name=loss_name )
[docs]class L1WR(WeightRegularization): def __init__(self, lambd: float, loss_name: str='loss'): super(L1WR, self).__init__( regularizer=lambda t: t.norm(1).mean(), lambd=lambd, loss_name=loss_name )
[docs]class L2WR(WeightRegularization): def __init__(self, lambd: float, loss_name: str='loss'): super(L2WR, self).__init__( regularizer=lambda t: t.norm(2).mean(), lambd=lambd, loss_name=loss_name )
[docs]class ActivationRegularization(Callback): """Regularization by penalizing activations""" def __init__( self, output_hook: OutputHook, regularizer: Callable[[Tensor], Tensor], lambd: float, loss_name: str='loss' ): """ :param output_hook: output hook of the module we want to regularize :param regularizer: regularization function (e.g L2) :param loss_name: name of the loss stored in loss logs. Default to 'loss' """ self.hook = output_hook self.loss_name = loss_name self.regularizer = regularizer self.lambd = lambd
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]: if train: assert self.loss_name in losses outputs = self.hook.store if isinstance(outputs, tuple): outputs = outputs[0] losses[self.loss_name] += self.regularizer(outputs) * self.lambd self.hook.store = None return losses
[docs] def on_train_end(self): self.hook.remove()
[docs]class L2AR(ActivationRegularization): def __init__(self, output_hook: OutputHook, lambd: float, loss_name: str='loss'): super(L2AR, self).__init__( output_hook=output_hook, regularizer=lambda t: t.norm(2).mean(), lambd=lambd, loss_name=loss_name )
[docs]class L1AR(ActivationRegularization): def __init__(self, output_hook: OutputHook, lambd: float, loss_name: str='loss'): super(L1AR, self).__init__( output_hook=output_hook, regularizer=lambda t: t.norm(1).mean(), lambd=lambd, loss_name=loss_name )
[docs]class StudentTPenaltyAR(ActivationRegularization): """ Student's T Activation Regularization: omega(t) = sum_i log(1 + t_i^2) """ def __init__(self, output_hook: OutputHook, lambd: float, loss_name: str='loss'): super(StudentTPenaltyAR, self).__init__( output_hook=output_hook, regularizer=lambda t: torch.log1p(t.pow(2)).mean(), lambd=lambd, loss_name=loss_name )
class LowActivityPrior(ActivationRegularization): """ Constraint the activation to be small. Coupling with a variance force, this will drive the activation to sparsity. (UNTESTED) References: Sven Behnke. "Hierarchical Neural Networks for Image Interpretation," page 124. https://www.ais.uni-bonn.de/books/LNCS2766.pdf """ def __init__(self, output_hook: OutputHook, lambd: float, alpha: float=0.1, loss_name: str='loss'): super(LowActivityPrior, self).__init__( output_hook=output_hook, regularizer=lambda t: (t.mean(0) - alpha).pow(2).mean(), lambd=lambd, loss_name=loss_name )
[docs]class TemporalActivationRegularization(Callback): """ Regularizing by penalizing activation change. References: Stephen Merity, Bryan McCann, Richard Socher. "Revisiting Activation Regularization for Language RNNs." https://arxiv.org/pdf/1708.01009.pdf """ def __init__( self, output_hook: OutputHook, regularizer: Callable[[Tensor], Tensor], lambd: float, loss_name: str = 'loss' ): self.lambd, self.loss_name, self.regularizer, self.hook = lambd, loss_name, regularizer, output_hook
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]: if train: assert self.loss_name in losses outputs = self.hook.store if isinstance(outputs, tuple): outputs = outputs[0] states_change = outputs[:len(outputs) - 1] - outputs[1:] losses[self.loss_name] = self.regularizer(states_change) * self.lambd + losses[self.loss_name] self.hook.store = None return losses
[docs] def on_train_end(self): self.hook.remove()
[docs]class L2TAR(TemporalActivationRegularization): def __init__( self, output_hook: OutputHook, lambd: float, loss_name: str = 'loss' ): super(L2TAR, self).__init__( output_hook=output_hook, regularizer=lambda t: t.norm(2).mean(), lambd=lambd, loss_name=loss_name )
[docs]class L1TAR(TemporalActivationRegularization): def __init__( self, output_hook: OutputHook, lambd: float, loss_name: str = 'loss' ): super(L1TAR, self).__init__( output_hook=output_hook, regularizer=lambda t: t.norm(1).mean(), lambd=lambd, loss_name=loss_name )