cerebras.modelzoo.data.vision.classification.sampler.RepeatedAugSampler#
- class cerebras.modelzoo.data.vision.classification.sampler.RepeatedAugSampler[source]#
Bases:
torch.utils.data.Sampler
Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on ‘torch.utils.data.DistributedSampler’.
This is borrowed from the DeiT Repo: https://github.com/facebookresearch/deit/blob/main/samplers.py
Methods
set_epoch
- __init__(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, num_repeats=3, batch_size=256)[source]#
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- static __new__(cls, *args: Any, **kwargs: Any) Any #