Source code for common.pytorch.optim.Lamb

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
    Copyright cybertonai and Cerebras, see LICENSE_LambOptimizer
"""

import torch

from modelzoo.common.pytorch.optim.CSOptimizer import CSOptimizer


[docs]class Lamb(CSOptimizer): r"""Implements Lamb algorithm. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) adam (bool, optional): always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 """
[docs] def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, adam=adam ) super(Lamb, self).__init__(params, defaults)
[docs] def state_names_to_sparsify(self): return ["exp_avg", "exp_avg_sq"]
[docs] def preinitialize(self): """ Allocates tensors for the optimizer state to allow direct compilation of the model before the first step. """ for group in self.param_groups: for p in group['params']: state = self.state[p] # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, device="cpu").to( p.device ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, device="cpu").to( p.device )
@torch.no_grad() def step(self, closure=None): r"""Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError( 'Lamb does not support sparse gradients, consider SparseAdam instad.' ) state = self.state[p] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] # Decay the first and second moment running average coefficient # m_t exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # v_t exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # Paper v3 does not use debiasing. # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. step_size = group[ 'lr' ] # * math.sqrt(bias_correction2) / bias_correction1 weight_norm = p.pow(2).sum().sqrt().clamp(0, 10).to(torch.float) adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) if group['weight_decay'] != 0: adam_step.add_(p, alpha=group['weight_decay']) adam_norm = adam_step.pow(2).sum().sqrt().to(torch.float) # pytorch version for future reference (we don't support weight_norm == 0 or adam_norm == 0) # if weight_norm == 0 or adam_norm == 0: # trust_ratio = 1 # else: # trust_ratio = weight_norm / adam_norm zero = torch.tensor( 0.0, dtype=torch.float32, device=weight_norm.device ) trust_ratio = torch.where( torch.gt(weight_norm, zero), torch.where( torch.gt(adam_norm, zero), weight_norm / adam_norm, torch.tensor( 1.0, dtype=torch.float32, device=weight_norm.device ), ), torch.tensor( 1.0, dtype=torch.float32, device=weight_norm.device ), ) if group['adam']: trust_ratio = 1 update_step = adam_step.mul(trust_ratio) p.sub_(update_step * step_size) return loss