Source code for nntoolbox.callbacks.transfer

from .callbacks import Callback, GroupCallback
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, Module, Sequential
from torch.optim import Optimizer
from typing import List, Dict, Any, Optional, Union


__all__ = ['FreezeBN', 'GradualUnfreezing', 'FineTuning']
BN_TYPE = [BatchNorm1d, BatchNorm2d, BatchNorm3d]


[docs]class FreezeBN(Callback): """ Freeze statistics of non trainable batch norms so that it won't accumulate statistics (UNTESTED) """
[docs] def on_epoch_begin(self): freeze_bn(self.learner._model)
def freeze_bn(module: Module): for submodule in module.modules(): for bn_type in BN_TYPE: if isinstance(submodule, bn_type): if not next(submodule.parameters()).requires_grad: submodule.eval() # freeze_bn(submodule) def unfreeze(module: Sequential, optimizer: Optimizer, unfreeze_from: int, unfreeze_to: int, **kwargs): """ Unfreeze a model from ind :param module: :param optimizer :param unfreeze_from: :param unfreeze_to: :return: """ for ind in range(len(module)): submodule = module._modules[str(ind)] if ind < unfreeze_from: for param in submodule.parameters(): param.requires_grad = False elif ind < unfreeze_to: for param in submodule.parameters(): param.requires_grad = True optimizer.add_param_group({'params': submodule.parameters(), **kwargs})
[docs]class GradualUnfreezing(Callback): """ Gradually unfreezing pretrained layers, with discriminative learning rates (UNTESTED) """ def __init__( self, unfreeze_every: int, freeze_inds: Optional[List[int]]=None, lr: Optional[Union[List[float], float]]=None ): self._freeze_inds = freeze_inds self._unfreeze_every = unfreeze_every if lr is None: self.lr = None else: if isinstance(lr, list): assert len(lr) == len(freeze_inds) self.lr = lr else: self.lr = [lr for _ in range(len(freeze_inds))]
[docs] def on_train_begin(self): n_layer = len(self.learner._model._modules['0']) if self._freeze_inds is None: self._freeze_inds = [n_layer - 1 - i for i in range(n_layer)] self._freeze_inds = [n_layer] + self._freeze_inds
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if logs['epoch'] % self._unfreeze_every == 0 \ and logs['epoch'] > 0 \ and logs['epoch'] // self._unfreeze_every < len(self._freeze_inds): unfreeze_from = self._freeze_inds[logs['epoch'] // self._unfreeze_every] unfreeze_to = self._freeze_inds[logs['epoch'] // self._unfreeze_every - 1] if self.lr is not None: unfreeze( self.learner._model._modules['0'], self.learner._optimizer, unfreeze_from, unfreeze_to, lr=self.lr[logs['epoch'] // self._unfreeze_every - 1] ) else: unfreeze( self.learner._model._modules['0'], self.learner._optimizer, unfreeze_from, unfreeze_to, ) print("Unfreeze feature after " + str(unfreeze_from)) return False
[docs]class FineTuning(GroupCallback): """ Combining freezing batch norm and gradual unfreezing of layer """ def __init__(self, unfreeze_every: int, freeze_inds: Optional[List[int]]=None, lr: Optional[Union[List[float], float]]=None): super(FineTuning, self).__init__( [ GradualUnfreezing(unfreeze_every, freeze_inds, lr), FreezeBN() ] )