# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Defines the Cerebras DataLoader class and RestartableDataLoader protocol class."""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Union
import torch
from typing_extensions import Protocol, runtime_checkable
from cerebras.appliance.log import ClassLogger, named_class_logger
from cerebras.pytorch.backend import Backend, current_backend_impl
from cerebras.pytorch.utils.data.utils import infer_batch_size
@named_class_logger
class DataLoader(ClassLogger):
"""
Wrapper around torch.utils.data.DataLoader that facilitates
moving data generated by the dataloader to a Cerebras system
Args:
input_fn: A callable that returns a torch.utils.data.DataLoader
instance or an iterable that returns a structure containing torch
tensors.
*args, **kwargs: Any other positional or keyword arguments
are passed into the input_fn when each worker instantiates
their respective dataloaders
"""
_id_counter = 0
STATE_UNKNOWN = object()
STATE_UNAVAILABLE = object()
def __init__(
self,
input_fn: Callable[..., Union[torch.utils.data.DataLoader, Iterable]],
*args,
**kwargs,
):
if not callable(input_fn):
raise TypeError(
"Expected a callable that constructs and returns a "
"`torch.utils.data.DataLoader` or an iterable that "
"returns a structure containing torch tensors."
)
# Properties accessed by the backend
DataLoader._id_counter += 1
self.id = DataLoader._id_counter
self.input_fn = input_fn
self.input_fn_params = deepcopy((args, kwargs))
self.cached_state = self.STATE_UNKNOWN
self.dataloader = input_fn(*args, **kwargs)
if isinstance(self.dataloader, torch.utils.data.DataLoader):
original_persistent_workers = self.dataloader.persistent_workers
original_num_workers = self.dataloader.num_workers
try:
self.dataloader.__initialized = False
self.dataloader.persistent_workers = False
self.dataloader.num_workers = 0
# If the original num workers is greater than zero and a
# worker_init_fn was provided, we need to call it with
# worker_id=0 to ensure that the dataloader is initialized
# correctly.
if (
self.dataloader.worker_init_fn is not None
and original_num_workers > 0
):
self.dataloader.worker_init_fn(0)
except:
# If worker_init_fn fails, we still want to restore the
# original values of persistent_workers and num_workers
self.dataloader.persistent_workers = original_persistent_workers
self.dataloader.num_workers = original_num_workers
finally:
self.dataloader.__initialized = True
self.batch_size = None
@property
def is_restartable(self) -> bool:
"""Returns True if dataloader is restartable."""
return isinstance(self.dataloader, RestartableDataLoader)
@property
def _backend(self) -> Backend:
"""Returns the current backend implementation."""
return current_backend_impl()
def state_dict(self) -> Dict[str, Any]:
"""Returns dataloader state to save in a checkpoint
by invoking the saving mechanism of the
:py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.
Returns:
`dict` capturing dataloader state as specified in the
implementation of the dataloader's `aggregate_state_dict`
method
"""
if not self.is_restartable:
raise RuntimeError(
f"DataLoader is not configured for getting state. "
f"Please implement {RestartableDataLoader.__name__} interface "
f"to enable `state_dict()` and `load_state_dict()` methods."
)
if not (
self._backend.run_context.is_checkpoint_step
or self._backend.run_context.is_pre_initial_step
or self._backend.run_context.is_final_step
):
raise RuntimeError(
"DataLoader state can only be requested at a checkpoint step. Please "
"ensure that `state_dict` is called on the `cstorch.utils.DataLoader` "
"at a checkpoint step. If you're calling it inside of a method, please "
"decorate it with the `cstorch.checkpoint_closure` method decorator."
)
# If the state is not known, we need to query it from somewhere, cached it, and return it.
if self.cached_state in [self.STATE_UNKNOWN, self.STATE_UNAVAILABLE]:
self._configure_worker_state(
self._backend.run_context.user_iteration
)
if self._backend.backend_type.is_csx:
if self._backend.run_context.is_pre_initial_step:
if self.cached_state is not self.STATE_UNKNOWN:
raise RuntimeError(
"Invalid dataloader cached state! At the pre-initial step, the cached "
"state should not be STATE_UNAVAILABLE."
)
# If no state was loaded but a state is requested before the run has started,
# call the dataloader directly to return its current state.
return deepcopy(
self.dataloader.aggregate_state_dict(
[self.dataloader.state_dict()]
)
)
else:
# Fetch state from the appliance workers
worker_states: List[
DataLoaderCheckpoint
] = self._backend.appliance.grpc_client.fetch_dataloader_state(
self._backend.run_context.user_iteration
)
# For aggregation, we only pass the per WRK state dict users explicitly
# chose to save in their `state_dict` implementation.
self.cached_state = self.dataloader.aggregate_state_dict(
[
worker_state.user_state_dict
for worker_state in worker_states
]
)
else:
self.cached_state = self.dataloader.aggregate_state_dict(
[self.dataloader.state_dict()]
)
return deepcopy(self.cached_state)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads dataloader state from the provided `state_dict`
by invoking the loading mechanism of the
:py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.
Args:
state_dict: dict capturing dataloader state loaded from a
checkpoint
"""
if not self.is_restartable:
raise RuntimeError(
f"DataLoader is not configured for setting state. "
f"Please implement {RestartableDataLoader.__name__} interface "
f"to enable `state_dict()` and `load_state_dict()` methods."
)
if (
self._backend.in_run_context
and not self._backend.run_context.is_pre_initial_step
):
raise RuntimeError(
"DataLoader state can only be loaded onto before execution. "
"Please make sure to call `load_state_dict()` only before "
"iterating the data executor."
)
if self._backend.backend_type.is_csx:
self.cached_state = state_dict
else:
self._configure_worker_state(0)
self.dataloader.load_state_dict(
self.dataloader.deaggregate_state_dict(state_dict)
)
def __len__(self):
return len(self.dataloader)
def __iter__(self):
self._configure_worker_state(0)
for batch in self.dataloader:
self.batch_size = infer_batch_size(batch, self.batch_size)
yield batch
def _configure_worker_state(self, step: int):
from cerebras.pytorch.distributed.worker_state import WorkerState
WorkerState.configure(
DataLoaderCheckpoint(
local_worker_id=0,
num_workers_per_csx=1,
num_csx=1,
wse_id=0,
appliance_step=step,
worker_step=step,
samples_streamed=step * self.batch_size if step > 0 else 0,
user_state_dict=None,
)
)
[docs]@runtime_checkable
class RestartableDataLoader(Protocol):
"""Defines interface for the restartable dataloader protocol."""
[docs] def state_dict(self) -> Dict[str, Any]:
"""Use this method to specify what state information should be saved
by each CSX Worker.
Returns:
dict holding state information for the CSX Worker
In order to access Cerebras internal data checkpoint info per
CSX Worker at some checkpoint step, follow the steps in the example
below. Cerebras internal data checkpoint format is recorded in the
:py:class:`~cerebras.pytorch.utils.data.DataLoaderCheckpoint` dataclass.
Usage:
::
import cerebras.pytorch as cstorch
...
def state_dict(self) -> Dict[str, Any]:
worker_state = cstorch.distributed.get_worker_state()
state_dict = {}
if worker_state:
state_dict["worker_step"] = worker_state.worker_step
state_dict["worker_id"] = worker_state.global_worker_id
return state_dict
.. note::
The call to :py:func:`~cerebras.pytorch.distributed.get_worker_state`
is well-defined only inside of the `state_dict` method; using this
anywhere else will result in a RuntimeError exception. See linked
docs for more details.
"""
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Use this method to load CSX Worker state for the dataloader instance,
as captured from a previous run.
Args:
state_dict: dict holding worker state info, specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.deaggregate_state_dict`
Usage:
::
def load_state_dict(self, state_dict):
wrk_state_dict = state_dict.get("worker_0", {})
worker_step = wrk_state_dict.get("worker_step", 0)
worker_id = wrk_state_dict.get("worker_id")
print(f"WRK {worker_id} loaded step: {worker_step}")
"""
[docs] def aggregate_state_dict(
self, worker_states: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Use this method to specify how to combine the list of CSX Worker state dicts.
Each CSX Worker state in the `worker_states` list is to be specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.state_dict`
Returns:
The consolidated state dict that will be saved in a checkpoint.
Usage:
::
def aggregate_state_dict(self, worker_states):
return {
"worker_0": worker_states[0],
"worker_1": worker_states[1]
}
"""
[docs] def deaggregate_state_dict(
self, aggregated_state_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""Use this method to specify how to load an individual CSX Worker state given
a consolidated list of state dicts, as specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.aggregate_state_dict`.
Returns:
The state dict will be passed to the above-defined
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.load_state_dict` method.
Usage:
::
def deaggregate_state_dict(self, aggregated_state_dict):
return {
"worker_0": aggregated_state_dict.get("worker_0", {})
}
"""
@dataclass
class DataLoaderCheckpoint:
"""Dataclass representing the Cerebras internal dataloader checkpoint format.
Each CSX Worker captures its state information via this class at a checkpoint
step.
Attributes:
global_worker_id:
ID of this worker amongst all other workers across all boxes
local_worker_id:
ID of this worker amongst all other workers across the same box
total_num_workers:
The total number of workers for the run across all boxes
num_workers_per_csx:
The total number of workers per box for the run
num_csx:
The total number of CSXs (boxes) for the run
wse_id:
ID of the Wafer-Scale Engine (CSX) to which this worker streams data
appliance_step:
The appliance step at which this checkpoint state info is captured
worker_step:
The worker step at which this state info is captured. Note that this
is simply equal to `appliance_step` if `num_workers_per_csx = 1`;
for the multi-worker scenario, the appliance step is distributed
across workers on a single box in a round-robin fashion based on
the local worker id
samples_streamed:
The total number of samples streamed by this worker at checkpoint
step. This is simply `worker_step` * `per_box_batch_size`
.. note::
`appliance_step`, `worker_step` and `samples_streamed` are the attributes
that vary across different steps; whereas the other attributes provide
constant state information for the current run.
"""
local_worker_id: int
num_workers_per_csx: int
num_csx: int
wse_id: int
appliance_step: int
worker_step: int
samples_streamed: int
# User-defined state dict for the CSX Worker. This object must be picklable.
user_state_dict: Dict[str, Any]
@property
def global_worker_id(self) -> int:
return self.wse_id * self.num_workers_per_csx + self.local_worker_id
@property
def total_num_workers(self) -> int:
return self.num_workers_per_csx * self.num_csx