"""Cerebras Gradient Scaler implementation"""
from contextlib import ExitStack, contextmanager
from enum import Enum, auto
from typing import Union
import torch
import cerebras_pytorch.experimental.amp as amp
from cerebras_pytorch.amp._amp_state import _amp_state, maybe_print
from cerebras_pytorch.experimental.backend import current_backend_impl
from cerebras_pytorch.experimental.utils._tensor import conditional_update
class OptState(Enum):
An enum to specify the optimizer's current state regarding if its been
scaled or not
READY = auto()
SCALED = auto()
UNSCALED = auto()
STEPPED = auto()
def is_unscaled(self):
"""Returns true if the state is unscaled"""
return self == OptState.UNSCALED
[docs]class GradScaler:
Faciliates mixed precision training and DLS, DLS + GCC
For more details please see docs for amp.initialize.
If loss_scale == "dynamic", then configure dynamic loss
scaling. Otherwise, it is the loss scale value used in static
loss scaling.
The initial loss scale value if loss_scale == "dynamic"
The number of steps after which to increase the loss
scaling condition
The minimum loss scale value that can be chosen by dynamic
loss scaling
The maximum loss scale value that can be chosen by dynamic
loss scaling
The maximum fraction of steps involving infinite or undefined
values in the gradient we allow. We reduce the loss scale if
the tolerance is exceeded
The maximum gradient norm to use for global gradient clipping
Only applies in the DLS + GCC case. If GCC is not enabled,
then this parameter has no effect
warned_unscaling_non_fp32_grad = False
[docs] def __init__(
loss_scale: Union[str, float] = None,
init_scale: float = None,
steps_per_increase: int = None,
min_loss_scale: float = None,
max_loss_scale: float = None,
overflow_tolerance: float = 0.05,
max_gradient_norm: float = None,
loss_scale = loss_scale if loss_scale else 1.0
init_scale = init_scale if init_scale else 2.0 ** 15
steps_per_increase = steps_per_increase if steps_per_increase else 2000
min_loss_scale = min_loss_scale if min_loss_scale else 2.0 ** -14
max_loss_scale = max_loss_scale if max_loss_scale else 2.0 ** 15
self.loss_is_scaled = loss_scale != 1.0
self.backend = current_backend_impl()
if loss_scale == "dynamic":
if min_loss_scale < 2.0 ** -14:
raise ValueError("min_loss_scale too small")
if overflow_tolerance < 0:
raise ValueError(
"loss scaling counter threshold must be set >= 0"
if (
and steps_per_increase >= 2 ** 15 - 1
raise ValueError(
f"loss scaling counters are implemented as int16 on WSE "
f"and thus cannot exceed {2**15 - 1}."
self.dynamic = True
self._loss_scale = torch.tensor(
min(max_loss_scale, init_scale), dtype=torch.float32,
self._steps_since_rescale = torch.tensor(0, dtype=torch.int64)
self._overflows_since_rescale = torch.tensor(0, dtype=torch.int64)
self._overflow_tolerance = overflow_tolerance
self._max_gradient_norm = max_gradient_norm
# Will be set in `_unscale_helper`
self._squared_local_norms = []
# Will bet set in `update_scale`
self.isfinite = None
self.dynamic = False
self._loss_scale = loss_scale
self.isfinite = True
max_gradient_norm = None
self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._steps_per_increase = steps_per_increase
self.global_norm = None
for optimizer in self.backend.optimizer_registry:
optimizer._amp_stash.state = OptState.READY
[docs] def state_dict(self, destination=None):
Returns a dictionary containing the state to be saved to a checkpoint
if not self.backend.backend_type.is_csx:
return {}
if self.dynamic:
return {
"loss_scale": self._loss_scale,
"steps_since_rescale": self._steps_since_rescale,
"overflows_since_rescale": self._overflows_since_rescale,
return {"loss_scale": self._loss_scale}
[docs] def load_state_dict(self, state_dict):
""" Loads the state dictionary into the current params """
def load_param(param, param_name):
value = state_dict[param_name]
if isinstance(param, torch.Tensor):
if isinstance(value, torch.Tensor):
# Only move to device is the param device is not CPU
# Otherwise keep the original value's device
if (
value.device.type != param.device.type
and param.device.type != "cpu"
return value.to(param.device)
return value
return torch.tensor(value, dtype=param.dtype).to(
return value
self._loss_scale = load_param(self._loss_scale, "loss_scale")
if self.dynamic:
self._steps_since_rescale = load_param(
self._steps_since_rescale, "steps_since_rescale"
self._overflows_since_rescale = load_param(
self._overflows_since_rescale, "overflows_since_rescale"
[docs] def scale(self, loss: torch.Tensor):
"""Scales the loss in preparation of the backwards pass"""
# TODO: handle the case of outputs being iterable
# which is supported by the torch interface
if not self.backend.backend_type.is_csx:
return loss
with self.backend.name_scope("grad_scaler.scale"):
loss = self.backend.pre_backward(loss)
if (not self.dynamic) and self._loss_scale == 1.0:
# Mark optimizers has having been unscaled since there is
# no scaling to be done
for optimizer in self.backend.optimizer_registry:
# pylint: disable=protected-access
optimizer._amp_stash.state = OptState.UNSCALED
return loss.float()
for optimizer in self.backend.optimizer_registry:
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.READY:
optimizer._amp_stash.state = OptState.SCALED
if optimizer._amp_stash.state != OptState.SCALED:
raise RuntimeError(
"Optimizer parameter gradients already scaled"
return (loss.float()) * self._loss_scale
[docs] def get_scale(self):
"""Return the loss scale"""
return self._loss_scale
def _unscale_helper(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads):
if model is not None:
if (
master is not model
): # copy_ probably internally short-circuits this
if not self.dynamic and scale == 1.0:
if not GradScaler.warned_unscaling_non_fp32_grad:
for master in master_grads:
if master.dtype != torch.float32:
f"Attempting to unscale a grad with type {master.type()} "
f"Unscaling non-fp32 grads may indicate an error. "
f"When using Amp, you don't need to call .half() on your model."
GradScaler.warned_unscaling_non_fp32_grad = True
if self.dynamic:
inv_scale = torch.tensor(1.0, dtype=torch.float32) / scale
inv_scale = torch.tensor(1.0 / scale, dtype=torch.float32)
for master in master_grads:
if self.dynamic:
# Use CS1 compatible algorithm for detcting NaN/inf by using global
# L2 norm of all gradients
norms_squared = [torch.sum(g * g) for g in master_grads]
def _unscale(
# implementation
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if self.dynamic or not models_are_masters or scale != 1.0:
self._unscale_helper(model_grads, master_grads, scale)
def _unscale_with_stashed_python(
self, model_grads, stashed_master_grads, master_grads, a, b
): # pylint: disable=missing-function-docstring
raise NotImplementedError("stashed grads not supported")
def _unscale_with_stashed(
): # pylint: disable=missing-function-docstring
raise NotImplementedError("stashed grads not supported")
[docs] def unscale_(self, optimizer):
"""Unscales the optimizer's params gradients inplace"""
# Go unscale all the gradients
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.UNSCALED:
return # no-op
elif optimizer._amp_stash.state == OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# if not dynamic, short circuit to match the implicit context manager case
if (not self.dynamic) and self._loss_scale == 1.0:
optimizer._amp_stash.state = OptState.UNSCALED
with self.backend.name_scope("grad_scaler.unscale_"):
optimizer._amp_stash.params_have_scaled_gradients = False
optimizer._amp_stash.state = OptState.UNSCALED
# Reconsider this name because `_handle_undefined_state` is kind of nonsense.
def _handle_undefined_state(self, optimizer):
if not self.dynamic:
# Rather than conditionally execute the optimizer using python flow control,
# we must trace the execution of the optimizer. However, if the gradients
# were not finite, the state should not be updated. Therefore, all state
# update (parameters and optimizer state) must be conditional. Use the
# conditional_update context manager for each tensor which may be updated.
with ExitStack() as all_conditional_updates:
for group in optimizer.param_groups:
for p in group["params"]:
if p.grad is None:
conditional_update(p, self.isfinite)
for state in optimizer.state[p].values():
conditional_update(state, self.isfinite)
yield # Update optimizer parameter and state
[docs] def step_if_finite(self, optimizer, *args, **kwargs):
Directly conditionalize the call to optimizer.step(*args, **kwargs) but
only if this GradScaler detected finite grads.
optimizer (cerebras_pytorch.experimental.optim.Optimizer):
Optimizer that applies the gradients.
Any arguments.
Any keyword arguments.
The result of optimizer.step()
with self._handle_undefined_state(
): # combine static and dynamic next
with amp.disable_casts():
return optimizer.step(*args, **kwargs)
[docs] def clip_gradients_and_return_isfinite(self, optimizers):
Clip the optimizer's params's gradients and return whether or not the
norm is finite
# Compute gloal norm from all squared local norms
# if not self.global_norm:
self.global_norm = torch.sqrt(
def float32(value):
return torch.tensor(
value, dtype=torch.float32, device=self.global_norm.device
# self.isfinite = torch.isfinite(self.global_norm)
# TODO: torch.isfinite^ hits a lowering error! so use:
self.isfinite = self.global_norm < float32(float("inf"))
if self._max_gradient_norm:
# Then we're doing combo GGC + DLS
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/nn/utils/clip_grad.py#L56-L59
clip_coef = float32(self._max_gradient_norm) / (
self.global_norm + 1e-6
clip_coef = torch.where(clip_coef < 1, clip_coef, 1.0,)
for optimizer in optimizers:
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is None:
return self.isfinite
[docs] def step(self, optimizer, *args, **kwargs):
`Step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless `unscale_` was
explicitly called for ``optimizer`` earlier in the iteration). As
part of the `unscale_`, gradients are checked for infs/NaNs.
2. Invokes ``optimizer.step()`` using the unscaled gradients. Ensure
that previous optimizer state or params carry over if we encounter
NaNs in the gradients.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
optimizer (cerebras_pytorch.optim.Optimizer):
Optimizer that applies the gradients.
Any arguments.
Any keyword arguments.
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update()."
# must unscale all optimizers prior to step and update
# so global grad norm can be computed correctly
# guaranteed to also unscale the given optimizer if needed
for _optimizer in self.backend.optimizer_registry:
if _optimizer._amp_stash.state == OptState.READY:
with self.backend.name_scope("grad_scaler.step"):
if self.dynamic and self.isfinite is None:
self.isfinite = self.clip_gradients_and_return_isfinite(
# Run optimizer's base step if self.isfinite is true
return_val = self.step_if_finite(optimizer, *args, **kwargs)
optimizer._amp_stash.state = OptState.STEPPED
return return_val
[docs] def update_scale(self, optimizers):
""" Update the scales of the optimizers """
if not self.dynamic:
# Compute gloal norm from all squared local norms
# if not self.global_norm:
self.global_norm = torch.sqrt(
def float32(value):
return torch.tensor(
value, dtype=torch.float32, device=self.global_norm.device
def int64(value):
return torch.tensor(
value, dtype=torch.int64, device=self.global_norm.device
# Reset local norms for next iteration
self._squared_local_norms = []
# integer representation of isfinite
isfinite_int = self.isfinite.long()
# Increment the step counter
# If overflow, increment the overflow counter
self._overflows_since_rescale.add_(1 - isfinite_int)
ratio = (
/ self._steps_since_rescale.float()
# Decrease loss scale
# decrease loss scaling condition
# 1 if we've exceeded our overflow tolerance
# 0 if we haven't hit too many overflows
overflow_tolerance_exceeded = (
float32(self._overflow_tolerance) < ratio
# decrease loss scale 2x if we're decreasing, otherwise unchanged
loss_scale_divisor = (1 + overflow_tolerance_exceeded).float()
# reset counters
reset_because_decreasing = 1 - overflow_tolerance_exceeded
# Increasing loss scale
# (done purposefully after decrease logic in case counter reset)
# increase loss scaling condition
# 1 if we've exceeded our steps per increase counter
# 0 if we haven't yet.
increase_counter_exceeded = (
int64(self._steps_per_increase) < self._steps_since_rescale
# increase loss scale 2x if we're increasing, otherwise unchanged
loss_scale_multipler = (1 + increase_counter_exceeded).float()
# reset counters
reset_because_increasing = 1 - increase_counter_exceeded
# clamp loss scale to within min/max
max_ls = float32(self._max_loss_scale)
self._loss_scale.data = torch.where(
self._loss_scale < max_ls, self._loss_scale, max_ls,
min_ls = float32(self._min_loss_scale)
self._loss_scale.data = torch.where(
min_ls < self._loss_scale, self._loss_scale, min_ls,
[docs] def update(self, new_scale=None):
"""Update the gradient scalar after all optimizers have been stepped"""
if new_scale:
raise ValueError(
"cstorch.amp.GradScaler does not support providing a `new_scale`"
# Update scale
if self.dynamic or self._loss_scale != 1.0:
with self.backend.name_scope("grad_scaler.update"):
# pylint: disable=protected-access,no-member
# clear all data from this iteration for the next
self.isfinite = None
for optimizer in self.backend.optimizer_registry:
optimizer._amp_stash.state = OptState.READY