Weight Streaming Execution#
The execution mode refers to how the Cerebras runtime loads your neural network model onto the Cerebras Wafer-Scale Engine (WSE). The Cerebras Wafer-Scale Cluster supports Weight streaming mode. In this mode, one layer of the neural network model is loaded at a time. This layer-by-layer mode is used to run large models, models for which one layer’s weights fit in memory, but the entire model does not.
To illustrate how Weight Streaming works, we will use the example shown below of a 3-layer FC-MNIST network:
Weight Streaming Execution during runtime#
At runtime, one layer is loaded onto the WSE at each step, as shown in Fig. 3. The weights are streamed from the MemoryX server to the WSE in forward propagation. In backpropagation, the weights are streamed again from MemoryX to the WSE, weight gradients are computed on the WSE and then streamed from the WSE to MemoryX for storage and for learning, in which the weights are adjusted using the weight gradient and the learning algorithm.
During the compilation job:
The Cerebras compiler extracts the graph of operations from the code and maps the operations to the supported kernels of the Cerebras Software Platform. Each such matched kernel constitutes a layer in the network dataflow graph. A mathematical representation of these layers is shown in Fig. 4. If you are interested in seeing the result of this lightweight phase of compilation, you can use the
--validate_only
flag.The Cerebras compiler plans the mapping of one kernel/layer at a time to the whole WSE, first for the forward, then for the backward prop passes. If multiple CS-2 systems are requested in the Cerebras Wafer-Scale cluster, then for every layer, the same mapping is used across all CS-2 systems. If you are interested in doing precompilation, you can use the
--compile_only
flag.
Training starts:
Forward propagation, as shown in Fig. 5
Layer 1 is loaded onto the WSE first. Input pre-processing servers process and stream training data to the WSE. MemoryX streams layer 1’s weights into the WSE. If multiple CS-2s are requested in the training job, SwarmX broadcasts the weights from MemoryX to the WSEs. The batch of training samples is sharded into equally large subsets of training examples, with one shard going to each of the CS-2s. This technique is known as data parallelism.
Each WSE performs the layer one forward computation in parallel with the others.
The computed activations for layer 1 remain in WSE memory.
Next, MemoryX broadcasts the weights of layer 2 to the WSEs.
Each WSE performs the layer two forward computations using its stored layer 1 activation.
The same process for layer three.
In this manner, the forward compute for each layer is performed by using the stored computed activations of the prior layer. In turn, the computed activations of the current layer are stored on the WSE memory to be used by the next loaded layer.
At the loss layer, the ground truth labels from the training data are used to compute the network loss delta, which is the gradient of the scalar loss with respect to the output layer (layer 3) activation. This loss delta is used to compute layer-by-layer deltas and weight gradients during the backward pass.
Backward propagation, as shown in Fig. 6
The layer three weights are broadcast from the MemoryX to WSEs, which perform the gradient and delta computations for layer three. (The implementation can, of course, retain the weights of this, the output layer, to save time.)
The layer three gradients are streamed out of the WSE to the SwarmX , which reduces (adds together) the weights from the multiple WSEs and presents their sum to MemoryX. Then MemoryX uses the aggregate gradient in the learning algorithm to update its stored copy of the layer weights.
Next, the layer two weights are streamed from the MemoryX to the WSE, and the WSE similarly performs the gradient and delta computations for the layer two. The layer two gradients are then streamed out of the WSE to the MemoryX where weight updates occur. If multiple CS-2s are requested in the training job, SwarmX broadcasts the weights from the MemoryX to the WSEs and reduces the gradient updates from the WSEs to the MemoryX.
The backward pass continues this way until layer one gradients are streamed out to the MemoryX where the weights are updated and the forward pass for the next training batch begins with the updated weights.
Meanwhile, in the user node,
As loss values are computed on the WSEs, they are reduced by SwarmX and sent to the user node.
All the weights can be downloaded from MemoryX to the user node at specified intervals to save checkpoints.