Taking checkpoints off a Cerebras System
On This Page
Taking checkpoints off a Cerebras System¶
Legacy Mode¶
In the legacy mode, the checkpoint format is the same as a typical PyTorch
workflow. It can thus be loaded using torch.load
.
Appliance Mode¶
In the new appliance mode, we define a new checkpoint format. The reason that this was necessary is that the pre-existing PyTorch checkpoint could not support saving extremely large models for which appliance mode was designed.
The new checkpoint format is based off the H5 file format.
At a high level, we took a PyTorch state dict, flattened it and stored it in an H5 file. For example, the following state dict:
{
"a": {
"b": 0.1,
"c": 0.001,
},
"d": [0.1, 0.2, 0.3]
}
Would be flattened and stored into the H5 file as follows
{
"a.b": 0.1,
"a.c": 0.001,
"d.0": 0.1,
"d.1": 0.2,
"d.2": 0.3,
}
A model/optimizer state dict can be saved in the new checkpoint format using the
cbtorch.save
method. e.g.
import cerebras_pytorch as cbtorch
...
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
cbtorch.save(state_dict, "path/to/checkpoint")
...
A checkpoint saved using the above can be loaded using the cbtorch.load
method. e.g.
import cerebras_pytorch as cbtorch
...
state_dict = cbtorch.load("path/to/checkpoint")
model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])
...
Note
If using the run.py
scripts provided in the ModelZoo the above is all
already taken care of in the runners used in the ModelZoo.
Converting Checkpoint Formats¶
If using cbtorch.load
is not a sufficient solution for loading the
checkpoint into memory, a simple conversion can be done to the pickle format
that PyTorch uses as follows
import torch
import cerebras_pytorch as cbtorch
state_dict = cbtorch.load("path/to/checkpoint")
torch.save(state_dict, "path/to/new/checkpoint")
Warning
This will not work for extremely large models whose state dict is too large to fit into memory. Sufficient RAM must be available to load the checkpoint into memory in order to be able to save it into the PyTorch pickle format.