"""Dataset classes for use with PyTorch DataLoaders."""
from typing import (
Callable,
Dict,
Iterator,
List,
NamedTuple,
Optional,
OrderedDict,
Tuple,
Union,
)
import torch
from torch.utils._pytree import SUPPORTED_NODES, tree_flatten, tree_unflatten
from torch.utils.data import IterDataPipe
LeafT = Union[torch.Tensor, Callable[[int], torch.Tensor]]
SampleSpecT = Union[
LeafT,
List["SampleSpecT"],
Tuple["SampleSpecT", ...],
Dict[str, "SampleSpecT"],
OrderedDict[str, "SampleSpecT"],
NamedTuple,
]
SampleT = Union[
torch.Tensor,
List["SampleT"],
Tuple["SampleT", ...],
Dict[str, "SampleT"],
OrderedDict[str, "SampleT"],
NamedTuple,
]
# pylint: disable=abstract-method
[docs]class SyntheticDataset(IterDataPipe):
"""A synthetic dataset that generates samples from a `SampleSpec`."""
[docs] def __init__(
self, sample_spec: SampleSpecT, num_samples: Optional[int] = None
):
"""Constructs a `SyntheticDataset` instance.
A synthetic dataset can be used to generate samples on the fly with
an expected dtype/shape but without needing to create a full-blown
dataset. This is especially useful for compile validation.
Args:
sample_spec: Specification of the samples to generate. This can be
a nested structure of one of the following types:
- `torch.Tensor`: A tensor to be cloned.
- `Callable`: A callable that takes the sample index and
returns a tensor.
Supported data structures for holding the above leaf nodes are
`list`, `tuple`, `dict`, `OrderedDict`, and `NamedTuple`.
num_samples: Total size of the dataset. If None, the dataset will
generate samples indefinitely.
"""
super().__init__()
self._leaf_nodes, self._spec_tree = tree_flatten(sample_spec)
if not self._leaf_nodes:
raise ValueError(
f"`sample_spec` must be a non-empty python tree of "
f"`torch.Tensor` or `Callable`."
)
for item in self._leaf_nodes:
if not isinstance(item, (torch.Tensor, Callable)):
raise ValueError(
f"`sample_spec` is expected to contain a python tree of "
f"`torch.Tensor`, or `Callable`, but got an item of type "
f"`{type(item)}`. Note that supported data structures for "
f"holding leaf nodes in the tree are "
f"{', '.join(str(x) for x in SUPPORTED_NODES)}."
)
if isinstance(num_samples, int):
if num_samples <= 0:
raise ValueError(
f"`num_samples` must be a positive integer, but got "
f"`{num_samples}`."
)
self._num_samples = num_samples
elif num_samples is None:
self._num_samples = None
else:
raise TypeError(
f"`num_samples` must be a positive integer or None, but got a "
f"value of type `{type(num_samples)}`."
)
def __iter__(self) -> Iterator[SampleT]:
"""Returns an iterator for generating samples."""
index = 0
while self._num_samples is None or index < self._num_samples:
sample_flat = []
for item in self._leaf_nodes:
if isinstance(item, torch.Tensor):
sample_flat.append(item.clone())
elif callable(item):
sample_flat.append(item(index))
else:
raise TypeError(
f"Invalid type for leaf node: {type(item)}."
)
yield tree_unflatten(sample_flat, self._spec_tree)
index += 1
def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
if self._num_samples is None:
raise TypeError(
f"`{self.__class__.__name__}` does not have a length because "
f"`num_samples` was not provided."
)
return self._num_samples