Step closures#
Overview#
By design, in the execution schema used by the Cerebras Wafer-Scale cluster, the client and the server run asynchronously to each other. This was to prevent the server from becoming bottlenecked by any client processes such as disk I/O or networking.
However, this means that a computed tensor may not be available to fetch from the server when the client requests it. For example, the call to compile happens in the first iteration of the training loop. Until compile is complete and execution starts on the cluster, no tensor is available to fetch.
To handle this, we introduce the concept of a step closure via the
step_closure decorator, e.g.
@cstorch.step_closure
def closure(loss):
    print(f"Loss: {loss}")
Tensors that are passed into a “step closure” are fetched from the server and their value is materialized before the closure is actually called. If the tensor is not yet available; it waits until the server “catches up” with the current step, and the tensor value is available to be fetched before actually calling the closure.
Example Usage#
You can call step closures either inside the traced step function or outside of it:
@cstorch.step_closure
def check_nan(loss):
    assert not torch.isnan(loss).any()
@cstorch.step_closure
def print_loss(loss):
    print(f"Loss: {loss}")
@cstorch.trace
def training_step(inputs, targets):
     outputs = compiled_model(inputs)
     loss = loss_fn(outputs, targets)
     check_nan(loss)
     loss.backward()
     optimizer.step()
     optimizer.zero_grad()
     return loss
for inputs, targets in executor:
    loss = training_step(inputs, targets)
    print_loss(loss)