Source code for neuralnet_pytorch.optim.lookahead

from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer


[docs]class Lookahead(Optimizer): """ PyTorch implementation of the lookahead wrapper. Lookahead Optimizer: https://arxiv.org/abs/1907.08610. Parameters ---------- optimizer an usual optimizer such as Adam. la_steps number of lookahead steps. Default: 5. alpha linear interpolation coefficient. Default: 0.8. pullback_momentum either ``'reset'``, ``'pullback'``, or ``None``. Default: ``None``. """ def __init__(self, optimizer, la_steps=5, alpha=0.8, pullback_momentum=None): self.optimizer = optimizer self._la_step = 0 # counter for inner optimizer self.alpha = alpha self._total_la_steps = la_steps assert pullback_momentum in ['reset', 'pullback', None] self.pullback_momentum = pullback_momentum self.state = defaultdict(dict) # Cache the current optimizer parameters for group in optimizer.param_groups: for p in group['params']: param_state = self.state[p] param_state['cached_params'] = torch.zeros_like(p.data) param_state['cached_params'].copy_(p.data) if self.pullback_momentum == 'pullback': param_state['cached_mom'] = torch.zeros_like(p.data) def __getstate__(self): return { 'state': self.state, 'optimizer': self.optimizer, 'la_alpha': self.alpha, '_la_step': self._la_step, '_total_la_steps': self._total_la_steps, 'pullback_momentum': self.pullback_momentum } def zero_grad(self): self.optimizer.zero_grad() @property def la_step(self): return self._la_step @la_step.setter def la_step(self, value): self._la_step = value def state_dict(self): return self.optimizer.state_dict() def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict) def _backup_and_load_cache(self): for group in self.optimizer.param_groups: for p in group['params']: param_state = self.state[p] param_state['backup_params'] = torch.zeros_like(p.data) param_state['backup_params'].copy_(p.data) p.data.copy_(param_state['cached_params']) def _clear_and_load_backup(self): for group in self.optimizer.param_groups: for p in group['params']: param_state = self.state[p] p.data.copy_(param_state['backup_params']) del param_state['backup_params'] @property def param_groups(self): return self.optimizer.param_groups def step(self, closure=None): loss = self.optimizer.step(closure) self.la_step += 1 if self.la_step >= self._total_la_steps: self.la_step = 0 # Lookahead and cache the current optimizer parameters for group in self.optimizer.param_groups: for p in group['params']: param_state = self.state[p] p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line param_state['cached_params'].copy_(p.data) if self.pullback_momentum == 'pullback': internal_momentum = self.optimizer.state[p]['momentum_buffer'] self.optimizer.state[p]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_( 1.0 - self.alpha, param_state['cached_mom']) param_state['cached_mom'] = self.optimizer.state[p]['momentum_buffer'] elif self.pullback_momentum == 'reset': self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p.data) return loss