Source code for cerebras.pytorch.amp._amp_state

# ###############################################################
# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
#
# From original Apex:
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
#     1. Redistributions of source code must retain the above copyright notice,
#        this list of conditions and the following disclaimer.
#
#     2. Redistributions in binary form must reproduce the above copyright notice,
#        this list of conditions and the following disclaimer in the documentation
#        and/or other materials provided with the distribution.
#
#     3. Neither the name of the copyright holder nor the names of its contributors
#        may be used to endorse or promote products derived from this software without
#        specific prior written permission.
#
#        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#        AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#        WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
#        IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
#        INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
#        NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
#        PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
#        WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#        ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#        POSSIBILITY OF SUCH DAMAGE.
#
#################################################################

from typing import Literal, Union

import torch

import cerebras.pytorch.distributed as dist
from cerebras.appliance.environment import appliance_environ

_ENABLED_ENV_VAR = "CEREBRAS_MIXED_PRECISION"
_DTYPE_ENV_VAR = "CEREBRAS_FP16_DTYPE"
HalfDtypeLiteral = Literal["float16", "bfloat16", "cbfloat16"]


class AmpState:
    def __init__(self):
        self.hard_override = False
        self.allow_incoming_model_not_fp32 = False
        self.verbosity = 1
        self._enabled = False
        self._real_dtype_str = "float16"

    @property
    def enabled(self):
        if not dist.is_master_ordinal():
            return bool(
                int(appliance_environ.get(_ENABLED_ENV_VAR, self._enabled))
            )
        else:
            return self._enabled

    @enabled.setter
    def enabled(self, value):
        if not dist.is_master_ordinal():
            raise RuntimeError(
                "Enabling/disabling mixed precision in the dataloader is not allowed as it might "
                "conflict with what the model was compiled with. Please ensure to enable/disable "
                "mixed precision outside of the dataloader before constructing the model."
            )
        if not isinstance(value, bool):
            raise ValueError(f"Expected a boolean value, got: {type(value)}")

        self._enabled = value
        appliance_environ[_ENABLED_ENV_VAR] = str(int(value))

    @property
    def half_dtype(self) -> torch.dtype:
        dtype_str = self.half_dtype_str

        if dtype_str == "bfloat16":
            return torch.bfloat16
        elif dtype_str == "float16":
            return torch.float16
        elif dtype_str == "cbfloat16":
            return torch.float16  # proxy dtype
        else:
            assert False, f"Invalid dtype str: {dtype_str}"

    @half_dtype.setter
    def half_dtype(self, value: Union[HalfDtypeLiteral, torch.dtype]):
        if not dist.is_master_ordinal():
            raise RuntimeError(
                "Setting half dtype in the dataloader is not allowed as it might conflict with "
                "what the model was compiled with. Please ensure to set the half dtype outside "
                "of the dataloader before constructing the model."
            )

        if value == torch.float16:
            self._real_dtype_str = "float16"
        elif value == torch.bfloat16:
            self._real_dtype_str = "bfloat16"
        elif isinstance(value, str) and value in [
            "float16",
            "bfloat16",
            "cbfloat16",
        ]:
            self._real_dtype_str = value
        else:
            raise ValueError(
                f"Invalid half dtype: {value}. Accepted values are: "
                f"\"float16\", \"bfloat16\", \"cbfloat16\", {torch.float16}, {torch.bfloat16}."
            )

        # Setting the half dtype implies that mixed precision is being used
        self.enabled = True

        appliance_environ[_DTYPE_ENV_VAR] = self._real_dtype_str

    @property
    def half_dtype_str(self) -> HalfDtypeLiteral:
        # TODO: Temporarily read the value in workers through an env variable. Once RT IR has the
        # value in the module, we should read it from there instead.
        if not dist.is_master_ordinal():
            return appliance_environ.get(_DTYPE_ENV_VAR, self._real_dtype_str)
        else:
            return self._real_dtype_str

    def get_floating_point_dtype(self, default=torch.float32) -> torch.dtype:
        if self.enabled:
            return self.half_dtype
        return default

    def get_floating_point_dtype_str(self, default="float32") -> str:
        if self.enabled:
            return self.half_dtype_str
        return default


# Attribute stash.  Could also just stash things as global module attributes.
_amp_state = AmpState()


def warn_or_err(msg):
    if _amp_state.hard_override:
        print("Warning:  " + msg)
    else:
        raise RuntimeError(msg)


def maybe_print(msg):
    if _amp_state.verbosity > 0:
        print(msg)


def mixed_precision() -> bool:
    return _amp_state.enabled


def enable_mixed_precision(enable: bool = True) -> bool:
    _amp_state.enabled = enable
    return _amp_state.enabled


[docs]def set_half_dtype(value: Union[HalfDtypeLiteral, torch.dtype]) -> torch.dtype: """Sets the underlying 16-bit floating point dtype to use. Args: value: Either a 16-bit floating point torch dtype or one of "float16", "bfloat16", or "cbfloat16" string. Returns: The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as `value` passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation. """ _amp_state.half_dtype = value return _amp_state.half_dtype
def get_half_dtype() -> torch.dtype: """Gets the 16-bit floating point dtype to use in the model. This returns the value set through `set_half_dtype()`. """ return _amp_state.half_dtype def get_half_dtype_str() -> str: """ Gets the string representation of the 16-bit floating point dtype to use in the model. """ return _amp_state.half_dtype_str def is_cbfloat16_tensor(tensor: torch.Tensor) -> bool: """Return true if tensor dtype is cbfloat16.""" return ( tensor.dtype == _amp_state.half_dtype and _amp_state.half_dtype_str == "cbfloat16" ) def get_floating_point_dtype(default=torch.float32) -> torch.dtype: """Gets the floating point dtype to use in the model. This returns the value set through `set_half_dtype()` if mixed precision is enabled, otherwise it returns `torch.float32`. """ return _amp_state.get_floating_point_dtype(default) def get_floating_point_dtype_str(default="float32") -> str: """ Gets the string representation of the floating point dtype to use in the model. This returns the value set through `set_half_dtype()` if mixed precision is enabled, otherwise it returns "float32". """ return _amp_state.get_floating_point_dtype_str(default)