Source code for nntoolbox.callbacks.swa
from .callbacks import Callback
from ..utils import copy_model, get_device
from typing import Dict, Any
from torch.nn import Module
__all__ = ['StochasticWeightAveraging']
[docs]class StochasticWeightAveraging(Callback):
def __init__(
self, learner, average_after: int,
update_every: int=1, timescale: str="iter", device=get_device()
):
"""
https://arxiv.org/pdf/1803.05407.pdf
:param model: the model currently being trained
:param average_after: the first epoch to start averaging
:param update_every: how many epochs/iters between each average update
"""
assert timescale == "epoch" or timescale == "iter"
self.learner = learner
self._model = learner._model
self.model_swa = copy_model(self._model).to(device)
self._update_every = update_every
self._average_after = average_after
self._timescale = timescale
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
if self._timescale == "epoch":
if logs["epoch"] >= self._average_after and (logs["epoch"] - self._average_after) % self._update_every == 0:
n_model = (logs["epoch"] - self._average_after) // self._update_every
for model_p, swa_p in zip(self._model.parameters(), self.model_swa.parameters()):
swa_p.data = (swa_p.data * n_model + model_p.data.to(swa_p.data.dtype)) / (n_model + 1)
print("Update averaged model after epoch " + str(logs["epoch"]))
return False
[docs] def on_batch_end(self, logs: Dict[str, Any]):
if self._timescale == "iter":
if logs["iter_cnt"] >= self._average_after and (logs["iter_cnt"] - self._average_after) % self._update_every == 0:
n_model = (logs["iter_cnt"] - self._average_after) // self._update_every
for model_p, swa_p in zip(self._model.parameters(), self.model_swa.parameters()):
swa_p.data = (swa_p.data * n_model + model_p.data.to(swa_p.data.dtype)) / (n_model + 1)
print("Update averaged model after iteration " + str(logs["iter_cnt"]))
[docs] def on_train_end(self):
self.model_swa.to(self.learner._device)
for images, labels in self.learner._train_data:
self.model_swa(images.to(self.learner._device))
[docs] def get_averaged_model(self) -> Module:
"""
Return the post-training average model
:return: the averaged model
"""
return self.model_swa