cerebras.pytorch.amp#
Automatic mixed precision#
The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster
GradScaler#
- class cerebras.pytorch.amp.GradScaler[source]#
Faciliates mixed precision training and DLS, DLS + GCC
For more details please see docs for amp.initialize.
- Parameters
loss_scale – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.
init_scale – The initial loss scale value if loss_scale == “dynamic”
steps_per_increase – The number of steps after which to increase the loss scaling condition
min_loss_scale – The minimum loss scale value that can be chosen by dynamic loss scaling
max_loss_scale – The maximum loss scale value that can be chosen by dynamic loss scaling
overflow_tolerance – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded
max_gradient_norm – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect
Example usage:
grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic") loss: torch.Tensor = ... optimizer.zero_grad() # Scale the loss before calling the backward pass grad_scaler.scale(loss).backward() # Unscales the gradients of optimizer's assigned params in-place # to facilitate things like gradient clipping grad_scaler.unscale_(optimizer) # Global gradient clipping torch.nn.utils.clip_grad_norm_( model.parameters(), 1.0, # max gradient norm ) # Step the optimizer using the grad scaler grad_scaler.step(optimizer) # update the grad scaler once all optimizers have been stepped grad_scaler.update()
- __init__(loss_scale: Optional[Union[str, float]] = None, init_scale: Optional[float] = None, steps_per_increase: Optional[int] = None, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = None, overflow_tolerance: float = 0.0, max_gradient_norm: Optional[float] = None)[source]#
- state_dict(destination=None)[source]#
Returns a dictionary containing the state to be saved to a checkpoint
- step_if_finite(optimizer, *args, **kwargs)[source]#
Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.
- Parameters
optimizer (cerebras.pytorch.optim.Optimizer) – Optimizer that applies the gradients.
args – Any arguments.
kwargs – Any keyword arguments.
- Returns
The result of optimizer.step()
- clip_gradients_and_return_isfinite(optimizers)[source]#
Clip the optimizer’s params’s gradients and return whether or not the norm is finite
- step(optimizer, *args, **kwargs)[source]#
Step carries out the following two operations: 1. Internally invokes
unscale_(optimizer)(unless unscale_ wasexplicitly called for
optimizerearlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.Invokes
optimizer.step()using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.
*argsand**kwargsare forwarded tooptimizer.step(). Returns the return value ofoptimizer.step(*args, **kwargs). :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.
set_half_dtype#
- cerebras.pytorch.amp.set_half_dtype(value: Union[Literal['float16', 'bfloat16', 'cbfloat16'], torch.dtype]) torch.dtype[source]#
Sets the underlying 16-bit floating point dtype to use.
- Parameters
value – Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.
- Returns
The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.
By default, automatic mixed precision uses
float16. If you want to usecbfloat16orbfloat16instead offloat16, call this function.Example usage:
cstorch.amp.set_half_dtype("cbfloat16")
optimizer_step#
- cerebras.pytorch.amp.optimizer_step(loss: torch.Tensor, optimizer: cerebras.pytorch.optim.optimizer.Optimizer, grad_scaler: cerebras.pytorch.amp.grad_scaler.GradScaler, max_gradient_norm: Optional[float] = None, max_gradient_value: Optional[float] = None)[source]#
Performs loss scaling, gradient scaling and optimizer step
- Parameters
loss – The loss value to scale. loss.backward should be called before this function
optimizer – The optimizer to step
grad_scaler – The gradient scaler to use to scale the parameter gradients
max_gradient_norm – the max gradient norm to use for gradient clipping
max_gradient_value – the max gradient value to use for gradient clipping
Example usage:
cstorch.amp.optimizer_step( loss, optimizer, grad_scaler, max_gradient_norm=1.0, )