# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating synthetic data based on some specification."""
import torch
from torch.utils._pytree import (
_dict_flatten,
_dict_unflatten,
_register_pytree_node,
tree_flatten,
tree_unflatten,
)
from cerebras.modelzoo.data.common.tensor_spec import TensorSpec
try:
import cerebras.pytorch as cstorch
except:
cstorch = None
[docs]def custom_dict_flatten(d: dict):
"""Constructs TensorSpec instances to contain the leaf nodes
of the tree structure before flattening.
"""
if "shape" or "dtype" or "tensor_factory" in d:
return [[TensorSpec(**d)], "TensorSpec"]
return _dict_flatten(d)
[docs]def custom_dict_unflatten(values, context):
"""After mapping the TensorSpecs to tensors/callables, return them directly
as leaf nodes instead of reconstructing a dictionary when unflattening.
"""
if context == "TensorSpec":
return values[0]
return _dict_unflatten(values, context)
[docs]class SyntheticDataProcessor:
"""Creates a synthetic dataset.
Constructs a SyntheticDataset from the user-provided nested structure of
input tensors and returns a torch.utils.data.DataLoader from the
SyntheticDataset and the regular torch.utils.data.DataLoader inputs
specified in params.yaml. The torch.utils.data.DataLoader is returned by
calling the create_dataloader() method.
Args:
params: Dictionary containing dataset inputs and specifications.
Within this dictionary, the user provides the additional
'synthetic_inputs' field that corresponds to a nested tree
structure of input tensor specifications used to construct the
SyntheticDataset.
In params.yaml:
data_processor: "SyntheticDataProcessor". Must set this input to
use this class
batch_size: int
shuffle_seed: Optional[int] = None. If it is not None, then
torch.manual_seed(seed=shuffle_seed) will be called when
creating the dataloader.
num_examples: Optional[int] = None. If it is not None, then
the it specifies the number of examples/samples in the
SyntheticDataset. Otherwise, the SyntheticDataset will
generate samples indefinitely.
.. regular torch.utils.DataLoader inputs
...
synthetic_inputs:
..
shape: Collection of positive ints
dtype: PyTorch dtype
OR
tensor_factory: name of PyTorch function
args:
size:
dtype:
...
"""
[docs] def __init__(self, params):
if cstorch is None:
raise RuntimeError(
f"Unable to import cerebras.pytorch. In order to use "
f"SyntheticDataProcessor, please ensure you have access to "
f"the cerebras_pytorch package."
)
# Regular torch.utils.DataLoader inputs
self._batch_size = params.get("batch_size", None)
if not self._batch_size:
raise ValueError(
f"No 'batch_size' field specified. Please enter a positive "
f"integer batch_size."
)
if not isinstance(self._batch_size, int) or self._batch_size <= 0:
raise ValueError(
f"Expected batch_size to be a positive integer but got "
f"{self._batch_size}."
)
self._shuffle = params.get("shuffle", False)
self._sampler = params.get("sampler", None)
self._batch_sampler = params.get("batch_sampler", None)
self._num_workers = params.get("num_workers", 0)
self._pin_memory = params.get("pin_memory", False)
self._drop_last = params.get("drop_last", False)
self._timeout = params.get("timeout", 0)
# SyntheticDataset specific inputs
self._seed = params.get("shuffle_seed", None)
self._num_examples = params.get("num_examples", None)
if self._num_examples is not None:
if (
not isinstance(self._num_examples, int)
or self._num_examples <= 0
):
raise ValueError(
f"Expected num_examples to be a positive integer but got "
f"{self._num_examples}."
)
if self._drop_last and self._num_examples < self._batch_size:
raise ValueError(
f"This dataset does not return any batches because number of "
f"examples in the dataset ({self._num_examples}) is less than "
f"the batch size ({self._batch_size}) and `drop_last` is True."
)
self._tensors = []
synthetic_inputs = params.get("synthetic_inputs", {})
if synthetic_inputs:
_register_pytree_node(
dict, custom_dict_flatten, custom_dict_unflatten
)
leaf_nodes, self._spec_tree = tree_flatten(synthetic_inputs)
for tensor_spec in leaf_nodes:
if not isinstance(tensor_spec, TensorSpec):
raise TypeError(
f"Expected all leaf nodes in 'synthetic_inputs' to be "
f"of type TensorSpec but got {type(tensor_spec)}. "
f"Please ensure that all leaf nodes under "
f"'synthetic_inputs' are instances of TensorSpec. "
f"These instances are created by specifying either a "
f"'shape' and 'dtype' keys or a 'tensor_factory' "
f"key in a dict (mutually exclusive)."
)
self._tensors.append(self._process_tensor(tensor_spec.specs))
self._tensor_specs = tree_unflatten(self._tensors, self._spec_tree)
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
else:
raise ValueError(
f"Expected 'synthetic_inputs' field but found none. Please "
f"specify this field and provide tensor information according "
f"to the documentation."
)
def _torch_dtype_from_str(self, dtype):
"""Takes in the user input string for dtype and returns the
corresponding torch.dtype.
"""
torch_dtype = getattr(torch, dtype, None)
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid torch dtype '{dtype}'. Please ensure all tensors use "
f"a valid torch dtype."
)
return torch_dtype
def _process_tensor(self, tensor_spec):
"""Parses the tensor_spec and returns a corresponding synthetic tensor."""
if not tensor_spec:
raise ValueError(
f"Empty TensorSpec found. Please provide at least a 'shape' "
f"and 'dtype' field to complete the tensor specification."
)
shape = tensor_spec.get("shape", None)
dtype = tensor_spec.get("dtype", None)
tensor_factory = tensor_spec.get("tensor_factory", None)
# Enforce mutually exclusive inputs
mutex = shape and dtype and not tensor_factory
mutex = mutex or (not shape and not dtype and tensor_factory)
if not mutex:
possible_inputs = ['shape', 'dtype', 'tensor_factory']
found = [
i
for i, j in locals().items()
if i in possible_inputs and j is not None
]
raise ValueError(
f"Expected either 'shape' and 'dtype' fields or 'tensor_factory' "
f"field specified (mutually exclusive) but instead found the "
f"following fields: {found}. Please ensure each tensor either "
f"has a 'shape' and 'dtype' field OR a 'tensor_factory' field."
)
if shape and dtype:
if not all(isinstance(e, int) and e > 0 for e in shape):
raise ValueError(
f"Expected shape to be a collection of positive integers "
f"but got {shape}. Please ensure all tensor shapes are "
f"collections of positive integers."
)
torch_dtype = self._torch_dtype_from_str(dtype)
return torch.zeros(shape, dtype=torch_dtype)
elif tensor_factory:
torch_args = tensor_spec.get("args", None)
if not torch_args:
raise ValueError(
f"Expected 'args' field but found none for the "
f"tensor_factory '{tensor_factory}'. Please specify this "
f"field and fill it with the arguments for the chosen "
f"tensor generation function."
)
if not torch_args.get("dtype", None):
raise ValueError(
f"Expected 'dtype' argument for tensor_factory '{tensor_factory}' "
f"in the 'args' field, but found none. Please specify this "
f"argument with the desired tensor dtype."
)
torch_dtype = self._torch_dtype_from_str(torch_args["dtype"])
torch_args["dtype"] = torch_dtype
# Raises torch AttributeError if the provided function is invalid
try:
test_tensor = getattr(torch, tensor_factory)(**torch_args)
except Exception as e:
raise ValueError(
f"Provided tensor_factory '{tensor_factory}' is invalid "
f"Please ensure you are using a supported PyTorch callable "
f"that returns a torch tensor."
) from e
if not isinstance(test_tensor, torch.Tensor):
raise ValueError(
f"Expected tensor_factory {tensor_factory} to return a "
f"torch.Tensor but instead got {type(test_tensor)}. Please "
f"ensure that tensor_factory contains a valid PyTorch "
f"callable that returns a torch tensor."
)
return lambda x: getattr(torch, tensor_factory)(**torch_args)
[docs] def create_dataloader(self):
"""Returns torch.utils.data.DataLoader that corresponds to the created
SyntheticDataset.
"""
if self._shuffle and self._seed is not None:
torch.manual_seed(self._seed)
return torch.utils.data.DataLoader(
cstorch.utils.data.SyntheticDataset(
self._tensor_specs, num_samples=self._num_examples
),
batch_size=self._batch_size,
shuffle=self._shuffle,
sampler=self._sampler,
batch_sampler=self._batch_sampler,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
drop_last=self._drop_last,
timeout=self._timeout,
)