# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Get information about the current cluster setup."""
import os
from pathlib import Path
from cerebras_appliance.utils.units import bytes_to_human
from cerebras_pytorch.utils.utils import get_dir_size
from .cluster_resolver import TaskRole
from .service_resolver import BaseServiceResolver
from .worker_state import WorkerState
[docs]def get_worker_state():
    """API exposing internal state info captured by each CSX Worker
    for the current run at a checkpoint step. This state info is
    represented in the :py:class:`DataLoaderCheckpoint` dataclass format:
    Returns:
        :py:class:`DataLoaderCheckpoint` instance holding worker state information
        at the checkpoint step
    .. note::
        - This method may only be called inside of a custom implementation of `state_dict` for
        dataloaders conforming to the :py:class:`RestartableDataLoader` protocol, since
        `state_dict` is well-defined only at a checkpoint step.
        - Use this method to save any of the aforementioned state info recorded by each worker
        when defining `state_dict` for custom implementations of restartable dataloaders.
        - This state info captured by each worker is for the current run only, i.e. if you pause and
        restart a run, the counters gathering information returned by this function will be reset.
    """
    return WorkerState.get_worker_state() 
def service_resolver():
    resolver = BaseServiceResolver.get_resolver()
    return resolver
def num_tasks():
    """Returns total number of tasks in the cluster."""
    return service_resolver().cluster_resolver.num_tasks
def num_streamers():
    """Returns total number of tasks responsible for streaming inputs."""
    return len(service_resolver().streamer_ordinals())
def num_receivers():
    """Returns total number of tasks responsible for receiving outputs."""
    return len(service_resolver().receiver_ordinals())
def get_ordinal():
    """Returns the ordinal number of the current task."""
    return service_resolver().cluster_resolver.rank
def get_streaming_rank():
    """Returns the rank of the current task among streamers."""
    streamers = sorted(service_resolver().streamer_ordinals())
    ordinal = get_ordinal()
    assert ordinal in streamers, f"Ordinal {ordinal} is not a streamer."
    return streamers.index(ordinal)
def get_streaming_batch_size(effective_batch_size: int) -> int:
    """Returns the streaming batch size of the current task.
    In a Wafer-Scaler Cluster setup with more than 1 CS-X node, the batch size
    used in compile and specified by user is the effective batch size at
    which gradient updates are done. However, each worker node streams a local
    batch of data to a given CS-X node to consitute data parallel training.
    This helper method returns the local batch size that the current task should
    use given the desired effective batch size.
    Args:
        effective_batch_size: The effective batch size of the model.
    Returns:
        The local batch size to be streamed by this task. If queried on the
        user node (used when compiling the model), this returns the original
        effective batch size as passed in the argument.
    """
    # Do some basic validation
    if not isinstance(effective_batch_size, int):
        raise TypeError(
            f"Expected effective batch size to be an integer, but got type "
            f"{type(effective_batch_size)} with value {effective_batch_size}."
        )
    if effective_batch_size <= 0:
        raise ValueError(
            f"Expected effective batch size to be a positive integer, but got "
            f"value {effective_batch_size}."
        )
    # If not queried on the worker node, return the effective batch size as is
    # so the compile can automatically handle data parallel and gradient
    # accumulation.
    if not is_streamer():
        return effective_batch_size
    # If queried on the worker node, return the local batch size
    num_csx = service_resolver().cluster_spec.num_csx
    if num_csx <= 0:
        raise ValueError(
            f"Expected number of CS-X nodes to be a positive integer, but "
            f"got {num_csx}."
        )
    if effective_batch_size % num_csx != 0:
        raise ValueError(
            f"Expected effective batch size {effective_batch_size} to be a "
            f"multiple of number of CS-X nodes {num_csx}."
        )
    return effective_batch_size // num_csx
def is_master_ordinal(local=False):
    """Returns True if the current task is the master task."""
    # Note: keeping `local` argument for compatibility with XLA API.
    return service_resolver().cluster_resolver.assumes_role(TaskRole.MASTER)
def is_streamer():
    """Returns True if the current task is a streamer task."""
    return get_ordinal() in service_resolver().streamer_ordinals()
def is_receiver():
    """Returns True if the current task is a receiver task."""
    return get_ordinal() in service_resolver().receiver_ordinals()
# constants
SSD_LIMIT = 0.8
WORKER_CACHE_ROOT = "/n0/cache"
def hit_worker_cache_limit(src_dir: str, dest_dir: str):
    """
    Identifies whether copying the src_dir to a dest_dir (within
    worker_cache), will lead to a cache overflow
    Args:
        src_dir (str, required): directory path of the source
        dest_dir (str, required): directory path of the destination within
        the worker cache
    Returns:
        A tuple of (``is_limit_hit``, ``dir_size``, ``available_space_for_copy``)
        where ``is_limit_hit`` is a bool indicating whether cache limit
        will be hit with the copy,
        ``dir_size`` is the size of the src_dir to be copied to the cache,
        ``available_space_for_copy`` is the space available for src_dir copy,
        including the space occupied by the currently cached_dir
        corresponding to src_dir.
    """
    # Raises if dest_dir path is not a descendant of WORKER_CACHE_ROOT
    Path(dest_dir).resolve().relative_to(Path(WORKER_CACHE_ROOT).resolve())
    # Only add things to cache if < SSD_LIMIT occupied
    ssd_mount = WORKER_CACHE_ROOT
    # Get size of SSD mount
    statvfs = os.statvfs(ssd_mount)
    max_size = statvfs.f_frsize * statvfs.f_blocks
    dir_size = get_dir_size(src_dir)
    ssd_available = statvfs.f_frsize * statvfs.f_bavail
    ssd_occupied = max_size - ssd_available
    removal_size = get_dir_size(dest_dir)
    cap = SSD_LIMIT * max_size
    new_size = dir_size + ssd_occupied - removal_size
    is_limit_hit = new_size > cap
    available_space_for_copy = (
        cap - ssd_occupied + removal_size
        if cap > (ssd_occupied - removal_size)
        else 0
    )
    return (
        is_limit_hit,
        bytes_to_human(dir_size),
        bytes_to_human(available_space_for_copy),
    )