nntoolbox.callbacks.lookahead module

class nntoolbox.callbacks.lookahead.LookaheadOptimizer(step_size: float = 0.5, update_every: int = 1, timescale: str = 'iter', device=device(type='cpu'))[source]

Bases: nntoolbox.callbacks.callbacks.Callback

Lookahead Optimizer: Keep track of a set of “slow weights”, which only update periodically. (UNTESTED)

References:

Michael R. Zhang, James Lucas, Geoffrey Hinton, Jimmy Ba. “Lookahead Optimizer: k steps forward, 1 step back.” https://arxiv.org/abs/1907.08610

get_final_model() → torch.nn.modules.module.Module[source]

Return the post-training average model :return: the averaged model

on_batch_end(logs: Dict[str, Any])[source]
on_epoch_end(logs: Dict[str, Any]) → bool[source]
on_train_begin()[source]
on_train_end()[source]
update_slow_weights()[source]