Source code for nntoolbox.optim.lr_scheduler
from torch.optim.lr_scheduler import LambdaLR, Optimizer, _LRScheduler
from .utils import plot_schedule
from typing import Optional, Callable, List
__all__ = ['FunctionalLR', 'CyclicalTriangularLR', 'TriangularLR']
[docs]class FunctionalLR(LambdaLR):
"""
Calculate learning rate based on a function
"""
def __init__(self, optimizer: Optimizer, schedule_fn: Callable[[int], float], last_epoch: int=-1):
super(FunctionalLR, self).__init__(optimizer=optimizer, lr_lambda=schedule_fn, last_epoch=last_epoch)
[docs] def get_lr(self) -> List[float]:
return [lmbda(self.last_epoch) for lmbda in self.lr_lambdas]
# UNTESTED
[docs]class CyclicalTriangularLR(FunctionalLR):
def __init__(self, optimizer: Optimizer, min_lr: float, max_lr: float, cycle_length: int, inc_fraction: float):
"""
Cyclical (slanted) triangular LR, based on:
https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/learning_rate_schedules_advanced.html
:param optimizer: pytorch optimizer
:param min_lr: minimum learning rate
:param max_lr: maximum learning rate
:param cycle_length: length of each cycle (i.e from one min to another)
:param inc_fraction: (fraction of cycle length to reach max)
"""
assert inc_fraction > 0.0 and inc_fraction < 1.0
def schedule_fn(iter: int) -> float:
iter %= cycle_length
peak_iter = int(inc_fraction * cycle_length)
if iter <= peak_iter:
unit_cycle = iter / cycle_length / inc_fraction
else:
unit_cycle = (cycle_length - iter) / cycle_length / (1 - inc_fraction)
return unit_cycle * (max_lr - min_lr) + min_lr
super(CyclicalTriangularLR, self).__init__(optimizer, schedule_fn=schedule_fn)
# self.iter = 0
# UNTESTED
[docs]class TriangularLR(FunctionalLR):
def __init__(self, optimizer: Optimizer, min_lr: float, max_lr: float, cycle_length: int, inc_fraction: float):
"""
One cycle (slanted) triangular LR, based on:
https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/learning_rate_schedules_advanced.html
:param optimizer: pytorch optimizer
:param min_lr: minimum learning rate
:param max_lr: maximum learning rate
:param cycle_length: length of each cycle (i.e from one min to another)
:param inc_fraction: (fraction of cycle length to reach max)
"""
assert inc_fraction > 0.0 and inc_fraction < 1.0
def schedule_fn(iter: int) -> float:
peak_iter = int(inc_fraction * cycle_length)
if iter <= peak_iter:
unit_cycle = iter / cycle_length / inc_fraction
elif iter < cycle_length:
unit_cycle = (cycle_length - iter) / cycle_length / (1 - inc_fraction)
else: unit_cycle = 0.0
return unit_cycle * (max_lr - min_lr) + min_lr
super(TriangularLR, self).__init__(optimizer, schedule_fn=schedule_fn)
# self.iter = 0
#
# def step(self, iter: Optional[int] = None):
# if iter is not None:
# super().step(iter)
# else:
# self.iter += 1
# super().step(self.iter)