(Early access) Port your code using Cerebras PyTorch API#

Note

This API is marked experimental. While, for the most part the design has been finalized, there may be backwards compatibility breaking changes introduced in a future release before it is marked as stable.

High Level Overview#

To leverage the Cerebras PyTorch API for porting and running your model on a Cerebras Wafer-Scale cluster, there are a few high level steps that must be done:

  1. Compile the model

    In order to train or evaluate a model on a Cerebras Wafer-Scale cluster it must be compiled.

  2. Instantiate a Cerebras compliant optimizer (training-only)

    Due to the nature of lazy tensor tracing and execution, the core PyTorch optimizer implementations (having been designed for eager execution) are not compatible with running on Cerebras hardware. We provide a number of drop-in Cerebras compliant replacements for commonly used optimizers.

  3. Construct a training step function

    Also due to the nature of lazy tensor tracing and execution, we need to be able to capture the entirety of computation graph. We capture this inside a training/evaluation step.

  4. Define the training loop

    The training loop is fully exposed and customizable now.

  5. Initializing the DataLoader

    The Cerebras dataloader must be initialized to distribute the dataloader across workers.

  6. Define an Execution run

    The parameters of the execution run must be specified.

There further steps you can perform to supplement and improve the performance and numerics of your run, such as

  1. Gradient scaling

  2. Learning rate scheduling

There are also some caveats that one should be aware of when running on the Cerebras Wafer-Scale Cluster, such assigned

  1. Step Closures

  2. Saving/Loading Checkpoints

Importing cstorch#

Currently the experimental API can be accessed by importing:

import cerebras_pytorch.experimental as cstorch

Once stable, the experimental modules will be merged into the base of cerebras_pytorch.

Compiling a torch.nn.Module#

In order to prepare a model for compilation, we introduce the cstorch.compile function:

model: torch.nn.Module = ...  # Any PyTorch module
compiled_model = cstorch.compile(model, backend="CSX")

We designed this function after the torch.compile function that was introduced in PyTorch 2.0.

Note

The call to cstorch.compile, much like torch.compile does not actually compile the model. It only prepares the model for compile. Meaning that it moves the model parameters to the appropriate torch.device and prepares the internals of cstorch to trace it.

The actual compilation does not happen until the first iteration is complete and the batch size is known.

Note

After calling cstorch.compile the model parameters may no longer be changed or modified as the call to cstorch.compile freezes the parameters to optimize their transfer to the Cerebras Wafer Scale Cluster.

Please ensure that all model weight initialization is complete before the call to cstorch.compile.

Instantiating a backend class#

As can be seen above, you can pass in the backend type to the cstorch.compile function and a backend of that type will automatically be instantiated for you.

Having said that, if you want to pass in any additional backend parameters you can construct a backend object via the backend function and pass that into the compile function instead. For example,

backend = cstorch.backend("CSX", ...)

model: torch.nn.Module = ...  # Any PyTorch module
compiled_model = cstorch.compile(model, backend=backend)

Supported backend types include

Backend Type

Optional Parameters

Parameter name

Description

CSX

artifact_dir

The directory at which to store any Cerebras specific artifacts generated by the backend. Default: $cwd/cerebras_logs

compile_dir

The directory at which to store any compile related artifacts. These compile artifacts are used to cache the compile to avoid recompilation. Default: /opt/cerebras/cached_compile

compile_only

If True, then configure the CSX backend only for compilation. This means that all parameter data is immediately dropped as it isn’t required for compilation. Further, the data executor will not send an execution request to the wafer scale cluster. This mode is intended to verify that the model is able to be compiled. As such, no system is required in this mode. Default: False

validate_only

If True, then configure the CSX backend only for validation. This means that all parameter data is immediately dropped as it isn’t required for validation. Further, the data executor will not send compile and execute requests to the wafer scale cluster. This mode is intended to verify that the model is able to be traced. As such, no system is required in this mode. Default: False

drop_data

If True, all parameter data is immediately dropped even if in a non-compile-only run. In this case, a checkpoint containing values for all stateful tensors must be loaded in order to be able to run. Default: False

max_checkpoints

If provided, cstorch.save will automatically only keep the newest max_checkpoints checkpoints, removing the oldest when the number of checkpoints exceeds the specified number. Default: None

CPU

max_checkpoints

See above description.

Initializing parameters directly for the Cerebras Wafer Scale Cluster#

Initializing the model parameters on the CPU torch.device, whilst completely valid, can be slow for extremely large models and can cause memory issues as the parameters may not fit within RAM and may spill over into swap memory or outright fail to allocate more memory.

