Source code for experimental.utils.data.dataloader

"""The Cerebras dataloader class"""
from copy import deepcopy
from typing import Callable

import torch

from cerebras_pytorch.experimental.utils.data.utils import infer_batch_size


[docs]class DataLoader: """ Wrapper around torch.utils.data.DataLoader that facilitates moving data generated by the dataloader to a Cerebras system Args: input_fn: A callable that returns a torch.utils.data.DataLoader instance *args, **kwargs: Any other positional or keyword arguments are passed into the input_fn when each worker instantiates their respective dataloaders """
[docs] def __init__( self, input_fn: Callable[..., torch.utils.data.DataLoader], *args, **kwargs, ): if not callable(input_fn): raise TypeError( "Expected a callable that constructs and returns a " "torch.utils.data.DataLoader." ) self.input_fn = input_fn self.input_fn_params = deepcopy((args, kwargs)) self.dataloader = input_fn(*args, **kwargs) self.batch_size = 0
def __len__(self): return len(self.dataloader) def __iter__(self): for batch in self.dataloader: self.batch_size = infer_batch_size(batch) yield batch