Source code for nntoolbox.callbacks.warmup
"""Learning rate warmup (UNTESTED)"""
from .callbacks import Callback
from typing import Dict, Any
from torch import Tensor
__all__ = ['LRWarmup', 'ConstantLRWarmup', 'GradualLRWarmup']
[docs]class LRWarmup(Callback):
"""
Start training with a small learning rate
References:
Priya Goyal et al. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour."
https://arxiv.org/abs/1706.02677
"""
def __init__(self, duration: int, timescale: str="iter"):
self.order = 99
self.duration = duration
self.timescale = timescale
self.cur = 0
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train) -> Dict[str, Tensor]:
if self.timescale == "iter":
if self.cur < self.duration:
self.update_lr()
return data
[docs] def on_epoch_begin(self):
if self.timescale == "epoch":
if self.cur < self.duration:
self.update_lr()
[docs] def update_lr(self):
for param_group in self.learner._optimizer.param_groups:
param_group['lr'] = self.get_lr()
self.cur += 1
[docs] def get_lr(self) -> float: pass
[docs]class ConstantLRWarmup(LRWarmup):
"""Keeping the learning rate at a small value for several iterations/epochs"""
def __init__(self, min_lr, duration: int, timescale: str="iter"):
super().__init__(duration, timescale)
self.min_lr = min_lr
[docs] def get_lr(self) -> float: return self.min_lr
[docs]class GradualLRWarmup(LRWarmup):
"""Gradually increase the learning rate from a small value for several iterations/epochs"""
def __init__(self, min_lr: float, max_lr: float, duration: int, timescale: str="iter"):
assert min_lr < max_lr
super().__init__(duration, timescale)
self.min_lr, self.max_lr = min_lr, max_lr
[docs] def get_lr(self) -> float: return self.min_lr + (self.max_lr - self.min_lr) * (self.cur / self.duration)