nntoolbox.utils.lr_finder module

class nntoolbox.utils.lr_finder.LRFinder(model: torch.nn.modules.module.Module, train_data: torch.utils.data.dataloader.DataLoader, criterion: torch.nn.modules.module.Module, optimizer: Callable[[], torch.optim.optimizer.Optimizer], device: torch.device)[source]

Bases: object

Leslie Smith’s learning rate range finder.

Adapt from https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html

https://arxiv.org/pdf/1506.01186.pdf

find_lr(lr0: float = 1e-07, lr_final: float = 10.0, warmup: int = 15, beta: float = 0.67, verbose: bool = True, display: bool = True, callbacks: Optional[List[Callback]] = None) → Tuple[float, float][source]

Start from a very low initial learning rate, then gradually increases it up to a big lr until loss blows up

Parameters
  • lr0 – intitial learning rate

  • lr_final – final (max) learning rate

  • warmup – how many iterations to warmup

  • beta – smoothing coefficient for loss

  • verbose – whether to print out the progress

  • display – whether to graph

  • callbacks – an optional list of callbacks to process input

Returns

a base_lr and the best lr (base_lr = best_lr / 4)