For this case you can explicitly instantiate a backend class that contains a Cerebras device that can be used as a context manager much like a torch.device, e.g.

backend = cstorch.backend("CSX", ...)

with backend.device:
    model: torch.nn.Module = ...

# compile the model the same way as before
compiled_model = cstorch.compile(model, backend)

What this does is automatically move the parameters to the Cerebras device which saves the parameter data that will be sent to the Cerebras Wafer Scale Cluster. This frees up memory for subsequent parameters and keeps the overall memory usage low.

Initializing the Optimizer#

Much like in our previous releases, we cannot support the vanilla PyTorch optimizers. This is because the implementations available in core PyTorch were designed with eager execution in mind and are fundamentally incompatible with traced lazy execution.

As such, we provide our own drop in replacements for these optimizers inside cstorch.optim. See the table below for a list of all of the optimizers that are currently available:

For convenience, we also include a configuration helper function:

cstorch.optim.configure_optimizer(
    optimizer_type="...",  # name of the optimizer
    params=...,  # The model parameters
    ...,  # kwargs to be passed into the optimizer class's init
)

This function is useful when you want to initialize an optimizer from some configuration dictionary. An example of its usage:

optimizer_params = {
    "optimizer_type": "SGD",
    "lr": 0.001,
    "momentum": 0.5,
}
optimizer = cstorch.optim.configure_optimizer(
    optimizer_type=optimizer_params.pop("optimizer_type"),
    params=model.parameters(),
    **optimizer_params
)

Defining a custom Cerebras optimizer#

In order to define a Cerebras compliant optimizer, one must create a subclass of cstorch.optim.Optimizer, e.g.

class CustomOptimizer(cstorch.optim.Optimizer):

    def __init__(self, params, ...):
        ...
        defaults = ...
        super().__init__(params, defaults, enable_global_step=...)

    ...

    def preinitialize(self):
        ...

    def step(self, closure=None):
        ...

    def state_names_to_sparsify(self):
        ...

As can be seen in the above example, similar to torch.optim.Optimizer, the base cstorch.optim.Optimizer class expects 3 arguments. Namely, the model parameters, the param group defaults as well as optional enable_global_step which will define a global step state variable for each parameter.

In addition, here are 3 abstract methods that must be overriden:

  1. preinitialize

    This method is used to initialize any state variables that will be used by the optimizer. For example, cstorch.optim.SGD defines its momentum buffers in its preinitialize method.

    Note, in order to remain Cerebras compliant, no optimizer state variables may be initialized outside of the preinitialize method

    For optimal performance, when initializing the state tensors that are filled with some constant value, you can use the creation ops that are available in the cstorch package to lazily initialize them. These ops will lazily initialize and fill the tensor, meaning that they take up very little memory and can be initialized much quicker than their torch counterparts when running on the Cerebras Wafer Scale cluster. Please see the source code for the optimizers in cerebras_pytorch for examples.

  2. step

    This method is where the optimizer step is implemented. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

  3. state_names_to_sparsify

    This method should return the names of the state variables that should be sparsified. Please see the existing optimizer implementations for examples.

Initializing the Learning Rate Scheduler#

Similar to the optimizers, the vanilla PyTorch learning rate schedulers are not compatible with traced lazy execution.

As such, we provide our own drop in replacements for some common schedulers inside cstorch.optim.lr_scheduler. See the table below for a list of all of the learning rate schedulers that are currently available.

Similar to cstorch.optim.configure_optimizer, for convenience, we also include the following configuration helper function for learning rate schedulers.

cstorch.optim.configure_lr_scheduler(
    optimizer=...,  # the optimizer object
    learning_rate=..., # the learning rate configuration
)

The expected format for the learning_rate parameter is one of the following

  1. learning_rate is a python scalar (int or float)

    In this case, configure_lr_scheduler returns an instance of ConstantLR with the provided value as the constant learning rate.

  2. learning_rate is a dictionary

    In this case, the dictionary is expected to contain the key scheduler which contains the name of the scheduler you want to configure.

    The rest of the parameters in the dictionary are passed in a keyword arguments to the specified schedulers init method.

  3. learning_rate is a list of dictionaries

    In this case, we assume what is being configured is a SequentialLR unless the any one of the dictionaries contains the key main_scheduler and the corresponding value is ChainedLR.

    In either case, each element of the list is expected to be a dictionary that follows the format as outlines in case 2.

    If what is being configured is indeed a SequentialLR, each dictionary entry is also expected to contain the key total_iters specifying the total number of iterations each scheduler should be applied for.

