Source code for cerebras.modelzoo.trainer.callbacks.precision

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains the MixedPrecision callback class for mixed precision training."""

from abc import ABC, abstractmethod
from contextlib import nullcontext
from numbers import Number
from typing import Literal, Optional, Union
from warnings import warn

import torch

import cerebras.pytorch as cstorch
from cerebras.modelzoo.trainer.callbacks import CoreCallback


[docs]class Precision(CoreCallback, ABC): """Base precision class for implementing custom backwards pass and optimization step to handle different precision types. """
[docs] @abstractmethod def autocast_context_manager(self): """Returns the context manager that performs autocasting for the forward pass. """
[docs] @abstractmethod def backward(self, loss: torch.Tensor): """Performs the backward pass. Args: loss: Loss tensor. """
[docs] @abstractmethod def clip_gradients(self, optimizer: cstorch.optim.Optimizer): """Clips the gradients before the optimization step. Args: optimizer: The optimizer to step. """
[docs] @abstractmethod def optimizer_step(self, optimizer: cstorch.optim.Optimizer): """Performs the optimization step. Args: optimizer: The optimizer to step. """
[docs]class MixedPrecision(Precision): """ Callback class that handles mixed precision training. """ def __init__( self, enabled: bool = True, fp16_type: Literal["float16", "bfloat16", "cbfloat16"] = "bfloat16", precision_opt_level: Optional[Literal[0, 1, 2]] = None, loss_scaling_factor: Union[float, Literal["dynamic"]] = 1.0, initial_loss_scale: Optional[float] = None, steps_per_increase: int = 2000, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = None, max_gradient_norm: Optional[float] = None, max_gradient_value: Optional[float] = None, log_loss_scale: bool = False, ): """ Args: enabled: If True, enables mixed precision training. fp16_type: Half precision type. One of "float16", "bfloat16", "cbfloat16". precision_opt_level: Precision optimization level. If not None, sets the global precision optimization level. loss_scaling_factor: Initial loss scaling factor. initial_loss_scale: Initial loss scale. steps_per_increase: Number of steps before increasing the loss scale. min_loss_scale: Minimum loss scale. max_loss_scale: Maximum loss scale. max_gradient_norm: Maximum gradient norm for gradient clipping. max_gradient_value: Maximum gradient value for gradient clipping. log_loss_scale: If True, log the gradient scaler's loss scale. """ if not ( isinstance(loss_scaling_factor, Number) or loss_scaling_factor is None or loss_scaling_factor == "dynamic" ): raise ValueError( f"'loss_scaling_factor' must be a float, None, or 'dynamic'." f"Got {loss_scaling_factor} instead." ) self.scaler = None self.enabled = enabled cstorch.amp.enable_mixed_precision(enabled) self.fp16_type = fp16_type self.precision_opt_level = precision_opt_level self.loss_scaling_factor = loss_scaling_factor self.initial_loss_scale = initial_loss_scale self.steps_per_increase = steps_per_increase self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale self.max_gradient_norm = max_gradient_norm self.max_gradient_value = max_gradient_value self.log_loss_scale = log_loss_scale if precision_opt_level is not None: cstorch.backends.csx.precision.optimization_level = ( precision_opt_level ) def pre_setup(self, trainer): # pylint: disable=attribute-defined-outside-init self.backend = trainer.backend if not self.enabled: if self.backend.is_csx: warn( "Mixed precision must be enabled for CSX. Setting enabled = True" ) self.enabled = True cstorch.amp.enable_mixed_precision() else: return valid_dtypes = ["float16", "bfloat16", "cbfloat16"] if self.fp16_type not in valid_dtypes: raise ValueError( f"Invalid value for \"fp16_type\". Expected one of {valid_dtypes}. " f"Got: {self.fp16_type}." ) if self.fp16_type == "cbfloat16": if not self.backend.is_csx: self.fp16_type = torch.bfloat16 warn( "cbfloat16 is only supported on CSX. Setting half dtype to bfloat16." ) elif self.loss_scaling_factor != "dynamic": raise ValueError( f"In order to use cbfloat16, dynamic loss scaling must be enabled. " f"Otherwise, gradients might underflow/overflow in the middle of " f"training and cause NaNs. Please set `loss_scaling_factor` to " f"`dynamic` to use cbfloat16." ) elif self.backend.is_cpu and self.fp16_type != "bfloat16": self.fp16_type = torch.bfloat16 warn( "Mixed precision on CPU is only supported with bfloat16. " "Setting half dtype to bfloat16." ) cstorch.amp.set_half_dtype(self.fp16_type) def setup(self, trainer): if self.fp16_type == "bfloat16": if self.loss_scaling_factor == "dynamic": trainer.logger.info( f"No need to use DLS for loss when half dtype is bfloat16. " f"Disabling gradient scaling." ) self.scaler = None elif self.backend.is_csx: self.scaler = cstorch.amp.GradScaler( loss_scale=self.loss_scaling_factor, init_scale=self.initial_loss_scale, steps_per_increase=self.steps_per_increase, min_loss_scale=self.min_loss_scale, max_loss_scale=self.max_loss_scale, ) else: if self.loss_scaling_factor == "dynamic": if self.initial_loss_scale is None: # This is the default value in PyTorch self.initial_loss_scale = 65536.0 self.scaler = torch.amp.GradScaler( device="cuda", init_scale=self.initial_loss_scale, growth_interval=self.steps_per_increase, ) else: self.scaler = torch.amp.GradScaler( device="cuda", init_scale=self.loss_scaling_factor, growth_interval=2**63 - 1, ) max_gradient_norm = self.max_gradient_norm max_gradient_value = self.max_gradient_value if max_gradient_norm is not None and max_gradient_norm <= 0.0: raise ValueError( f"max_gradient_norm has to be a positive float. Got " f"{max_gradient_norm}" ) if max_gradient_value is not None and max_gradient_value <= 0.0: raise ValueError( f"max_gradient_value has to be a positive float. Got " f"{max_gradient_value}" ) if max_gradient_norm is not None and max_gradient_value is not None: raise ValueError( f"Gradients can be clipped by norm(={max_gradient_norm}) or by " f"value(={max_gradient_value}), but not both. " f"Do not set both `max_gradient_norm` and `max_gradient_value`." ) def autocast_context_manager(self): if not self.enabled or self.backend.is_csx: return nullcontext() else: return cstorch.amp.autocast() def backward(self, loss): if self.scaler: self.scaler.scale(loss).backward() else: loss.backward()
[docs] def on_before_optimizer_step(self, trainer, model, optimizer): """Unscales the gradients and performs gradient clipping.""" if self.scaler: self.scaler.unscale_(optimizer)
def clip_gradients(self, optimizer): params = ( p for param_group in optimizer.param_groups for p in param_group["params"] ) if self.max_gradient_norm is not None: torch.nn.utils.clip_grad_norm_(list(params), self.max_gradient_norm) elif self.max_gradient_value is not None: torch.nn.utils.clip_grad_value_( list(params), self.max_gradient_value ) def optimizer_step(self, optimizer): if self.scaler: self.scaler.step(optimizer) self.scaler.update() else: optimizer.step() def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): if self.scaler and self.log_loss_scale: trainer.log_metrics(loss_scale=self.scaler.get_scale()) def on_save_checkpoint(self, trainer, state_dict): if self.scaler: state_dict["grad_scaler"] = self.scaler.state_dict() def on_load_checkpoint(self, trainer, state_dict): if self.scaler: # TODO: handle conversion between CUDA grad scaler and cstorch grad scaler if "grad_scaler" in state_dict: self.scaler.load_state_dict(state_dict["grad_scaler"]) trainer.logger.info( f"Gradient scaler state found in checkpoint and loaded successfully." ) else: trainer.logger.info( "Gradient scaler state not found in the checkpoint. " "Using default initialized state." )