Source code for cerebras.modelzoo.data.nlp.bert.BertCSVDataProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processor for PyTorch BERT training.
"""

import csv
import random
from typing import Any, List, Literal, Optional, Union

import numpy as np
import torch
from pydantic import Field, PositiveInt, field_validator

from cerebras.modelzoo.common.input_utils import (
    bucketed_batch,
    get_streaming_batch_size,
)
from cerebras.modelzoo.config import DataConfig
from cerebras.modelzoo.data.common.input_utils import (
    get_data_for_task,
    num_tasks,
    shard_list_interleaved,
    task_id,
)
from cerebras.modelzoo.data.nlp.bert.bert_utils import get_meta_data


[docs]class BertCSVDataProcessorConfig(DataConfig): data_processor: Literal["BertCSVDataProcessor"] data_dir: Union[str, List[str]] = ... "Path to the data files to use." batch_size: PositiveInt = ... "The batch size." disable_nsp: bool = False "Whether Next Sentence Prediction (NSP) objective is disabled." dynamic_mlm_scale: bool = False "Whether to dynamically scale the loss." buckets: Optional[List[int]] = None """ A list of bucket boundaries. If set to None, then no bucketing will happen, and data will be batched normally. If set to a list, then data will be grouped into `len(buckets) + 1` buckets. A sample `s` will go into bucket `i` if `buckets[i-1] <= element_length_fn(s) < buckets[i]` where 0 and inf are the implied lowest and highest boundaries respectively. `buckets` must be sorted and all elements must be non-zero. """ shuffle: bool = False "Whether or not to shuffle the dataset." shuffle_seed: Optional[int] = None "The seed used for deterministic shuffling." shuffle_buffer: Optional[int] = None """ Buffer size to shuffle samples across. If None and shuffle is enabled, 10*batch_size is used. """ num_workers: int = 0 "The number of PyTorch processes used in the dataloader." prefetch_factor: Optional[int] = 2 "The number of batches to prefetch in the dataloader." persistent_workers: bool = False "Whether or not to keep workers persistent between epochs." drop_last: bool = True "Whether to drop last batch of epoch if it's an incomplete batch." # The following fields are deprecated and unused. # They will be removed in the future once all configs have been fixed mixed_precision: Optional[Any] = Field(default=None, deprecated=True) vocab_size: Optional[Any] = Field(default=None, deprecated=True) vocab_file: Optional[Any] = Field(default=None, deprecated=True) whole_word_masking: Optional[Any] = Field(default=None, deprecated=True) max_predictions_per_seq: Optional[Any] = Field( default=None, deprecated=True ) do_lower: Optional[Any] = Field(default=None, deprecated=True) masked_lm_prob: Optional[Any] = Field(default=None, deprecated=True) max_sequence_length: Optional[Any] = Field(default=None, deprecated=True) max_position_embeddings: Optional[Any] = Field( default=None, deprecated=True ) def post_init(self, context): super().post_init(context) if not self.num_workers: self.prefetch_factor = None # the default value in DataLoader self.persistent_workers = False @field_validator("disable_nsp", mode="after") @classmethod def get_disable_nsp(cls, disable_nsp, info): if info.context: model_config = info.context.get("model", {}).get("config") if hasattr(model_config, "disable_nsp"): return model_config.disable_nsp return disable_nsp
[docs]class BertCSVDataProcessor(torch.utils.data.IterableDataset): """Reads csv files containing the input text tokens, and MLM features.""" def __init__(self, config: BertCSVDataProcessorConfig): super().__init__() self.meta_data = get_meta_data(config.data_dir) self.meta_data_values = list(self.meta_data.values()) self.meta_data_filenames = list(self.meta_data.keys()) # Please note the appending of [0] self.meta_data_values_cum_sum = np.cumsum([0] + self.meta_data_values) self.num_examples = sum(map(int, self.meta_data.values())) self.disable_nsp = config.disable_nsp self.batch_size = get_streaming_batch_size(config.batch_size) self.num_batches = self.num_examples // self.batch_size assert ( self.num_batches > 0 ), "Dataset does not contain enough samples for one batch. Please choose a smaller batch size" self.num_tasks = num_tasks() self.task_id = task_id() self.num_batch_per_task = self.num_batches // self.num_tasks assert ( self.num_batch_per_task > 0 ), "Dataset cannot be evenly distributed across the given tasks. Please choose fewer tasks to run with" self.num_examples_per_task = self.num_batch_per_task * self.batch_size self.files_in_task = get_data_for_task( self.task_id, self.meta_data_values_cum_sum, self.num_examples_per_task, self.meta_data_values, self.meta_data_filenames, ) self.shuffle = config.shuffle self.shuffle_seed = config.shuffle_seed if config.shuffle_buffer is None: self.shuffle_buffer = 10 * self.batch_size else: self.shuffle_buffer = config.shuffle_buffer self.dynamic_mlm_scale = config.dynamic_mlm_scale self.buckets = config.buckets # Multi-processing params. self.num_workers = config.num_workers self.drop_last = config.drop_last self.prefetch_factor = config.prefetch_factor self.persistent_workers = config.persistent_workers # Store params. self.data_buffer = [] self.csv_files_per_task_per_worker = [] self.processed_buffers = 0
[docs] def load_buffer(self): """ Generator to read the data in chunks of size of `data_buffer`. :returns: Yields the data stored in the `data_buffer`. """ self.processed_buffers = 0 self.data_buffer = [] while self.processed_buffers < len(self.csv_files_per_task_per_worker): ( current_file_path, num_examples, start_id, ) = self.csv_files_per_task_per_worker[self.processed_buffers] with open(current_file_path, "r", newline="") as fin: data_reader = csv.DictReader(fin) for row_id, row in enumerate(data_reader): if start_id <= row_id < start_id + num_examples: self.data_buffer.append(row) else: continue if len(self.data_buffer) == self.shuffle_buffer: if self.shuffle: self.rng.shuffle(self.data_buffer) for ind in range(len(self.data_buffer)): yield self.data_buffer[ind] self.data_buffer = [] self.processed_buffers += 1 if self.shuffle: self.rng.shuffle(self.data_buffer) for ind in range(len(self.data_buffer)): yield self.data_buffer[ind] self.data_buffer = []
def __len__(self): # Returns the len of dataset on the task process if not self.drop_last: return ( self.num_examples_per_task + self.batch_size - 1 ) // self.batch_size elif self.buckets is None: return self.num_examples_per_task // self.batch_size else: # give an under-estimate in case we don't fully fill some buckets length = self.num_examples_per_task // self.batch_size length -= self.batch_size * (len(self.buckets) + 1) return length
[docs] def get_single_item(self): """ Iterating over the data to construct input features. :return: A tuple with training features: * np.array[int.32] input_ids: Numpy array with input token indices. Shape: (`max_sequence_length`). * np.array[int.32] labels: Numpy array with labels. Shape: (`max_sequence_length`). * np.array[int.32] attention_mask Shape: (`max_sequence_length`). * np.array[int.32] token_type_ids: Numpy array with segment indices. Shape: (`max_sequence_length`). * np.array[int.32] next_sentence_label: Numpy array with labels for NSP task. Shape: (1). * np.array[int.32] masked_lm_mask: Numpy array with a mask of predicted tokens. Shape: (`max_predictions`) `0` indicates the non masked token, and `1` indicates the masked token. """ def make_features( data_row, feature_names, required_features=False, dtype=np.int32, ): if required_features: absent_features = [ feature for feature in feature_names if feature not in data_row ] if absent_features: raise ValueError( f"{absent_features} are required features, but absent in the dataset" ) return { feature: np.array(eval(data_row[feature]), dtype=dtype) for feature in feature_names if feature in data_row } # Iterate over the data rows to create input features. for data_row in self.load_buffer(): # `data_row` is a dict with keys: features = make_features( data_row, [ "input_ids", "attention_mask", "labels", ], required_features=True, ) features.update( make_features( data_row, ["masked_lm_weights", "masked_lm_positions"] ) ) if "masked_lm_weights" in features: # Stored as masked_lm_weights, but really masked_lm_mask features["masked_lm_mask"] = features["masked_lm_weights"] features.pop("masked_lm_weights") if not self.disable_nsp: features.update( make_features( data_row, ["next_sentence_label", "token_type_ids"] ) ) yield features
def __iter__(self): batched_dataset = bucketed_batch( self.get_single_item(), self.batch_size, buckets=self.buckets, element_length_fn=lambda feats: np.sum(feats["attention_mask"]), drop_last=self.drop_last, seed=self.shuffle_seed, ) for batch in batched_dataset: if self.dynamic_mlm_scale: scale = self.batch_size / torch.sum(batch["masked_lm_mask"]) batch["mlm_loss_scale"] = scale.expand(self.batch_size, 1) yield batch def _worker_init_fn(self, worker_id): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers else: # Single-process worker_id = 0 num_workers = 1 self.processed_buffers = 0 if self.shuffle_seed is not None: self.shuffle_seed += worker_id + 1 self.rng = random.Random(self.shuffle_seed) # Shard the data across multiple processes. self.csv_files_per_task_per_worker = shard_list_interleaved( self.files_in_task, worker_id, num_workers ) if self.shuffle: self.rng.shuffle(self.csv_files_per_task_per_worker)
[docs] def create_dataloader(self): """ Classmethod to create the dataloader object. """ if self.num_workers: dataloader = torch.utils.data.DataLoader( self, batch_size=None, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, worker_init_fn=self._worker_init_fn, ) else: dataloader = torch.utils.data.DataLoader(self, batch_size=None) self._worker_init_fn(0) return dataloader