Source code for data_processing.HDF5IterableDataset

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch HDF5 Dataset"""

import logging
import math
import os
import random
from pathlib import Path

import h5py
import numpy as np
import torch

from modelzoo.common.pytorch.input_utils import get_streaming_batch_size
from modelzoo.transformers.pytorch.input_utils import (
    num_tasks,
    shard_list_of_chunks_contiguous,
    task_id,
)


[docs]class HDF5IterableDataset(torch.utils.data.IterableDataset): """ A HDF5 dataset processor. Loads data from HDF5 files. :param dict params: dict containing training input parameters for creating dataset. Expects the following fields: - "data_dir" (str or list of str): Path to dataset HDF5 files - "batch_size" (int): Batch size. - "shuffle" (bool): Flag to enable data shuffling. - "shuffle_seed" (int): Shuffle seed. - "num_workers" (int): How many subprocesses to use for data loading. - "drop_last" (bool): If True and the dataset size is not divisible by the batch size, the last incomplete batch will be dropped. """
[docs] def __init__(self, params): super(HDF5IterableDataset, self).__init__() self.data_dir = params["data_dir"] self.batch_size = get_streaming_batch_size(params["batch_size"]) self.shuffle = params["shuffle"] self.shuffle_seed = params.get("shuffle_seed", None) self.num_workers = params.get("num_workers", 0) self.drop_last = params.get("drop_last", True) self.dataloader_state = params.get('cerebras', {}) self.features_list = params.get( "features", ["input_ids", "attention_mask", "labels"] ) assert self.batch_size > 0, "Batch size should be a positive number." if not isinstance(self.data_dir, list): self.data_dir = [self.data_dir] files = [] for directory in self.data_dir: p = Path(directory) assert ( p.is_dir() ), f"The path {directory} does not exist or is not a directory." files.extend(p.glob('*.h5')) files = sorted(files) if not files: raise RuntimeError("No .h5 dataset files found.") self.num_tasks = num_tasks() self.task_id = task_id() # Shard H5 files between the tasks and resolve the paths files_in_this_task = [ str(file.resolve()) for file in files[self.task_id :: self.num_tasks] ] self.files_in_this_task = [] self.num_examples_in_this_task = 0 for file_path in files_in_this_task: with h5py.File(file_path, mode='r') as h5_file: num_examples_in_file = h5_file.attrs["n_examples"] self.files_in_this_task.append( (file_path, num_examples_in_file) ) self.num_examples_in_this_task += num_examples_in_file if self.shuffle: random.seed(self.shuffle_seed) random.shuffle(self.files_in_this_task) # Single worker # load dataloader state from previous run for restart self.dataloader_state_path = self.dataloader_state.get( 'save_iter_state_path', None ) self.num_workers_prev_state = self.dataloader_state.get( 'num_workers', None ) if self.num_workers_prev_state is not None: assert ( self.num_workers == self.num_workers_prev_state ), "num_workers should be the same at the restart" self.prev_worker_iter_index = 0 if self.dataloader_state_path is not None: if os.path.exists(self.dataloader_state_path): # check if state file is available that contains iter to start from if os.path.isfile( os.path.join( self.dataloader_state_path, 'data_iter_checkpoint_state_file_global', ) ): with open( os.path.join( self.dataloader_state_path, 'data_iter_checkpoint_state_file_global', ), 'r', ) as f: try: global_ckpt_step = int(f.readline()) except IOError as error: logging.error( 'Caught this error: ' + repr(error) + f'not able to read data iter ckpt step' ) # add worker_id suffix to correctly load state for the current task/worker self.dataloader_state_file = os.path.join( self.dataloader_state_path, f'data_iter_state_file_worker_{self.task_id}_step_{global_ckpt_step}.txt', ) if os.path.isfile(self.dataloader_state_file): with open(self.dataloader_state_file, 'r') as f: samples_seen = int(f.readline()) if samples_seen % self.batch_size == 0: self.prev_worker_iter_index = int( samples_seen / self.batch_size ) else: self.prev_worker_iter_index = ( math.floor(samples_seen / self.batch_size) + 1 )
def _load_buffer(self, data_partitions): # partition id should default to 0 if not reading iter from file restart_iter_partition_id = 0 restart_iter_start_idx = 0 # start_idx should default to 0 # Sanity check whether or not "dataloader_state_file" is readable if self.prev_worker_iter_index > 0: iters_until_current_partition = 0 prev_partition_offset_start_idx = 0 current_partition_offset_start_idx = 0 for partition_idx, partition_specs in enumerate(data_partitions): start_idx = partition_specs[1] num_examples = partition_specs[2] if partition_idx > 0: num_examples_prev_partition = ( data_partitions[partition_idx - 1][2] - prev_partition_offset_start_idx ) if ( num_examples_prev_partition - (num_examples_prev_partition // self.batch_size) * self.batch_size ) > 0: current_partition_offset_start_idx = self.batch_size - ( num_examples_prev_partition - (num_examples_prev_partition // self.batch_size) * self.batch_size ) else: current_partition_offset_start_idx = 0 prev_partition_offset_start_idx = ( current_partition_offset_start_idx ) num_examples_curr_partition = ( num_examples - current_partition_offset_start_idx ) else: num_examples_curr_partition = num_examples current_partition_offset_start_idx = 0 iters_until_current_partition += np.ceil( num_examples_curr_partition / self.batch_size ) if ( self.prev_worker_iter_index <= iters_until_current_partition - 1 ): restart_iter_partition_id = partition_idx restart_iter_start_idx = int( self.batch_size * ( self.prev_worker_iter_index - ( iters_until_current_partition - np.ceil( num_examples_curr_partition / self.batch_size ) ) ) ) restart_iter_start_idx += current_partition_offset_start_idx break for partition_idx, partition_specs in enumerate( data_partitions[restart_iter_partition_id:] ): file_path = partition_specs[0] start_idx_org = partition_specs[1] num_examples = partition_specs[2] if self.prev_worker_iter_index > 0: if restart_iter_partition_id >= 0 and partition_idx == 0: start_idx = restart_iter_start_idx else: start_idx = start_idx_org else: start_idx = start_idx_org with h5py.File(file_path, mode='r') as h5_file: for idx in range( start_idx, start_idx_org + num_examples, self.batch_size ): load_len = min( self.batch_size, start_idx_org + num_examples - idx ) load_data = h5_file["data"][idx : idx + load_len] for i in range(load_len): yield load_data[i] def __iter__(self): """ Iterating over the data to construct input features. """ worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers else: # Single-process worker_id = 0 num_workers = 1 data_partitions = shard_list_of_chunks_contiguous( self.files_in_this_task, worker_id, num_workers ) for example in self._load_buffer(data_partitions): yield { feature: np.array(example[i], np.int32) for i, feature in enumerate(self.features_list) } def __len__(self): """ Returns the len of dataset on the task process """ return self.num_examples_in_this_task