Source code for nntoolbox.callbacks.checkpoint

from typing import Dict, Any, Optional
from ..utils import save_model, load_model
from ..optim.utils import save_optimizer, load_optimizer
from .callbacks import Callback


__all__ = ['ModelCheckpoint', 'OptimizerCheckPoint', 'ResumeFromCheckpoint']


[docs]class ModelCheckpoint(Callback): def __init__( self, learner, filepath: str, monitor: str='loss', save_best_only: bool=True, mode: str='min', period: int=1 ): self._learner = learner self._filepath = filepath self._monitor = monitor self._period = period self._mode = mode self._save_best_only = save_best_only self._metrics = []
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if self._save_best_only: epoch_metrics = logs['epoch_metrics'] assert self._monitor in epoch_metrics self._metrics.append(epoch_metrics[self._monitor]) if self._mode == "min": if epoch_metrics[self._monitor] == min(self._metrics): save_model(self._learner._model, self._filepath) else: if epoch_metrics[self._monitor] == max(self._metrics): save_model(self._learner._model, self._filepath) else: save_model(self._learner._model, self._filepath) return False
[docs]class OptimizerCheckPoint(Callback): def __init__( self, filepath: str, monitor: str='loss', save_best_only: bool=True, mode: str='min', period: int=1 ): self._filepath = filepath self._monitor = monitor self._period = period self._mode = mode self._save_best_only = save_best_only self._metrics = []
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if self._save_best_only: epoch_metrics = logs['epoch_metrics'] assert self._monitor in epoch_metrics self._metrics.append(epoch_metrics[self._monitor]) if self._mode == "min": if epoch_metrics[self._monitor] == min(self._metrics): save_optimizer(self.learner._optimizer, self._filepath) else: if epoch_metrics[self._monitor] == max(self._metrics): save_optimizer(self.learner._optimizer, self._filepath) else: save_optimizer(self.learner._optimizer, self._filepath) return False
# UNTESTED
[docs]class ResumeFromCheckpoint(Callback): """ Resume from previous checkpoint """ def __init__(self, model_path: Optional[str]=None, optimizer_path: Optional[str]=None): self.model_path, self.optimizer_path = model_path, optimizer_path
[docs] def on_train_begin(self): if self.model_path is not None: try: load_model(self.learner._model, self.model_path) except: print("Load model failed.") if self.optimizer_path is not None: try: load_optimizer(self.learner._optimizer, self.optimizer_path) except: print("Load optimizer failed.")