"""Cerebras Gradient Scaler implementation"""
from contextlib import ExitStack, contextmanager
from enum import Enum, auto
from typing import Union
import torch
import cerebras_pytorch.experimental.amp as amp
from cerebras_pytorch.amp._amp_state import _amp_state, maybe_print
from cerebras_pytorch.experimental.backend import current_backend_impl
from cerebras_pytorch.experimental.utils._tensor import conditional_update
class OptState(Enum):
"""
An enum to specify the optimizer's current state regarding if its been
scaled or not
"""
READY = auto()
SCALED = auto()
UNSCALED = auto()
STEPPED = auto()
def is_unscaled(self):
"""Returns true if the state is unscaled"""
return self == OptState.UNSCALED
[docs]class GradScaler:
"""
Faciliates mixed precision training and DLS, DLS + GCC
For more details please see docs for amp.initialize.
Args:
loss_scale:
If loss_scale == "dynamic", then configure dynamic loss
scaling. Otherwise, it is the loss scale value used in static
loss scaling.
init_scale:
The initial loss scale value if loss_scale == "dynamic"
steps_per_increase:
The number of steps after which to increase the loss
scaling condition
min_loss_scale:
The minimum loss scale value that can be chosen by dynamic
loss scaling
max_loss_scale:
The maximum loss scale value that can be chosen by dynamic
loss scaling
overflow_tolerance:
The maximum fraction of steps involving infinite or undefined
values in the gradient we allow. We reduce the loss scale if
the tolerance is exceeded
max_gradient_norm:
The maximum gradient norm to use for global gradient clipping
Only applies in the DLS + GCC case. If GCC is not enabled,
then this parameter has no effect
"""
warned_unscaling_non_fp32_grad = False
[docs] def __init__(
self,
loss_scale: Union[str, float] = None,
init_scale: float = None,
steps_per_increase: int = None,
min_loss_scale: float = None,
max_loss_scale: float = None,
overflow_tolerance: float = 0.05,
max_gradient_norm: float = None,
):
loss_scale = loss_scale if loss_scale else 1.0
init_scale = init_scale if init_scale else 2.0 ** 15
steps_per_increase = steps_per_increase if steps_per_increase else 2000
min_loss_scale = min_loss_scale if min_loss_scale else 2.0 ** -14
max_loss_scale = max_loss_scale if max_loss_scale else 2.0 ** 15
self.loss_is_scaled = loss_scale != 1.0
self.backend = current_backend_impl()
if loss_scale == "dynamic":
if min_loss_scale < 2.0 ** -14:
raise ValueError("min_loss_scale too small")
if overflow_tolerance < 0:
raise ValueError(
"loss scaling counter threshold must be set >= 0"
)
if (
self.backend.backend_type.is_pipeline
and steps_per_increase >= 2 ** 15 - 1
):
raise ValueError(
f"loss scaling counters are implemented as int16 on WSE "
f"and thus cannot exceed {2**15 - 1}."
)
self.dynamic = True
self._loss_scale = torch.tensor(
min(max_loss_scale, init_scale), dtype=torch.float32,
)
self._steps_since_rescale = torch.tensor(0, dtype=torch.int64)
self._overflows_since_rescale = torch.tensor(0, dtype=torch.int64)
self._overflow_tolerance = overflow_tolerance
self._max_gradient_norm = max_gradient_norm
# Will be set in `_unscale_helper`
self._squared_local_norms = []
# Will bet set in `update_scale`
self.isfinite = None
else:
self.dynamic = False
self._loss_scale = loss_scale
self.isfinite = True
max_gradient_norm = None
self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._steps_per_increase = steps_per_increase
self.global_norm = None
self.backend.setup_grad_scaler(self)
for optimizer in self.backend.optimizer_registry:
amp.setup_optimizer(optimizer)
optimizer._amp_stash.state = OptState.READY
[docs] def state_dict(self, destination=None):
"""
Returns a dictionary containing the state to be saved to a checkpoint
"""
if not self.backend.backend_type.is_csx:
return {}
if self.dynamic:
return {
"loss_scale": self._loss_scale,
"steps_since_rescale": self._steps_since_rescale,
"overflows_since_rescale": self._overflows_since_rescale,
}
else:
return {"loss_scale": self._loss_scale}
[docs] def load_state_dict(self, state_dict):
""" Loads the state dictionary into the current params """
def load_param(param, param_name):
value = state_dict[param_name]
if isinstance(param, torch.Tensor):
if isinstance(value, torch.Tensor):
# Only move to device is the param device is not CPU
# Otherwise keep the original value's device
if (
value.device.type != param.device.type
and param.device.type != "cpu"
):
return value.to(param.device)
return value
else:
return torch.tensor(value, dtype=param.dtype).to(
param.device
)
else:
return value
self._loss_scale = load_param(self._loss_scale, "loss_scale")
if self.dynamic:
self._steps_since_rescale = load_param(
self._steps_since_rescale, "steps_since_rescale"
)
self._overflows_since_rescale = load_param(
self._overflows_since_rescale, "overflows_since_rescale"
)
[docs] def scale(self, loss: torch.Tensor):
"""Scales the loss in preparation of the backwards pass"""
# TODO: handle the case of outputs being iterable
# which is supported by the torch interface
if not self.backend.backend_type.is_csx:
return loss
with self.backend.name_scope("grad_scaler.scale"):
loss = self.backend.pre_backward(loss)
if (not self.dynamic) and self._loss_scale == 1.0:
# Mark optimizers has having been unscaled since there is
# no scaling to be done
for optimizer in self.backend.optimizer_registry:
# pylint: disable=protected-access
optimizer._amp_stash.state = OptState.UNSCALED
return loss.float()
for optimizer in self.backend.optimizer_registry:
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.READY:
optimizer._prepare_amp_backward()
optimizer._amp_stash.state = OptState.SCALED
continue
if optimizer._amp_stash.state != OptState.SCALED:
raise RuntimeError(
"Optimizer parameter gradients already scaled"
)
return (loss.float()) * self._loss_scale
[docs] def get_scale(self):
"""Return the loss scale"""
return self._loss_scale
def _unscale_helper(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads):
if model is not None:
if (
master is not model
): # copy_ probably internally short-circuits this
master.copy_(model)
if not self.dynamic and scale == 1.0:
return
if not GradScaler.warned_unscaling_non_fp32_grad:
for master in master_grads:
if master.dtype != torch.float32:
maybe_print(
f"Attempting to unscale a grad with type {master.type()} "
f"Unscaling non-fp32 grads may indicate an error. "
f"When using Amp, you don't need to call .half() on your model."
)
GradScaler.warned_unscaling_non_fp32_grad = True
if self.dynamic:
inv_scale = torch.tensor(1.0, dtype=torch.float32) / scale
else:
inv_scale = torch.tensor(1.0 / scale, dtype=torch.float32)
for master in master_grads:
master.mul_(inv_scale)
if self.dynamic:
# Use CS1 compatible algorithm for detcting NaN/inf by using global
# L2 norm of all gradients
norms_squared = [torch.sum(g * g) for g in master_grads]
self._squared_local_norms.extend(norms_squared)
def _unscale(
self,
model_grads,
master_grads,
unused_scale,
models_are_masters=False,
scale_override=None,
):
# implementation
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if self.dynamic or not models_are_masters or scale != 1.0:
self._unscale_helper(model_grads, master_grads, scale)
def _unscale_with_stashed_python(
self, model_grads, stashed_master_grads, master_grads, a, b
): # pylint: disable=missing-function-docstring
raise NotImplementedError("stashed grads not supported")
def _unscale_with_stashed(
self,
model_grads,
stashed_master_grads,
master_grads,
scale_override=None,
): # pylint: disable=missing-function-docstring
raise NotImplementedError("stashed grads not supported")
[docs] def unscale_(self, optimizer):
"""Unscales the optimizer's params gradients inplace"""
# Go unscale all the gradients
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.UNSCALED:
return # no-op
elif optimizer._amp_stash.state == OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# if not dynamic, short circuit to match the implicit context manager case
if (not self.dynamic) and self._loss_scale == 1.0:
optimizer._amp_stash.state = OptState.UNSCALED
return
with self.backend.name_scope("grad_scaler.unscale_"):
optimizer._post_amp_backward(self)
optimizer._amp_stash.params_have_scaled_gradients = False
optimizer._amp_stash.state = OptState.UNSCALED
# Reconsider this name because `_handle_undefined_state` is kind of nonsense.
@contextmanager
def _handle_undefined_state(self, optimizer):
if not self.dynamic:
yield
return
# Rather than conditionally execute the optimizer using python flow control,
# we must trace the execution of the optimizer. However, if the gradients
# were not finite, the state should not be updated. Therefore, all state
# update (parameters and optimizer state) must be conditional. Use the
# conditional_update context manager for each tensor which may be updated.
with ExitStack() as all_conditional_updates:
for group in optimizer.param_groups:
for p in group["params"]:
if p.grad is None:
continue
all_conditional_updates.enter_context(
conditional_update(p, self.isfinite)
)
for state in optimizer.state[p].values():
all_conditional_updates.enter_context(
conditional_update(state, self.isfinite)
)
yield # Update optimizer parameter and state
[docs] def step_if_finite(self, optimizer, *args, **kwargs):
"""
Directly conditionalize the call to optimizer.step(*args, **kwargs) but
only if this GradScaler detected finite grads.
Args:
optimizer (cerebras_pytorch.experimental.optim.Optimizer):
Optimizer that applies the gradients.
args:
Any arguments.
kwargs:
Any keyword arguments.
Returns:
The result of optimizer.step()
"""
with self._handle_undefined_state(
optimizer
): # combine static and dynamic next
with amp.disable_casts():
return optimizer.step(*args, **kwargs)
[docs] def clip_gradients_and_return_isfinite(self, optimizers):
"""
Clip the optimizer's params's gradients and return whether or not the
norm is finite
"""
# Compute gloal norm from all squared local norms
# if not self.global_norm:
self.global_norm = torch.sqrt(
torch.sum(torch.stack(self._squared_local_norms))
)
def float32(value):
return torch.tensor(
value, dtype=torch.float32, device=self.global_norm.device
)
# self.isfinite = torch.isfinite(self.global_norm)
# TODO: torch.isfinite^ hits a lowering error! so use:
self.isfinite = self.global_norm < float32(float("inf"))
if self._max_gradient_norm:
# Then we're doing combo GGC + DLS
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/nn/utils/clip_grad.py#L56-L59
clip_coef = float32(self._max_gradient_norm) / (
self.global_norm + 1e-6
)
clip_coef = torch.where(clip_coef < 1, clip_coef, 1.0,)
for optimizer in optimizers:
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.detach().mul_(clip_coef)
return self.isfinite
[docs] def step(self, optimizer, *args, **kwargs):
"""
`Step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless `unscale_` was
explicitly called for ``optimizer`` earlier in the iteration). As
part of the `unscale_`, gradients are checked for infs/NaNs.
2. Invokes ``optimizer.step()`` using the unscaled gradients. Ensure
that previous optimizer state or params carry over if we encounter
NaNs in the gradients.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (cerebras_pytorch.optim.Optimizer):
Optimizer that applies the gradients.
args:
Any arguments.
kwargs:
Any keyword arguments.
"""
# pylint: disable=protected-access
if optimizer._amp_stash.state == OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update()."
)
# must unscale all optimizers prior to step and update
# so global grad norm can be computed correctly
# guaranteed to also unscale the given optimizer if needed
for _optimizer in self.backend.optimizer_registry:
if _optimizer._amp_stash.state == OptState.READY:
self.unscale_(_optimizer)
with self.backend.name_scope("grad_scaler.step"):
if self.dynamic and self.isfinite is None:
self.isfinite = self.clip_gradients_and_return_isfinite(
self.backend.optimizer_registry
)
# Run optimizer's base step if self.isfinite is true
return_val = self.step_if_finite(optimizer, *args, **kwargs)
optimizer._amp_stash.state = OptState.STEPPED
return return_val
[docs] def update_scale(self, optimizers):
""" Update the scales of the optimizers """
if not self.dynamic:
return
# Compute gloal norm from all squared local norms
# if not self.global_norm:
self.global_norm = torch.sqrt(
torch.sum(torch.stack(self._squared_local_norms))
)
def float32(value):
return torch.tensor(
value, dtype=torch.float32, device=self.global_norm.device
)
def int64(value):
return torch.tensor(
value, dtype=torch.int64, device=self.global_norm.device
)
# Reset local norms for next iteration
self._squared_local_norms = []
# integer representation of isfinite
isfinite_int = self.isfinite.long()
# Increment the step counter
self._steps_since_rescale.add_(1)
# If overflow, increment the overflow counter
self._overflows_since_rescale.add_(1 - isfinite_int)
ratio = (
self._overflows_since_rescale.float()
/ self._steps_since_rescale.float()
)
# Decrease loss scale
# decrease loss scaling condition
# 1 if we've exceeded our overflow tolerance
# 0 if we haven't hit too many overflows
overflow_tolerance_exceeded = (
float32(self._overflow_tolerance) < ratio
).long()
# decrease loss scale 2x if we're decreasing, otherwise unchanged
loss_scale_divisor = (1 + overflow_tolerance_exceeded).float()
self._loss_scale.div_(loss_scale_divisor)
# reset counters
reset_because_decreasing = 1 - overflow_tolerance_exceeded
self._overflows_since_rescale.mul_(reset_because_decreasing)
self._steps_since_rescale.mul_(reset_because_decreasing)
# Increasing loss scale
# (done purposefully after decrease logic in case counter reset)
# increase loss scaling condition
# 1 if we've exceeded our steps per increase counter
# 0 if we haven't yet.
increase_counter_exceeded = (
int64(self._steps_per_increase) < self._steps_since_rescale
).long()
# increase loss scale 2x if we're increasing, otherwise unchanged
loss_scale_multipler = (1 + increase_counter_exceeded).float()
self._loss_scale.mul_(loss_scale_multipler)
# reset counters
reset_because_increasing = 1 - increase_counter_exceeded
self._overflows_since_rescale.mul_(reset_because_increasing)
self._steps_since_rescale.mul_(reset_because_increasing)
# clamp loss scale to within min/max
max_ls = float32(self._max_loss_scale)
self._loss_scale.data = torch.where(
self._loss_scale < max_ls, self._loss_scale, max_ls,
)
min_ls = float32(self._min_loss_scale)
self._loss_scale.data = torch.where(
min_ls < self._loss_scale, self._loss_scale, min_ls,
)
[docs] def update(self, new_scale=None):
"""Update the gradient scalar after all optimizers have been stepped"""
if new_scale:
raise ValueError(
"cstorch.amp.GradScaler does not support providing a `new_scale`"
)
# Update scale
if self.dynamic or self._loss_scale != 1.0:
with self.backend.name_scope("grad_scaler.update"):
self.update_scale(self.backend.optimizer_registry)
# pylint: disable=protected-access,no-member
_amp_state.handle._clear_cache()
# clear all data from this iteration for the next
self.isfinite = None
for optimizer in self.backend.optimizer_registry:
optimizer._amp_stash.state = OptState.READY