Source code for nntoolbox.callbacks.lookahead

from .callbacks import Callback
from ..utils import copy_model, get_device
from typing import Dict, Any
from torch.nn import Module


__all__ = ['LookaheadOptimizer']


[docs]class LookaheadOptimizer(Callback): """ Lookahead Optimizer: Keep track of a set of "slow weights", which only update periodically. (UNTESTED) References: Michael R. Zhang, James Lucas, Geoffrey Hinton, Jimmy Ba. "Lookahead Optimizer: k steps forward, 1 step back." https://arxiv.org/abs/1907.08610 """ def __init__( self, step_size: float=0.5, update_every: int=1, timescale: str="iter", device=get_device() ): """ https://arxiv.org/pdf/1803.05407.pdf :param model: the model currently being trained :param step_size: the stepsize for slow weight update :param average_after: the first epoch to start averaging :param update_every: how many epochs/iters between each average update """ assert timescale == "epoch" or timescale == "iter" self.step_size = step_size self._update_every = update_every self._timescale = timescale self._device = device
[docs] def on_train_begin(self): self._model = self.learner._model self._model_slow = copy_model(self._model).to(self._device)
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if self._timescale == "epoch": if logs["epoch"] % self._update_every == 0: self.update_slow_weights() print("Update slow weights after epoch " + str(logs["epoch"])) return False
[docs] def on_batch_end(self, logs: Dict[str, Any]): if self._timescale == "iter": if logs["iter_cnt"] % self._update_every == 0: self.update_slow_weights() print("Update slow weights after iteration " + str(logs["iter_cnt"]))
[docs] def on_train_end(self): self._model_slow.to(self.learner._device) for inputs, labels in self.learner._train_data: self._model_slow(inputs.to(self.learner._device)) self.learner._model = self._model_slow
[docs] def update_slow_weights(self): for model_p, slow_p in zip(self._model.parameters(), self._model_slow.parameters()): slow_p.data.add_(self.step_size * (model_p.data.to(slow_p.data.dtype) - slow_p.data))
[docs] def get_final_model(self) -> Module: """ Return the post-training average model :return: the averaged model """ return self._model_slow