PyTorch DataLoader
On This Page
PyTorch DataLoader¶
To increase the speed to load your data, PyTorch supports parallelized data loading, retrieving batches of indices instead of individually, and streaming to progressively download datasets.
PyTorch provides a data loading utility (torch.utils.data.DataLoader
) class. The most important argument of this DataLoader is the Dataset which is a dataset object to load the data from. There are two different types of Datasets.
Map-style datasets (
Dataset
) is a map from indices/keys to data samples. So, ifdataset[idx]
is accessed, that readsidx-th
image and its label from a directory on disk.Iterable-style datasets (
IterableDataset
) represents an iterable over data samples. This is very suitable where random reads are expensive or even improbable, and where the batch size depends on the fetched data. So, ifiter(dataset)
is called, returns a stream of data from a database, or remote server, or even logs generated in real time.
We extend either one of the above two to create our dataloader classes and implement additional functionality. As an example, BertCSVDynamicMaskDataProcessor
extends IterableDataset
and BertClassifierDataProcessor
extends Dataset
.
Data loading order¶
IterableDataset
is controlled by user-defined iterable which makes it easy for chunk-reading and dynamic batch size.
Dataset
has a default Sampler
it is possible to create custom Sampler
object that at each time that yields the next index/key to fetch.
There is no iterable-style dataset sampler, since such datasets have no notion of a key or an index.
Loading batched/ non-batched data¶
Automatic Batching is the most common(default) case, which fetches a minibatch of data and collates them into batches samples (Usually first dimension of the Tensor is the batch dimension).
When batch_size
(default 1) is not None
, the data loader yields batched samples instead of individual samples. drop_last
argument is to specify how the data loader obtains batches of dataset keys.
Essentially, when drop_last=True
, the last batch (when the number of examples in your dataset is not divisible by your batch_size) is ignored, while drop_last=False
makes the last batch smaller than your batch_size
.
This leads to loading from map-style datasets as such:
for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices])
And from an iterable-style dataset would be:
dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices])
collate_fn
¶
If batching is disabled, then it simply convert Numpy arrays in Pytorch tensors. When batching is enabled, it has the following properties:
Prepends a new dimension as batch dimension
Automatically converts Numpy arrays and python numericals into torch tensors
Preserves the data structure just converted into torch tensors
A custom collate_fn
can be used to customize collation; e.g., padding sequential data to max length of a batch, collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.
Multiple Workers¶
Pytorch makes parallel data loading very easy. You can parallelize data loading with the num_workers
argument of a PyTorch DataLoader and get a higher throughput.
Under the hood, the DataLoader
starts num_workers
processes. Each process reloads the dataset passed to the DataLoader and is used to query examples. Reloading the dataset inside a worker doesn’t fill up your RAM, since it simply memory-maps the dataset again from your disk.
Stream data¶
By loading dataset in streaming mode, you allow to progressively download the data you need while iterating over the dataset. If the dataset is split in several shards (i.e., if the dataset consists of multiple data files), then you can stream in parallel using num_workers
.
Now, a dataloader that supports both map-style and iterable-style datasets looks like this:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)