Source code for nntoolbox.callbacks.lr_scheduler

from torch.optim.lr_scheduler import ReduceLROnPlateau
from .callbacks import Callback
from torch.optim import Optimizer
from typing import Dict, Any


__all__ = ['LRSchedulerCB', 'ReduceLROnPlateauCB']


[docs]class LRSchedulerCB(Callback): def __init__(self, scheduler, timescale: str="iter"): assert timescale == "epoch" or timescale == "iter" self._scheduler = scheduler self._timescale = timescale
[docs] def on_batch_end(self, logs: Dict[str, Any]): if self._timescale == "iter": self._scheduler.step()
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if self._timescale == "epoch": self._scheduler.step() return False
[docs]class ReduceLROnPlateauCB(Callback): def __init__( self, optimizer: Optimizer, monitor: str='accuracy', mode: str='max', factor: float=0.1, patience: int=10, verbose: bool=True, threshold: float=0.0001, threshold_mode: str='rel', cooldown: int=0, min_lr: float=0, eps: float=1e-08 ): self._scheduler = ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) self._monitor = monitor
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: if "epoch_metrics" in logs: assert self._monitor in logs["epoch_metrics"] self._scheduler.step(logs["epoch_metrics"][self._monitor]) return False