Source code for neuralnet_pytorch.optim.nadam

import torch as T
from torch import optim

[docs]class NAdam(optim.Adam): """ Adaptive moment with Nesterov gradients. Parameters ---------- params iterable of parameters to optimize or dicts defining parameter groups lr learning rate (default: 1e-3) betas coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps term added to the denominator to improve numerical stability (default: 1e-8) weight_decay weight decay (L2 penalty) (default: 0) decay a decay scheme for `betas[0]`. Default: :math:`\\beta * (1 - 0.5 * 0.96^{\\frac{t}{250}})` where `t` is the training step. """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, decay=lambda x, t: x * (1. - .5 * .96 ** (t / 250.))): super().__init__(params, lr, betas, eps, weight_decay) self.decay = decay def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = if grad.is_sparse: raise RuntimeError('NAdam does not support sparse gradients, please consider SparseAdam instead') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = T.zeros_like( # Exponential moving average of squared gradient values state['exp_avg_sq'] = T.zeros_like( # Beta1 accumulation state['beta1_cum'] = 1. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 if group['weight_decay'] != 0: grad.add_(group['weight_decay'], beta1_t = self.decay(beta1, state['step']) beta1_tp1 = self.decay(beta1, state['step'] + 1.) beta1_cum = state['beta1_cum'] * beta1_t g_hat_t = grad / (1. - beta1_cum) exp_avg.mul_(beta1).add_(1. - beta1, grad) m_hat_t = exp_avg / (1. - beta1_cum * beta1_tp1) exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) v_hat_t = exp_avg_sq / (1. - beta2 ** state['step']) m_bar_t = (1. - beta1) * g_hat_t + beta1_tp1 * m_hat_t denom = v_hat_t.sqrt().add_(group['eps'])['lr'], m_bar_t, denom) state['beta1_cum'] = beta1_cum return loss