Source code for nntoolbox.callbacks.fge

from .callbacks import Callback
from ..utils import copy_model
from typing import Dict, Any, List
from torch.nn import Module
from collections import deque


__all__ = ['FastGeometricEnsembling']


# UNTESTED
[docs]class FastGeometricEnsembling(Callback): def __init__(self, model: Module, max_n_model: int, save_after: int, save_every: int=1, timescale: str="iter"): """ https://arxiv.org/pdf/1802.10026.pdf https://arxiv.org/pdf/1704.00109.pdf :param model: the model currently being trained :param average_after: the first epoch to start averaging :param update_every: how many epochs/iters between each average update """ assert timescale == "epoch" or timescale == "iter" self._model = model self.models = deque() self._save_every = save_every self._save_after = save_after self._timescale = timescale self._max_n_model = max_n_model self.learner = None
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if self._timescale == "epoch": if logs["epoch"] >= self._save_after and (logs["epoch"] - self._save_after) % self._save_every == 0: self.models.append(copy_model(self._model)) if len(self.models) > self._max_n_model: self.models.pop() print("Save model after epoch " + str(logs["epoch"])) return False
[docs] def on_batch_end(self, logs: Dict[str, Any]): if self._timescale == "iter": if logs["iter_cnt"] >= self._save_after and (logs["iter_cnt"] - self._save_after) % self._save_every == 0: self.models.append(copy_model(self._model)) if len(self.models) > self._max_n_model: self.models.pop() print("Save model after iteration " + str(logs["iter_cnt"]))
[docs] def on_train_end(self): self.models = list(self.models) self.models = [model.to(self.learner.device) for model in self.models]
[docs] def get_models(self) -> List[Module]: """ Return the post-training average model :return: the averaged model """ return list(self.models)