Source code for nntoolbox.callbacks.bptt

from .callbacks import Callback
from ..optim import change_lr, get_lr
import numpy as np
from typing import Dict, Any


__all__ = ['VariableLengthBPTT']


[docs]class VariableLengthBPTT(Callback): """ Change the truncated backprop through time length and linearly scale the learning rate. (UNTESTED) References: Stephen Merity, Nitish Shirish Keskar, Richard Socher. "Regularizing and Optimizing LSTM Language Models." https://arxiv.org/abs/1708.02182 """ def __init__(self, default_len: int, p: float, std: float): assert 0.0 < p < 1.0 assert std > 0.0 self.default_len, self.p, self.std = default_len, p, std self.original_lr = None
[docs] def on_epoch_begin(self): base_length = np.random.choice([self.default_len, self.default_len / 2], p=[self.p, 1.0 - self.p]) epoch_length = min(max(int(np.random.normal(base_length)), 1), self.default_len) self.learner._train_iterator.bptt_len = epoch_length self.original_lr = get_lr(self.learner._optimizer) new_lr = [lr * epoch_length / self.default_len for lr in self.original_lr] change_lr(self.learner._optimizer, new_lr)
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: change_lr(self.learner._optimizer, self.original_lr) return False