Source code for nntoolbox.optim.layerwise

"""Scaling the learning rate layerwise (HIGHLY EXPERIMENTAL)"""


from torch.optim import SGD, Adam
import torch
from typing import Callable, Tuple
from torch import Tensor


__all__ = ['LARS', 'LAMB']


[docs]class LARS(SGD): """ Implement Layer-wise Adaptive Rate Scaling (LARS) algorithm for training with large batch and learning rate References: https://arxiv.org/pdf/1708.03888.pdf """ def __init__( self, params, lr: float, momentum: float=0.0, weight_decay: float=0.0, trust_coefficient: float=0.001, eps: float=1e-8 ): super(LARS, self).__init__( params, lr, momentum, 0.0, weight_decay, False ) self.trust_coefficient, self.eps = trust_coefficient, eps
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data data_norm = p.data.norm(2) grad_norm = p.grad.data.norm(2) local_lr = ( self.trust_coefficient * data_norm / (grad_norm + weight_decay * data_norm + self.eps) ).detach() if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = (group['lr'] * local_lr * torch.clone(d_p)).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(group['lr'] * local_lr, d_p) d_p = buf p.data.add_(-d_p) return loss
[docs]class LAMB(Adam): """ Implement LAMB algorithm for training with large batch and learning rate Note that in second version of the paper, bias correction for betas is missing. References: https://arxiv.org/pdf/1904.00962.pdf """ def __init__( self, params, lr: float=1e-3, betas: Tuple[float, float]=(0.9, 0.999), eps: float=1e-8, weight_decay: float=0, scaling_fn: Callable[[Tensor], Tensor] = lambda x: x, amsgrad: bool=False, correct_bias: bool=True ): super(LAMB, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) self.scaling_fn = scaling_fn self.correct_bias = correct_bias
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('LAMB does not support sparse gradients.') amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 weight_decay_term = group['weight_decay'] * p.data # decouple the weight decay # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) if self.correct_bias: exp_avg = exp_avg / (1 - beta1 ** state['step']) exp_avg_sq = exp_avg_sq / (1 - beta2 ** state['step']) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group['eps']) else: denom = exp_avg_sq.sqrt().add_(group['eps']) direction = exp_avg / denom + weight_decay_term # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group['lr'] * self.scaling_fn(p.data.norm()) / (direction.norm() + group['eps']) p.data.add_(-step_size, direction) return loss