# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch HDF5 Dataset"""
import logging
import math
import os
import random
from pathlib import Path
import h5py
import numpy as np
import torch
from modelzoo.common.pytorch.input_utils import get_streaming_batch_size
from modelzoo.transformers.pytorch.input_utils import (
num_tasks,
shard_list_of_chunks_contiguous,
task_id,
)
[docs]class HDF5IterableDataset(torch.utils.data.IterableDataset):
"""
A HDF5 dataset processor. Loads data from HDF5 files.
:param dict params: dict containing training
input parameters for creating dataset.
Expects the following fields:
- "data_dir" (str or list of str): Path to dataset HDF5 files
- "batch_size" (int): Batch size.
- "shuffle" (bool): Flag to enable data shuffling.
- "shuffle_seed" (int): Shuffle seed.
- "num_workers" (int): How many subprocesses to use for data loading.
- "drop_last" (bool): If True and the dataset size is not divisible by the batch size, the last incomplete batch will be dropped.
"""
[docs] def __init__(self, params):
super(HDF5IterableDataset, self).__init__()
self.data_dir = params["data_dir"]
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.shuffle = params["shuffle"]
self.shuffle_seed = params.get("shuffle_seed", None)
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
self.dataloader_state = params.get('cerebras', {})
self.features_list = params.get(
"features", ["input_ids", "attention_mask", "labels"]
)
assert self.batch_size > 0, "Batch size should be a positive number."
if not isinstance(self.data_dir, list):
self.data_dir = [self.data_dir]
files = []
for directory in self.data_dir:
p = Path(directory)
assert (
p.is_dir()
), f"The path {directory} does not exist or is not a directory."
files.extend(p.glob('*.h5'))
files = sorted(files)
if not files:
raise RuntimeError("No .h5 dataset files found.")
self.num_tasks = num_tasks()
self.task_id = task_id()
# Shard H5 files between the tasks and resolve the paths
files_in_this_task = [
str(file.resolve())
for file in files[self.task_id :: self.num_tasks]
]
self.files_in_this_task = []
self.num_examples_in_this_task = 0
for file_path in files_in_this_task:
with h5py.File(file_path, mode='r') as h5_file:
num_examples_in_file = h5_file.attrs["n_examples"]
self.files_in_this_task.append(
(file_path, num_examples_in_file)
)
self.num_examples_in_this_task += num_examples_in_file
if self.shuffle:
random.seed(self.shuffle_seed)
random.shuffle(self.files_in_this_task)
# Single worker
# load dataloader state from previous run for restart
self.dataloader_state_path = self.dataloader_state.get(
'save_iter_state_path', None
)
self.num_workers_prev_state = self.dataloader_state.get(
'num_workers', None
)
if self.num_workers_prev_state is not None:
assert (
self.num_workers == self.num_workers_prev_state
), "num_workers should be the same at the restart"
self.prev_worker_iter_index = 0
if self.dataloader_state_path is not None:
if os.path.exists(self.dataloader_state_path):
# check if state file is available that contains iter to start from
if os.path.isfile(
os.path.join(
self.dataloader_state_path,
'data_iter_checkpoint_state_file_global',
)
):
with open(
os.path.join(
self.dataloader_state_path,
'data_iter_checkpoint_state_file_global',
),
'r',
) as f:
try:
global_ckpt_step = int(f.readline())
except IOError as error:
logging.error(
'Caught this error: '
+ repr(error)
+ f'not able to read data iter ckpt step'
)
# add worker_id suffix to correctly load state for the current task/worker
self.dataloader_state_file = os.path.join(
self.dataloader_state_path,
f'data_iter_state_file_worker_{self.task_id}_step_{global_ckpt_step}.txt',
)
if os.path.isfile(self.dataloader_state_file):
with open(self.dataloader_state_file, 'r') as f:
samples_seen = int(f.readline())
if samples_seen % self.batch_size == 0:
self.prev_worker_iter_index = int(
samples_seen / self.batch_size
)
else:
self.prev_worker_iter_index = (
math.floor(samples_seen / self.batch_size)
+ 1
)
def _load_buffer(self, data_partitions):
# partition id should default to 0 if not reading iter from file
restart_iter_partition_id = 0
restart_iter_start_idx = 0 # start_idx should default to 0
# Sanity check whether or not "dataloader_state_file" is readable
if self.prev_worker_iter_index > 0:
iters_until_current_partition = 0
prev_partition_offset_start_idx = 0
current_partition_offset_start_idx = 0
for partition_idx, partition_specs in enumerate(data_partitions):
start_idx = partition_specs[1]
num_examples = partition_specs[2]
if partition_idx > 0:
num_examples_prev_partition = (
data_partitions[partition_idx - 1][2]
- prev_partition_offset_start_idx
)
if (
num_examples_prev_partition
- (num_examples_prev_partition // self.batch_size)
* self.batch_size
) > 0:
current_partition_offset_start_idx = self.batch_size - (
num_examples_prev_partition
- (num_examples_prev_partition // self.batch_size)
* self.batch_size
)
else:
current_partition_offset_start_idx = 0
prev_partition_offset_start_idx = (
current_partition_offset_start_idx
)
num_examples_curr_partition = (
num_examples - current_partition_offset_start_idx
)
else:
num_examples_curr_partition = num_examples
current_partition_offset_start_idx = 0
iters_until_current_partition += np.ceil(
num_examples_curr_partition / self.batch_size
)
if (
self.prev_worker_iter_index
<= iters_until_current_partition - 1
):
restart_iter_partition_id = partition_idx
restart_iter_start_idx = int(
self.batch_size
* (
self.prev_worker_iter_index
- (
iters_until_current_partition
- np.ceil(
num_examples_curr_partition
/ self.batch_size
)
)
)
)
restart_iter_start_idx += current_partition_offset_start_idx
break
for partition_idx, partition_specs in enumerate(
data_partitions[restart_iter_partition_id:]
):
file_path = partition_specs[0]
start_idx_org = partition_specs[1]
num_examples = partition_specs[2]
if self.prev_worker_iter_index > 0:
if restart_iter_partition_id >= 0 and partition_idx == 0:
start_idx = restart_iter_start_idx
else:
start_idx = start_idx_org
else:
start_idx = start_idx_org
with h5py.File(file_path, mode='r') as h5_file:
for idx in range(
start_idx, start_idx_org + num_examples, self.batch_size
):
load_len = min(
self.batch_size, start_idx_org + num_examples - idx
)
load_data = h5_file["data"][idx : idx + load_len]
for i in range(load_len):
yield load_data[i]
def __iter__(self):
"""
Iterating over the data to construct input features.
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
# Single-process
worker_id = 0
num_workers = 1
data_partitions = shard_list_of_chunks_contiguous(
self.files_in_this_task, worker_id, num_workers
)
for example in self._load_buffer(data_partitions):
yield {
feature: np.array(example[i], np.int32)
for i, feature in enumerate(self.features_list)
}
def __len__(self):
"""
Returns the len of dataset on the task process
"""
return self.num_examples_in_this_task