Defining a custom Cerebras learning rate scheduler#

In order to define a Cerebras compliant learning rate scheduler, one must create a subclass of cstorch.optim.lr_scheduler.LRScheduler, e.g.

class CustomScheduler(cstorch.optim.lr_scheduler.LRScheduler):

    def __init__(self, optimizer, ...):
        ...
        super().__init__(optimizer, total_iters=..., last_epoch=...)

    ...

    def _get_closed_form_lr(self) -> torch.Tensor:
        ...

As can be seen in the above example, the base cstorch.optim.lr_scheduler.LRScheduler class expects 3 arguments. Namely, the optimizer whose learning rate is being scheduled, and optionally the total number of iterations that the scheduler is scheduled for as well as the last epoch to start on.

In addition, here one abstract method that must be overriden:

  1. _get_closed_form_lr

    This method is where the full scheduler is defined in closed form. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

    This method is expected to return a torch.Tensor that represents the full learning rate schedule as a computed tensor.

    Please see the existing LR scheduler implementations for examples on how to properly define the schedule.

Initializing the DataLoader#

The Cerebras Wafer-Scale cluster makes use of worker nodes to stream data to the system in order to maximize utilization by keeping the input buffers saturated. The workers being their own nodes means that they cannot share a PyTorch dataloader. Hence, we required a mechanism for each worker to be able to initialize their own dataloader.

To facilitate this, we introduce a custom dataloader class:

dataloader = cstorch.utils.data.DataLoader(
    input_fn,
    ...,  # kwargs to be passed into the input_fn
)

It takes in a input_fn parameter which is a callable that takes in some parameters and returns a torch.utils.data.DataLoader:

def input_fn(...) -> torch.utils.data.DataLoader:
    ...

All other parameters that are passed into the DataLoader init are forwarded and passed into the input_fn.

Each worker will call this input function to construct their own dataloader object. This means that some data sharding scheme is required if the intent is for each worker to stream in a unique set of data.

Using Gradient Scaling#

Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.

To facilitate gradient scaling, we introduce a Cerebras implementation of the AMP GradScaler class found in core PyTorch.

grad_scaler = cstorch.amp.GradScaler(
    loss_scale=...,
    init_scale=...,
    steps_per_increase=...,
    min_loss_scale=...,
    max_loss_scale=...,
    overflow_tolerance=...,
    max_gradient_norm=...,
)

It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.

See the below table for a description on each parameter:

Type

Description

loss_scale

str or float

If loss_scale == "dynamic", then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling. (Default: 0.0)

init_scale

float

The initial loss scale value if loss_scale == "dynamic" (Default: None)

steps_per_increase

int

The number of steps after which to increase the loss scaling condition (Default: None)

min_loss_scale

float

The minimum loss scale value that can be chosen by dynamic loss scaling (Default: None)

max_loss_scale

float

The maximum loss scale value that can be chosen by dynamic loss scaling (Default: None)

overflow_tolerance

float

The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded (Default: 0.05)

max_gradient_norm

float

The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect (Default: 0.05) (Note: Only used in pipeline mode)

Its usage is practically identical to the usage of the CUDA AMP GradScaler:

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()

Using automatic mixed precision with bfloat16#

By default, automatic mixed precision uses float16

In order to use automatic mixed precision with bfloat16 instead of float16 please call the following function:

cstorch.amp.use_bfloat16(True)

cstorch.amp.optimizer_step#

We introduce an optional helper function to take care of the details of gradient scaling

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=...,  # optionally perform gradient clipping by norm
    max_gradient_val=...,  # optional perform gradient clipping by value
)

It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.

This is completely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.

Constructing the training step#

In order to compile the full training graph, the entire training step must be captured in its entirety. To handle this we introduce the cstorch.compile_step decorator:

@cstorch.compile_step
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(output, targets)

    cstorch.amp.optimizer_step(
        loss, optimizer, grad_scaler, max_gradient_norm=1.0
    )

    if lr_scheduler:
        lr_scheduler.step()

    return loss

This decorator should wrap some function that encapsulates the entirety of a single training iteration. That is to say, everything that is intended to run on a Cerebras system should be inside this wrapped function.

In addition, no tensor value may be eagerly evaluated at any point inside this training step. This means, no tensor is allowed to be printed, fetched via a debugger, or used as part of a python conditional. Any operation that requires knowing the value of tensor inside the training step will result in an error stating that it is not allowed to read a tensor’s contents outside of a step closure.

