Source code for experimental.amp.grad_scaler

"""Cerebras Gradient Scaler implementation"""
from contextlib import ExitStack, contextmanager
from enum import Enum, auto
from typing import Union

import torch

import cerebras_pytorch.experimental.amp as amp
from cerebras_pytorch.amp._amp_state import _amp_state, maybe_print
from cerebras_pytorch.experimental.backend import current_backend_impl
from cerebras_pytorch.experimental.utils._tensor import conditional_update


class OptState(Enum):
    """
    An enum to specify the optimizer's current state regarding if its been
    scaled or not
    """

    READY = auto()
    SCALED = auto()
    UNSCALED = auto()
    STEPPED = auto()

    def is_unscaled(self):
        """Returns true if the state is unscaled"""
        return self == OptState.UNSCALED


[docs]class GradScaler: """ Faciliates mixed precision training and DLS, DLS + GCC For more details please see docs for amp.initialize. Args: loss_scale: If loss_scale == "dynamic", then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling. init_scale: The initial loss scale value if loss_scale == "dynamic" steps_per_increase: The number of steps after which to increase the loss scaling condition min_loss_scale: The minimum loss scale value that can be chosen by dynamic loss scaling max_loss_scale: The maximum loss scale value that can be chosen by dynamic loss scaling overflow_tolerance: The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded max_gradient_norm: The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect """ warned_unscaling_non_fp32_grad = False
[docs] def __init__( self, loss_scale: Union[str, float] = None, init_scale: float = None, steps_per_increase: int = None, min_loss_scale: float = None, max_loss_scale: float = None, overflow_tolerance: float = 0.05, max_gradient_norm: float = None, ): loss_scale = loss_scale if loss_scale else 1.0 init_scale = init_scale if init_scale else 2.0 ** 15 steps_per_increase = steps_per_increase if steps_per_increase else 2000 min_loss_scale = min_loss_scale if min_loss_scale else 2.0 ** -14 max_loss_scale = max_loss_scale if max_loss_scale else 2.0 ** 15 self.loss_is_scaled = loss_scale != 1.0 self.backend = current_backend_impl() if loss_scale == "dynamic": if min_loss_scale < 2.0 ** -14: raise ValueError("min_loss_scale too small") if overflow_tolerance < 0: raise ValueError( "loss scaling counter threshold must be set >= 0" ) if ( self.backend.backend_type.is_pipeline and steps_per_increase >= 2 ** 15 - 1 ): raise ValueError( f"loss scaling counters are implemented as int16 on WSE " f"and thus cannot exceed {2**15 - 1}." ) self.dynamic = True self._loss_scale = torch.tensor( min(max_loss_scale, init_scale), dtype=torch.float32, ) self._steps_since_rescale = torch.tensor(0, dtype=torch.int64) self._overflows_since_rescale = torch.tensor(0, dtype=torch.int64) self._overflow_tolerance = overflow_tolerance self._max_gradient_norm = max_gradient_norm # Will be set in `_unscale_helper` self._squared_local_norms = [] # Will bet set in `update_scale` self.isfinite = None else: self.dynamic = False self._loss_scale = loss_scale self.isfinite = True max_gradient_norm = None self._max_loss_scale = max_loss_scale self._min_loss_scale = min_loss_scale self._steps_per_increase = steps_per_increase self.global_norm = None self.backend.setup_grad_scaler(self) for optimizer in self.backend.optimizer_registry: amp.setup_optimizer(optimizer) optimizer._amp_stash.state = OptState.READY
[docs] def state_dict(self, destination=None): """ Returns a dictionary containing the state to be saved to a checkpoint """ if not self.backend.backend_type.is_csx: return {} if self.dynamic: return { "loss_scale": self._loss_scale, "steps_since_rescale": self._steps_since_rescale, "overflows_since_rescale": self._overflows_since_rescale, } else: return {"loss_scale": self._loss_scale}
[docs] def load_state_dict(self, state_dict): """ Loads the state dictionary into the current params """ def load_param(param, param_name): value = state_dict[param_name] if isinstance(param, torch.Tensor): if isinstance(value, torch.Tensor): # Only move to device is the param device is not CPU # Otherwise keep the original value's device if ( value.device.type != param.device.type and param.device.type != "cpu" ): return value.to(param.device) return value else: return torch.tensor(value, dtype=param.dtype).to( param.device ) else: return value self._loss_scale = load_param(self._loss_scale, "loss_scale") if self.dynamic: self._steps_since_rescale = load_param( self._steps_since_rescale, "steps_since_rescale" ) self._overflows_since_rescale = load_param( self._overflows_since_rescale, "overflows_since_rescale" )
[docs] def scale(self, loss: torch.Tensor): """Scales the loss in preparation of the backwards pass""" # TODO: handle the case of outputs being iterable # which is supported by the torch interface if not self.backend.backend_type.is_csx: return loss with self.backend.name_scope("grad_scaler.scale"): loss = self.backend.pre_backward(loss) if (not self.dynamic) and self._loss_scale == 1.0: # Mark optimizers has having been unscaled since there is # no scaling to be done for optimizer in self.backend.optimizer_registry: # pylint: disable=protected-access optimizer._amp_stash.state = OptState.UNSCALED return loss.float() for optimizer in self.backend.optimizer_registry: # pylint: disable=protected-access if optimizer._amp_stash.state == OptState.READY: optimizer._prepare_amp_backward() optimizer._amp_stash.state = OptState.SCALED continue if optimizer._amp_stash.state != OptState.SCALED: raise RuntimeError( "Optimizer parameter gradients already scaled" ) return (loss.float()) * self._loss_scale
[docs] def get_scale(self): """Return the loss scale""" return self._loss_scale
def _unscale_helper(self, model_grads, master_grads, scale): for model, master in zip(model_grads, master_grads): if model is not None: if ( master is not model ): # copy_ probably internally short-circuits this master.copy_(model) if not self.dynamic and scale == 1.0: return if not GradScaler.warned_unscaling_non_fp32_grad: for master in master_grads: if master.dtype != torch.float32: maybe_print( f"Attempting to unscale a grad with type {master.type()} " f"Unscaling non-fp32 grads may indicate an error. " f"When using Amp, you don't need to call .half() on your model." ) GradScaler.warned_unscaling_non_fp32_grad = True if self.dynamic: inv_scale = torch.tensor(1.0, dtype=torch.float32) / scale else: inv_scale = torch.tensor(1.0 / scale, dtype=torch.float32) for master in master_grads: master.mul_(inv_scale) if self.dynamic: # Use CS1 compatible algorithm for detcting NaN/inf by using global # L2 norm of all gradients norms_squared = [torch.sum(g * g) for g in master_grads] self._squared_local_norms.extend(norms_squared) def _unscale( self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None, ): # implementation scale = self._loss_scale if scale_override is not None: scale = scale_override if self.dynamic or not models_are_masters or scale != 1.0: self._unscale_helper(model_grads, master_grads, scale) def _unscale_with_stashed_python( self, model_grads, stashed_master_grads, master_grads, a, b ): # pylint: disable=missing-function-docstring raise NotImplementedError("stashed grads not supported") def _unscale_with_stashed( self, model_grads, stashed_master_grads, master_grads, scale_override=None, ): # pylint: disable=missing-function-docstring raise NotImplementedError("stashed grads not supported")
[docs] def unscale_(self, optimizer): """Unscales the optimizer's params gradients inplace""" # Go unscale all the gradients # pylint: disable=protected-access if optimizer._amp_stash.state == OptState.UNSCALED: return # no-op elif optimizer._amp_stash.state == OptState.STEPPED: raise RuntimeError("unscale_() is being called after step().") # if not dynamic, short circuit to match the implicit context manager case if (not self.dynamic) and self._loss_scale == 1.0: optimizer._amp_stash.state = OptState.UNSCALED return with self.backend.name_scope("grad_scaler.unscale_"): optimizer._post_amp_backward(self) optimizer._amp_stash.params_have_scaled_gradients = False optimizer._amp_stash.state = OptState.UNSCALED
# Reconsider this name because `_handle_undefined_state` is kind of nonsense. @contextmanager def _handle_undefined_state(self, optimizer): if not self.dynamic: yield return # Rather than conditionally execute the optimizer using python flow control, # we must trace the execution of the optimizer. However, if the gradients # were not finite, the state should not be updated. Therefore, all state # update (parameters and optimizer state) must be conditional. Use the # conditional_update context manager for each tensor which may be updated. with ExitStack() as all_conditional_updates: for group in optimizer.param_groups: for p in group["params"]: if p.grad is None: continue all_conditional_updates.enter_context( conditional_update(p, self.isfinite) ) for state in optimizer.state[p].values(): all_conditional_updates.enter_context( conditional_update(state, self.isfinite) ) yield # Update optimizer parameter and state
[docs] def step_if_finite(self, optimizer, *args, **kwargs): """ Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads. Args: optimizer (cerebras_pytorch.experimental.optim.Optimizer): Optimizer that applies the gradients. args: Any arguments. kwargs: Any keyword arguments. Returns: The result of optimizer.step() """ with self._handle_undefined_state( optimizer ): # combine static and dynamic next with amp.disable_casts(): return optimizer.step(*args, **kwargs)
[docs] def clip_gradients_and_return_isfinite(self, optimizers): """ Clip the optimizer's params's gradients and return whether or not the norm is finite """ # Compute gloal norm from all squared local norms # if not self.global_norm: self.global_norm = torch.sqrt( torch.sum(torch.stack(self._squared_local_norms)) ) def float32(value): return torch.tensor( value, dtype=torch.float32, device=self.global_norm.device ) # self.isfinite = torch.isfinite(self.global_norm) # TODO: torch.isfinite^ hits a lowering error! so use: self.isfinite = self.global_norm < float32(float("inf")) if self._max_gradient_norm: # Then we're doing combo GGC + DLS # https://github.com/pytorch/pytorch/blob/release/1.9/torch/nn/utils/clip_grad.py#L56-L59 clip_coef = float32(self._max_gradient_norm) / ( self.global_norm + 1e-6 ) clip_coef = torch.where(clip_coef < 1, clip_coef, 1.0,) for optimizer in optimizers: for group in optimizer.param_groups: for p in group['params']: if p.grad is None: continue p.grad.detach().mul_(clip_coef) return self.isfinite
[docs] def step(self, optimizer, *args, **kwargs): """ `Step` carries out the following two operations: 1. Internally invokes ``unscale_(optimizer)`` (unless `unscale_` was explicitly called for ``optimizer`` earlier in the iteration). As part of the `unscale_`, gradients are checked for infs/NaNs. 2. Invokes ``optimizer.step()`` using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients. ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. Returns the return value of ``optimizer.step(*args, **kwargs)``. Args: optimizer (cerebras_pytorch.optim.Optimizer): Optimizer that applies the gradients. args: Any arguments. kwargs: Any keyword arguments. """ # pylint: disable=protected-access if optimizer._amp_stash.state == OptState.STEPPED: raise RuntimeError( "step() has already been called since the last update()." ) # must unscale all optimizers prior to step and update # so global grad norm can be computed correctly # guaranteed to also unscale the given optimizer if needed for _optimizer in self.backend.optimizer_registry: if _optimizer._amp_stash.state == OptState.READY: self.unscale_(_optimizer) with self.backend.name_scope("grad_scaler.step"): if self.dynamic and self.isfinite is None: self.isfinite = self.clip_gradients_and_return_isfinite( self.backend.optimizer_registry ) # Run optimizer's base step if self.isfinite is true return_val = self.step_if_finite(optimizer, *args, **kwargs) optimizer._amp_stash.state = OptState.STEPPED return return_val
[docs] def update_scale(self, optimizers): """ Update the scales of the optimizers """ if not self.dynamic: return # Compute gloal norm from all squared local norms # if not self.global_norm: self.global_norm = torch.sqrt( torch.sum(torch.stack(self._squared_local_norms)) ) def float32(value): return torch.tensor( value, dtype=torch.float32, device=self.global_norm.device ) def int64(value): return torch.tensor( value, dtype=torch.int64, device=self.global_norm.device ) # Reset local norms for next iteration self._squared_local_norms = [] # integer representation of isfinite isfinite_int = self.isfinite.long() # Increment the step counter self._steps_since_rescale.add_(1) # If overflow, increment the overflow counter self._overflows_since_rescale.add_(1 - isfinite_int) ratio = ( self._overflows_since_rescale.float() / self._steps_since_rescale.float() ) # Decrease loss scale # decrease loss scaling condition # 1 if we've exceeded our overflow tolerance # 0 if we haven't hit too many overflows overflow_tolerance_exceeded = ( float32(self._overflow_tolerance) < ratio ).long() # decrease loss scale 2x if we're decreasing, otherwise unchanged loss_scale_divisor = (1 + overflow_tolerance_exceeded).float() self._loss_scale.div_(loss_scale_divisor) # reset counters reset_because_decreasing = 1 - overflow_tolerance_exceeded self._overflows_since_rescale.mul_(reset_because_decreasing) self._steps_since_rescale.mul_(reset_because_decreasing) # Increasing loss scale # (done purposefully after decrease logic in case counter reset) # increase loss scaling condition # 1 if we've exceeded our steps per increase counter # 0 if we haven't yet. increase_counter_exceeded = ( int64(self._steps_per_increase) < self._steps_since_rescale ).long() # increase loss scale 2x if we're increasing, otherwise unchanged loss_scale_multipler = (1 + increase_counter_exceeded).float() self._loss_scale.mul_(loss_scale_multipler) # reset counters reset_because_increasing = 1 - increase_counter_exceeded self._overflows_since_rescale.mul_(reset_because_increasing) self._steps_since_rescale.mul_(reset_because_increasing) # clamp loss scale to within min/max max_ls = float32(self._max_loss_scale) self._loss_scale.data = torch.where( self._loss_scale < max_ls, self._loss_scale, max_ls, ) min_ls = float32(self._min_loss_scale) self._loss_scale.data = torch.where( min_ls < self._loss_scale, self._loss_scale, min_ls, )
[docs] def update(self, new_scale=None): """Update the gradient scalar after all optimizers have been stepped""" if new_scale: raise ValueError( "cstorch.amp.GradScaler does not support providing a `new_scale`" ) # Update scale if self.dynamic or self._loss_scale != 1.0: with self.backend.name_scope("grad_scaler.update"): self.update_scale(self.backend.optimizer_registry) # pylint: disable=protected-access,no-member _amp_state.handle._clear_cache() # clear all data from this iteration for the next self.isfinite = None for optimizer in self.backend.optimizer_registry: optimizer._amp_stash.state = OptState.READY