Limitations of PyTorch on Cerebras#
Floating Point Precision#
Only mixed precision is supported for training a model on the Cerebras system. Weights are stored as float32
but the computation other than the weight update happens in a combination of float32
and bfloat16
or float16
. Casts are automatically inserted; users do not need to insert them manually. See Control numerical precision level for more information on switching between precision optimization levels.
Note
There are no plans to support other precision modes at this time.
Static Graphs#
As of the 1.8.0 software release, users may not reprogram the Wafer-Scale Engine after initial programming. This means that multiple compiles are not supported and therefore the PyTorch compute graph must not change between iterations.
This means that there are a number of caveats as to how the training loop is allowed to be constructed, all of which are already addressed in our custom PyTorch runner classses. Refer to our implementations of the various hooks mentioned in PyTorch Runners.
Learning Rate Scheduler#
Currently, we do not support the typical PyTorch learning rate scheduler paradigm. A typical PyTorch learning scheduler would compute a learning rate scalar and set the values of the learning rates in the optimizer parameter groups. However, due to current limitations of the system requiring static graphs, we cannot support this behavior.
Scheduling Learning Rate#
We must specify the entire learning rate schedule
as a function of the global step. This means that the learning rate becomes less
of a scalar value and more of a tensor that depends on the value of the global
step. See modelzoo.common.pytorch.optim.lr_scheduler
for examples of this.
This does also mean that any optimizers being used need to be written in a way
such that the learning rate is not treated as a scalar value, but rather as a
tensor. See modelzoo.common.pytorch.optim.AdamBase
for an example of
this.
Eval metrics#
Eval metrics are only allowed to return a single result value. Therefore, only the final metric value should be returned. No intermediate state or values can be retrieved at this time.