Another caveat with the compile_step is that any variables that are not torch tensors will only see their first value. So, for example, if they have an int counter that they increment inside a cstorch.compile_step wrapper, they will see the first value on all iterations. This is because the training step graph is only captured once. Hence, any pure python ops only run once.

Running Compile/Execution on a Cerebras Wafer Scale Cluster#

In order to send a compilation or execution request, you must construct a data executor.

executor = cstorch.utils.data.DataExecutor(...)

The DataExecutor class takes the following arguments

Parameter name

Description

dataloader

The DataLoader object to use for the run (Required).

num_steps

The number or steps to run for. This parameter is required if not only compiling the model. Otherwise, it defaults to 1.

checkpoint_steps

The interval at which to schedule fetching checkpoints from the Cerebras Wafer Scale Cluster. If None, don’t fetch checkpoints. (Default: None)

cs_config

Optionally, a Cerebras configuration (CSConfig) object may be passed in to configure the Cerebras Wafer Scale Cluster. If not provided, the default configuration values will be used. (Default: None)

writer

The SummaryWriter object to be used to write any summarized scalars or tensors to tensorboard. (Default: None)

profiler_activities

The list of activities to profile. By default, the client side rate and global rate are tracked.

Once created, simply iterate through the executor to enter the execution context where compile and/or execution requests will be sent to the Cerebras Wafer Scale Cluster, e.g.

for i, batch in enumerate(executor):
    training_step(batch)
    ...

Note

As of 1.9, we don’t currently support multiple CS runs in a single process. This means that the above executor can only be run/iterated once. Any runs with different configurations must be run in separate processes.

Profiling the executor#

We provide tools through the executor to profile its performance during the run.

Currently the supported activities that can be profiled include

Activity

Description

rate

Client side smoothed samples/second of all the samples added since last queried

global_rate

Non-smoothed samples/second since the beginning of when the executor context was entered

These activities can be specified via the profiler_activities flag to the DataExecutor constructor and can be queried via their names through the executor’s profiler attribute, e.g.

executor = cstorch.utils.data.DataLoader(
    ...,
    profiler_activities=["rate", "global_rate"],
)
...
print(f"Rate: {executor.profiler.rate()}")
print(f"Rate: {executor.profiler.global_rate()}")

Step Closures#

By design, in the execution schema used by the Cerebras Wafer-Scale cluster, the client and the server run asynchronous to each other. This was to prevent the server becoming bottlenecked by any client processes such as disk IO or networking.

However, this means that a computed tensor may not be available to fetch from the server when the client requests it. For example, the call to compile happens in the first iteration of the training loop. Until compile is complete and execution starts on cluster, no tensor is available to fetch.

To handle this, we introduce the concept of a step closure via the step_closure decorator:

@cstorch.step_closure
def closure(...):
    ...

Any tensors that are passed into a “step closure” are fetched from the server and their value is materialized before the closure is actually called. If the tensor is not yet available, it waits until the server “catches up” to the current step and the tensor value is available to be fetched before actually calling the closure.

One caveat regarding values passed into step closures is that the value seen by the step closure is the last value set to that tensor, not the value at point of definition. This means if the tensor is updated inplace after being passed into the step closure, the inplace modified tensor is what gets materialized before being passed into the closure

Saving/Loading Checkpoints#

To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch pickle based checkpoint format, especially when it comes to any models with extremely large weights, such as large language models.

To save a checkpoint, we provide a very familiar cstorch.save function that you can use in exactly the same way as torch.save:

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    ...
}
cstorch.save(state_dict, "<path to save checkpoint to>")

Similarly, we provide a very familiar cstorch.load function that can also be used in exactly the same way as torch.load:

state_dict = cstorch.load("<path to save checkpoint to>")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
...

Checkpoint Closures#

It is only possible to fetch weights on predetermined checkpoint steps configured in cstorch.configure. The reason this is so, is to make training more performant.

For example, if the configuration was checkpoint_steps=100, you are only allowed to fetch the weights to take a checkpoint every 100th step and at the very end on the last step.

To aid this, you can use the checkpoint_closure decorator which is a step closure that checks that the current step is a checkpoint step before calling the function. In addition, using this decorator ensures that the weights are available to fetch from the server before they can be saved to the checkpoint file.

@cstorch.checkpoint_closure
def save_checkpoint():
    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        ...
    }
    cstorch.save(state_dict, "<path to save checkpoint to>")

Converting checkpoints to Pickle-based format#

If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:

state_dict = cstorch.load("<path to H5 checkpoint>", map_location="cpu")
torch.save(state_dict, "<path to torch checkpoint>")

