cerebras.modelzoo.data.common.HDF5IterableDataset.RestartableDataLoader#
- class cerebras.modelzoo.data.common.HDF5IterableDataset.RestartableDataLoader[source]#
Bases:
torch.utils.data.DataLoader
This custom dataloader class specifies the ‘state_dict’, ‘aggregate_state_dict’, ‘load_state_dict’ and ‘deaggregate_state_dict’ methods. These methods dictate what worker state information is stored (local or global streaming info) and how it is to be aggregated and retrieved. To deterministically restart an instance of HDF5IterableDataset, it requires the number of samples already seen in the previous run. This info is stored in the samples_streamed key inside the worker state dict. Upon restart, the load_state_dict method sets the samples_seen class variable which determines the number of samples to be skipped.
Methods
aggregate_state_dict
deaggregate_state_dict
load_state_dict
state_dict
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- static __new__(cls, *args: Any, **kwargs: Any) Any #