"""The Cerebras dataloader class"""
from copy import deepcopy
from typing import Callable
import torch
from cerebras_pytorch.experimental.utils.data.utils import infer_batch_size
[docs]class DataLoader:
"""
Wrapper around torch.utils.data.DataLoader that facilitates
moving data generated by the dataloader to a Cerebras system
Args:
input_fn: A callable that returns a torch.utils.data.DataLoader
instance
*args, **kwargs: Any other positional or keyword arguments
are passed into the input_fn when each worker instantiates
their respective dataloaders
"""
[docs] def __init__(
self,
input_fn: Callable[..., torch.utils.data.DataLoader],
*args,
**kwargs,
):
if not callable(input_fn):
raise TypeError(
"Expected a callable that constructs and returns a "
"torch.utils.data.DataLoader."
)
self.input_fn = input_fn
self.input_fn_params = deepcopy((args, kwargs))
self.dataloader = input_fn(*args, **kwargs)
self.batch_size = 0
def __len__(self):
return len(self.dataloader)
def __iter__(self):
for batch in self.dataloader:
self.batch_size = infer_batch_size(batch)
yield batch