Work with Cerebras checkpoints#
The reading, writing, and manipulation of checkpoints can become bottlenecks for training when working with extremely large models, as these file sizes become very large. To address this, Cerebras uses an HDF5-based file format in order to store checkpoints. The content of the checkpoints are the same as for all common model checkpoint formats, so they are also easily convertible to other desired formats. Specific details about utilities and interfaces are provided for each framework supported.
PyTorch Checkpoint Format#
Our large model-optimized checkpoint format is based off the standard HDF5 file format. At a high-level, when saving a checkpoint, the Cerebras stack will take a PyTorch state dictionary, flatten it, and store it in an HDF5 file. For example, the following state dict:
{
"a": {
"b": 0.1,
"c": 0.001,
},
"d": [0.1, 0.2, 0.3]
}
Would be flattened and stored into the H5 file as follows
{
"a.b": 0.1,
"a.c": 0.001,
"d.0": 0.1,
"d.1": 0.2,
"d.2": 0.3,
}
A model/optimizer state dict can be saved in the new checkpoint format using the
cbtorch.save
method. e.g.
import cerebras.pytorch as cbtorch
...
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
cbtorch.save(state_dict, "path/to/checkpoint")
...
A checkpoint saved using the above can be loaded using the cbtorch.load
method. e.g.
import cerebras.pytorch as cbtorch
...
state_dict = cbtorch.load("path/to/checkpoint")
model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])
...
Note
If using the run.py
scripts provided in the ModelZoo the above is all
already taken care of in the runners used in the ModelZoo.
Converting Checkpoint Formats#
If using cbtorch.load
is not a sufficient solution for loading the
checkpoint into memory, a simple conversion can be done to the pickle format
that PyTorch uses as follows
import torch
import cerebras.pytorch as cbtorch
state_dict = cbtorch.load("path/to/checkpoint")
torch.save(state_dict, "path/to/new/checkpoint")
Warning
This will not work for extremely large models whose state dict is too large to fit into memory. Sufficient RAM must be available to load the checkpoint into memory in order to be able to save it into the PyTorch pickle format.