cerebras.modelzoo.data.common.h5_map_dataset.dataset.RestartableDataLoader#
- class cerebras.modelzoo.data.common.h5_map_dataset.dataset.RestartableDataLoader[source]#
Bases:
torch.utils.data.DataLoader
The state we care about for allowing deterministic restart of instances of HDF5Dataset is the total number of samples streamed globally, which gets consumed by the sampler. Accordingly each worker saves the number of samples that it has streamed in state_dict(). We aggregate these together via summation to save the global number of samples streamed across all workers, which is the same thing that is used to set the state of the sampler on state dict load.
Methods
Sum samples streamed across all workers to get the number of samples streamed globally
No deaggregation needed since the sampler needs the global number of samples streamed
Set sampler state with the total number of samples streamed globally
Save number of samples streamed for current worker
validate_state_dict
- load_state_dict(state_dict)[source]#
Set sampler state with the total number of samples streamed globally
- aggregate_state_dict(worker_states)[source]#
Sum samples streamed across all workers to get the number of samples streamed globally
- deaggregate_state_dict(aggregated_state_dict)[source]#
No deaggregation needed since the sampler needs the global number of samples streamed
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- static __new__(cls, *args: Any, **kwargs: Any) Any #