# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Directory containing the implementations of the various API backends"""
from enum import Enum, auto
import torch
class BackendType(Enum):
    """
    The enum class used to distinguish which Cerebras backend to use
    """
    CPU = auto()
    GPU = auto()
    # synonyms
    CSX = auto()
    WSE = auto()  # deprecated
    @property
    def is_cpu(self):
        """Returns True if the backend is for the CPU"""
        return self == BackendType.CPU
    @property
    def is_gpu(self):
        """Returns True if the backend is for the GPU"""
        return self == BackendType.GPU
    @property
    def is_csx(self):
        """Returns True if the backend is for the Cerebras wafer scaler cluster"""
        return self in (BackendType.CSX, BackendType.WSE)
class BackendMeta(type):
    """
    The metaclass for Backend to ensure only one backend class is ever
    instantiated
    """
    instance = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls.instance:
            cls.instance[cls] = super(BackendMeta, cls).__call__(
                *args, **kwargs
            )
        else:
            raise RuntimeError(
                f"Cannot instantiate multiple backends. "
                f"A backend with type {cls.instance[cls].backend_type.name} "
                f"has already been instantiated."
            )
        return cls.instance[cls]
class Backend(metaclass=BackendMeta):
    """ Externally facing Cerebras backend class """
    # Only if True, initialize the backend implementation
    _init_impl: bool = True
    def __init__(self, backend_type: BackendType, *args, **kwargs):
        assert isinstance(backend_type, BackendType)
        self.backend_type = backend_type
        if not self._init_impl:
            return
        if self.backend_type == BackendType.CSX:
            from .ltc_backend import PyTorchLtcBackendImpl
            self._impl = PyTorchLtcBackendImpl(
                self.backend_type, *args, **kwargs
            )
        elif self.backend_type == BackendType.CPU:
            from .cpu_backend import CpuBackendImpl
            self._impl = CpuBackendImpl(self.backend_type, *args, **kwargs)
        elif self.backend_type == BackendType.GPU:
            from .gpu_backend import GpuBackendImpl
            self._impl = GpuBackendImpl(self.backend_type, *args, **kwargs)
        else:
            raise ValueError(
                f"{self.backend_type.name} backend not yet supported. "
                f"Supported backends include: CSX, CPU, GPU"
            )
    @property
    def artifact_dir(self):
        """Returns the artifact directory being used by the backend"""
        return self._impl.config.artifact_dir
    @property
    def device(self):
        """Returns the Cerebras device being used by the backend"""
        return self._impl.device
    @property
    def torch_device(self):
        """Returns the underlying PyTorch device being used by the backend"""
        return self._impl.device.torch_device
    # alias properties from backend type
    is_cpu = property(lambda self: self.backend_type.is_cpu)
    is_gpu = property(lambda self: self.backend_type.is_gpu)
    is_csx = property(lambda self: self.backend_type.is_csx)
[docs]def backend(backend_type: str, *args, **kwargs):
    """Instantiates a backend with the given type"""
    if isinstance(backend_type, str):
        backend_type = backend_type.upper()
        if backend_type not in BackendType.__members__:
            raise ValueError(
                f"Invalid Cerebras PyTorch backend type specified. "
                f"Expected one of {list(BackendType.__members__)}. "
                f"Got {backend_type}. "
            )
        backend_type = BackendType[backend_type]
    elif not isinstance(backend_type, BackendType):
        raise TypeError(
            f"Expected backend_type to be of type BackendType, "
            "or a string representing the backend type. "
            f"Got: {type(backend_type)}"
        )
    return Backend(backend_type, *args, **kwargs) 
[docs]def current_backend(raise_exception: bool = True):
    """Gets instance of the current backend
    Args:
        raise_exception: If True, raise an exception if no backend has been
            instantiated. Otherwise return None
    """
    if Backend not in BackendMeta.instance:
        if raise_exception:
            raise RuntimeError(
                "No active Cerebras backend found. Please make sure that "
                "your model has been prepared for compilation.\n"
                "You can do this using a call to:\n\n"
                "\tcompiled_model = cstorch.compile(model, backend=...)\n\n"
                "Or by explicitly instantiating a backend, e.g.\n\n"
                "\tbackend = cstorch.backend(...)"
            )
        return None
    return BackendMeta.instance[Backend] 
[docs]def current_torch_device():
    """
    Gets the torch device of the current backend.
    Returns torch.device('cpu') if no backend has been initialized yet
    """
    _backend = current_backend(raise_exception=False)
    if _backend is None:
        return torch.device("cpu")
    # pylint: disable=protected-access
    return _backend._impl.torch_device 
def current_backend_impl(raise_exception: bool = True):
    """Returns the implementation of the current backend class.
    Args:
        raise_exception: If True, raise an exception if no backend has been
            instantiated.
    Returns:
        The backend implementation if one exists, otherwise None.
    Raises:
        RuntimeError: If no backend has been instantiated and `raise_exception`
            is True.
    """
    _backend = current_backend(raise_exception=raise_exception)
    if _backend is None:
        return None
    # pylint: disable=protected-access
    return _backend._impl
[docs]def use_cs():
    """Returns True if the active device is a CSX device"""
    _backend = current_backend(raise_exception=False)
    return _backend is not None and _backend.is_csx