Source code for data_processing.h5_map_dataset.samplers

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import torch

from modelzoo.transformers.pytorch.input_utils import cluster_config


[docs]class CBSampler(torch.utils.data.Sampler): """ A sampler to handle sharding, batching, and skipping of map style datasets intended for use on CSX. Sharding is performed in such a way that data order is independent of the number of systems being used and the number of workers per system. """
[docs] def __init__( self, data_source, shuffle=True, seed=None, start_index=0, shard=True, batch_size=None, drop_last=True, ): """ Create a sampler to handle shuffling in a deterministic and restartable way as well as sharding. Args: data_source (torch.utils.data.Dataset): dataset to sample from shuffle (bool): whether or not to shuffle the dataset seed (int): The seed used to make shuffling deterministic start_index (int): The index of the first sample to yield shard (bool): Whether or not to shard the dataset across Cerebras data streamer nodes batch_size (int): The batch size to use to compute sharded indices and group samples into batches. If `None`, no batching will be performed. This is the global batch size visible to the dataset rather than the microbatch size. """ cluster_spec, _ = cluster_config() _num_systems = cluster_spec.num_csx if batch_size is not None and batch_size % _num_systems: raise ValueError( f"The global batch size must be a multiple of the number of " f"CS-2s being used. Got global batch size {batch_size} and " f"number of systems {_num_systems}." ) if _num_systems > 1 and not drop_last: raise ValueError( f"`drop_last=False` is only supported on GPU. Please re-run " f"with `drop_last=True`." ) microbatch_size = ( batch_size // _num_systems if batch_size is not None else None ) self.sampler = BaseSampler( data_source, shuffle=shuffle, seed=seed, start_index=start_index ) if batch_size is not None: self.sampler = BatchSampler( self.sampler, microbatch_size, drop_last ) if shard: self.sampler = Sharder(self.sampler) if batch_size is not None and _num_systems > 1: self.sampler = BatchAccumulator(self.sampler, _num_systems)
def __iter__(self): return iter(self.sampler) def __len__(self): return len(self.sampler)
[docs]class BaseSampler(torch.utils.data.Sampler): """ Handle shuffling and skipping """
[docs] def __init__( self, data_source, num_samples=None, shuffle=True, seed=None, start_index=0, ): self.data_source = data_source self._num_samples = num_samples if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError( "num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples) ) self._num_samples_frozen = self.num_samples self.shuffle = shuffle self.seed = seed self.epoch = start_index // self.num_samples self.start_index = start_index - self.num_samples * self.epoch
@property def num_samples(self): if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): if self.num_samples != self._num_samples_frozen: raise RuntimeError( f"Data source passed into Sampler must have the same length " f"every epoch. Original length was {self._num_samples_frozen}, " f"new length is {self.num_samples}" ) if self.shuffle: gen = torch.Generator() gen.manual_seed(self.seed + self.epoch) if self.num_samples >= len(self.data_source): perm = torch.randperm(self.num_samples, generator=gen) else: perm = torch.randperm(len(self.data_source), generator=gen) perm = perm[: self.num_samples] perm = perm[self.start_index :] else: perm = torch.arange(self.start_index, self.num_samples) yield from perm.tolist() self.epoch += 1 self.start_index = 0 def __len__(self): return self.num_samples - self.start_index
[docs]class Sharder(torch.utils.data.Sampler):
[docs] def __init__(self, data_source): self.data_source = data_source cluster_spec, worker_spec = cluster_config() self.task_id = ( worker_spec.local_rank * cluster_spec.num_csx + worker_spec.wse_id ) self.num_tasks = cluster_spec.num_workers_per_csx * cluster_spec.num_csx self.first_task = 0
def __iter__(self): n = len(self.data_source) effective_task_id = (self.task_id - self.first_task) % self.num_tasks for i, x in enumerate(self.data_source): if i % self.num_tasks == effective_task_id: yield x self.first_task = ( self.first_task + (n % self.num_tasks) ) % self.num_tasks def __len__(self): effective_task_id = (self.task_id - self.first_task) % self.num_tasks n = len(self.data_source) l = n // self.num_tasks if n % self.num_tasks > effective_task_id: l += 1 return l
[docs]class BatchSampler(torch.utils.data.Sampler): """ A slight modification of the PyTorch batch sampler such that any samples not yielded at the end of an epoch when `drop_last=True` will be yielded at the start of the next epoch. This is necessary for shard-invariance. Adapted from the PyTorch batch sampler """
[docs] def __init__(self, sampler, batch_size, drop_last): if ( not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0 ): raise ValueError( "batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size) ) if not isinstance(drop_last, bool): raise ValueError( "drop_last should be a boolean value, but got " "drop_last={}".format(drop_last) ) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last self.leftover_samples = []
def __iter__(self): # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 if self.drop_last: sampler_iter = itertools.chain(self.leftover_samples, self.sampler) while True: try: batch = [] for _ in range(self.batch_size): batch.append(next(sampler_iter)) yield batch except StopIteration: self.leftover_samples = batch break else: batch = [0] * self.batch_size idx_in_batch = 0 for idx in self.sampler: batch[idx_in_batch] = idx idx_in_batch += 1 if idx_in_batch == self.batch_size: yield batch idx_in_batch = 0 batch = [0] * self.batch_size if idx_in_batch > 0: yield batch[:idx_in_batch] def __len__(self): if self.drop_last: return ( len(self.sampler) + len(self.leftover_samples) ) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size
[docs]class BatchAccumulator(torch.utils.data.Sampler): """ Accumulate neighboring batches into one single larger batch. This is the inverse operation to the splitting of batches into microbatches that happens when using multiple CSX systems. """
[docs] def __init__( self, data_source, n_accum, ): """ Assumes data_source is an iterator of batches where each batch has the same length (i.e. `drop_last=True`). """ self.data_source = data_source self._n = n_accum self._next_batch = []
def __iter__(self): data_iter = itertools.chain(self._next_batch, self.data_source) self._next_batch = [] while True: try: for _ in range(self._n): self._next_batch.append(next(data_iter)) yield [x for batch in self._next_batch for x in batch] self._next_batch = [] except StopIteration: break def __len__(self): return (len(self.data_source) + len(self._next_batch)) // self._n