Note, this will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.

Training Loop#

The training loop is now fully exposed and customizable. An very basic example of a training loop could be

@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
    print("Loss: ", loss.item())

for i, batch in enumerate(dataloader):
    loss = training_step(batch)

    post_training_step(loss)

    if i % checkpoint_steps == 0:
        save_checkpoint()

Please see the Full Example (Training) for a comprehensive example of the full training API

Evaluation Metrics#

We provide Cerebras compatible metrics that can be used to during evaluation to measure how well the model has trained.

They are found in the metrics.Metric module. See the table below for a list of all of the metrics that are currently available:

These metrics class keep an internal state and will return the final computed value. Please see the Full Example (Evaluation) to see how these metrics may be used

Full Example (Training)#

Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample training script.

import torch
import cerebras_pytorch.experimental as cstorch

backend = cstorch.backend("CSX", ...)

with backend.device:
    # user defined model
    model: torch.nn.Module = ...

compiled_model = cstorch.compile(model, backend)

loss_fn: torch.nn.Module = ...

optimizer: cstorch.optim.Optimizer = cstorch.optim.configure_optimizer(
    optimizer_type="...",
    params=model.parameters(),
    ...
)
lr_scheduler: cstorch.optim.lr_scheduler.LRScheduler = cstorch.optim.configure_lr_scheduler(
    optimizer, learning_rate=...,
)

grad_scaler = None
if loss_scale != 0.0:
    grad_scaler = cstorch.amp.GradScaler(...)

@cstorch.checkpoint_closure
def save_checkpoint(step):
    checkpoint_file = f"checkpoint_{step}.mdl"

    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    if lr_scheduler:
        state_dict["lr_scheduler"] = lr_scheduler.state_dict()
    if grad_scaler:
        state_dict["grad_scaler"] = grad_scaler.state_dict()

    state_dict["global_step"] = step

    cstorch.save(state_dict, checkpoint_file)

global_step = 0

# Load checkpoint if provided
if checkpoint_path is not None:
    state_dict = cstorch.load(checkpoint_path)

    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dict(state_dict["optimizer"])
    if lr_scheduler:
        lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
    if grad_scaler:
        grad_scaler.load_state_dict(state_dict["grad_scaler"])

    global_step = state_dict.get("global_step", 0)

@cstorch.compile_step
def training_step(batch):
    inputs, targets = batch
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    cstorch.amp.optimizer_step(
        loss, optimizer, grad_scaler, max_gradient_norm=1.0
    )

    return loss

@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
    print("Loss: ", loss.item())

dataloader = cstorch.utils.data.DataLoader(
    train_dataloader_fn,
    ...
)
executor = cstorch.utils.data.DataExecutor(
    dataloader=dataloader,
    num_steps=1000,
    chekpoint_steps=100,
    cs_config=cstorch.utils.CSConfig(...),
)

for i, batch in enumerate(executor):
    loss = training_step(dataloader)

    post_training_step(loss)

    # Always call save_checkpoint, but is only truly
    # run every 100 steps
    save_checkpoint(i)

Full Example (Evaluation)#

Evaluation is less complex compared to training. There is no optimizer or gradient scaler that needs to be initialized.

Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample eval script.

import torch
import cerebras_pytorch.experimental as cstorch
import cerebras_pytorch.experimental.metrics as metrics

backend = cstorch.backend("CSX", ...)

with backend.device:
    model: torch.nn.Module = ...

compiled_model = cstorch.compile(model, backend)
compiled_model.eval()

loss_fn: torch.nn.Module = ...

accuracy = metrics.AccuracyMetric("accuracy", compute_on_system=True)

state_dict = cstorch.load("<path to checkpoint file>")
model.load_state_dict(state_dict["model"])

dataloader = cstorch.utils.data.DataLoader(
    eval_dataloader_fn,
    ...
)
executor = cstorch.utils.data.DataLoader(
    dataloader,
    num_steps=100,
    cs_config=cstorch.utils.CSConfig(...)
)

@cstorch.compile_step
def evaluation_step(batch):
    inputs, targets = batch
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    accuracy(
        labels=targets.clone(),
        predictions=outputs.argmax(-1).int(),
    )

    return loss

total_loss = 0
total_steps = 0

@cstorch.step_closure
def post_eval_step(loss: torch.Tensor):
    global total_loss

    total_loss += loss.item()
    total_steps += 1


with torch.no_grad():
    for batch in executor:
        loss = evaluation_step(batch)

        post_eval_step(loss)

print(f"Eval Accuracy: {float(accuracy)}"))