Source code for data_processing.scripts.hdf5_preprocessing.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
from itertools import repeat
from math import ceil
from multiprocessing import Pool, cpu_count

import h5py
import yaml
from tqdm import tqdm

from modelzoo.transformers.data_processing.utils import split_list

logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)


[docs]def add_common_args(parser): """ For the argparse to parse arguments for subcommands, we add common command line arguments to each subcommand parser here. """ parser.add_argument( "--params", type=str, default=None, help="Path to the YAML config file for setting dataset preprocessing hyper-parameters.", ) parser.add_argument( "--input_dir", type=str, help="Directory where raw data is stored.", ) parser.add_argument( "--metadata_files", type=str, default=None, help="Path to text file containing a list of file names " "corresponding to the raw input documents to be " "processed and stored; can handle multiple metadata files " "separated by comma.", ) parser.add_argument( "--output_dir", type=str, help="Directory where HDF5 files will be stored.", ) parser.add_argument( "--processes", type=int, help="Number of processes to use.", ) parser.add_argument( "--tokenizer_type", type=str, choices=["GPT2Tokenizer", "NeoXTokenizer"], help=( "Type of tokenizer to use for HDF5 dataset generation. " "Can be one of `GPT2Tokenizer` or `NeoXTokenizer`." ), ) parser.add_argument( "--vocab_file", type=str, help="Path to the vocabulary file." ) parser.add_argument( "--encoder_file", type=str, help="Path to the encoder file." ) parser.add_argument( "--max_seq_length", type=int, help="Maximum sequence length.", ) parser.add_argument( "--short_seq_prob", type=float, default=0.0, help=( "Probability of creating sequences which are shorter than the" + " maximum sequence length." ), ) parser.add_argument( "--ftfy", type=str, choices=["True", "False"], help="Whether to fix text with ftfy.", ) parser.add_argument( "--ftfy_normalizer", type=str, choices=["NFC", "None"], help=( "Choose what kind of unicode normalization is applied. Usually, we" + " apply `NFC` normalization, so that letters followed by combining" + " characters become single combined characters. Using `None`" + " applies no normalization while fixing text." ), ) parser.add_argument( "--wikitext_detokenize", type=str, choices=["True", "False"], help="Whether to use wikitext detokenizer to fix text.", ) parser.add_argument( "--output_name", type=str, default="examples", help=( "Name of the dataset; i.e. prefix to use for HDF5 file names." + "Defaults to `examples`." ), ) parser.add_argument( "--files_per_record", type=int, help="Text files to write per HDF5 file.", ) parser.add_argument( "--write_in_batch", type=str, choices=["True", "False"], help="Whether to write the samples in batch for the HDF5 format, " "setting to false will save memory but a bit slower.", ) parser.add_argument( "--write_remainder", type=str, choices=["True", "False"], help="Write the remainder files when data is left over from " "processing.", ) parser.add_argument( "--pack_sequences", type=str, choices=["True", "False"], help="Concatenate a document smaller than maximum sequence length with " "other documents, instead of filling it with Padding token.", ) parser.add_argument( "--resume_from_checkpoint", type=str, choices=["True", "False"], help="Resume record writing from a given checkpoint.", ) parser.add_argument( "--display_pbar", type=str, choices=["True", "False"], help="Display progress while runs.", ) parser.add_argument( "--seed", type=int, help="Random seed.", )
[docs]def get_parser(desc): """Argparser definition for command line arguments from user. Returns: Argparse namespace object with command line arguments. """ parser = argparse.ArgumentParser(description=desc) subparser = parser.add_subparsers( description="Sub command for HDF5 conversion.", dest="mode", required=True, help="Sub command to choose saving the raw text into HDF5 files or " "pre-processed text converted into token ids at desired maximum " "sequence length.", ) lm_parser = subparser.add_parser( "LMData", help="Language modeling dataset in `.jsonl` or `.txt` format." ) add_common_args(lm_parser) lm_parser.add_argument( "--jsonl_key", type=str, default="text", help="The key name in input jsonl files from which the raw text will be " "extracted in order to further process it.", ) summarization_parser = subparser.add_parser( "Summarization", help="Fine-tuning dataset in plane text format." ) add_common_args(summarization_parser) summarization_parser.add_argument( "--input_sep", type=str, help="String that separates between prompts and their completions in input data files.", ) summarization_parser.add_argument( "--sep_token", type=str, default="<|sep|>", help="Token added between prompt and completion in preprocessed sequences.", ) json_pair_parser = subparser.add_parser( "JsonPair", help="Fine-tuning dataset in `.json` format." ) add_common_args(json_pair_parser) json_pair_parser.add_argument( "--prompt", type=str, help="Json key for the prompt.", ) json_pair_parser.add_argument( "--completion", type=str, help="Json key for the completion.", ) json_pair_parser.add_argument( "--sep_token", type=str, default="<|sep|>", help="Token added between prompt and completion in preprocessed sequences.", ) custom_parser = subparser.add_parser( "Customize", help="Provide customized dataset processor." ) add_common_args(custom_parser) custom_parser.add_argument( "--module", type=str, help="Python file name contains the custom dataset processor.", ) custom_parser.add_argument( "--dataset_processor", type=str, help="Name of the custom dataset processor.", ) return parser.parse_args()
[docs]def update_params(params, args): """ Update config parameters with CLI arguments """ setup_params = [ "input_dir", "metadata_files", "output_dir", "processes", "module", "dataset_processor", ] processing_params = [ "tokenizer_type", "vocab_file", "encoder_file", "max_seq_length", "short_seq_prob", "ftfy", "ftfy_normalizer", "wikitext_detokenize", "output_name", "files_per_record", "write_in_batch", "write_remainder", "pack_sequences", "resume_from_checkpoint", "display_pbar", "seed", ] processor_map = { "lmdata": "LMDataPreprocessor", "summarization": "SummarizationPreprocessor", } mode = args.pop("mode").lower() if mode != "customize": params["setup"]["dataset_processor"] = processor_map[mode] for key, value in args.items(): if value in ["True", "False"]: value = value == "True" if value is not None: if key in setup_params: params["setup"][key] = value elif key in processing_params: params["processing"][key] = value else: params["dataset"][key] = value
[docs]def get_params(desc): """Retrieve configuration parameters Returns: params (Dict): Dictionary contains the parameters used to configure the data processing. """ args = get_parser(desc) args = vars(args) params_file = args.pop("params", None) if params_file: with open(params_file, 'r') as stream: params = yaml.safe_load(stream) else: params = {} for section in ["setup", "processing", "dataset"]: if not params.get(section, None): params[section] = {} update_params(params, args) return params
[docs]def dump_args(args, json_params_file): """ Write the input params to file. """ logger.info(f"User arguments can be found at {json_params_file}.") # write initial params to file with open(json_params_file, "w") as _fout: json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def dump_result( results, json_params_file, eos_id=None, pad_id=None, vocab_size=None ): """ Write outputs of execution """ with open(json_params_file, "r") as _fin: data = json.load(_fin) post_process = {} post_process["discarded_files"] = results["discarded"] post_process["processed_files"] = results["processed"] post_process["successful_files"] = results["successful"] post_process["n_examples"] = results["examples"] if eos_id: post_process["eos_id"] = ( eos_id[0] if isinstance(eos_id, list) else eos_id ) if pad_id: post_process["pad_id"] = pad_id if vocab_size: post_process["vocab_size"] = vocab_size data["post-process"] = post_process with open(json_params_file, "w") as _fout: json.dump(data, _fout, indent=4, sort_keys=True)
[docs]def get_verification_args(params): args = argparse.Namespace() args.processes = params["setup"].get("processes", 0) if args.processes == 0: args.processes = cpu_count() args.files_per_record = params["processing"].get("files_per_record", 50000) args.max_seq_length = params["processing"].get("max_seq_length", 2048) return args
[docs]def process_dataset(files, dataset_processor, processes): """Process a dataset and write it into HDF5 format. Args: files (list): List of files to process. dataset_processor: Class containing methods that specify how the dataset will be processed and written into HDF5 files. processes (int): Number of processes to use. Returns: Dictionary containing results of execution, specifically as number of processed, discarded, and successful files as well as number of examples from all processes. """ if processes < 2: # Run only single process run, with process number set as 0. return dataset_processor.create_dataset((files, 0)) try: n_proc = processes n_chunks = ceil(len(files) / n_proc) remain = len(files) % n_proc if n_chunks == 1 and remain: n_proc = remain logger.warning( f"There aren't enough files to distribute to {processes} " f"processes, resetting it to {n_proc}. If you're working with a " "small number of compressed archives and could extract it into " "txt files, you might be able to get more benefits from the " f"available {processes} processes." ) files = split_list(files, n_chunks) except ValueError as e: # We hit errors in two potential scenarios, # 1) Files is an empty list, in which case there is nothing to split # 2) There are more processes than files, in which case we cannot split # the files to processes correctly, as there will be many idle # processes which are not doing anything. logger.error(e) raise with Pool(processes=n_proc) as pool: pbar = tqdm( pool.imap( dataset_processor.create_dataset, zip(files, range(len(files)),), ), total=len(files), ) meta = {"discarded": 0, "processed": 0, "successful": 0, "examples": 0} for results in pbar: pbar.update() for k, v in results.items(): meta[k] += v return meta
[docs]def verify_saved_hdf5_files(params): """ This function is used to do sanity checks at the end of the creation of hdf5 files. This function loads every .h5 files generated and checks: 1. The data type 2. Shape of the dataset 3. Fact that labels and inputs are as expected """ h5_files_path, args = params for h5_file_path in h5_files_path: with h5py.File(h5_file_path, mode="r") as h5_file: n_examples = h5_file.attrs["n_examples"] dataset = h5_file["data"] expected_dtype = "i4" assert dataset.dtype == expected_dtype, ( f"Error in {h5_file}, conversion is corrupted as the " f"datatype is unexpected. Expected: {expected_dtype}, " f"received {dataset.dtype}." ) data_shape = dataset[()].shape assert ( n_examples <= args.files_per_record or args.files_per_record == -1 ), ( f"Error in {h5_file}, conversion is corrupted as the " f"number of examples in file is unexpected. Expected:" f" {args.files_per_record}, received {n_examples}." ) assert ( data_shape[1:] == (3, args.max_seq_length) or args.max_seq_length == -1 ), ( f"Error in {h5_file}, conversion is corrupted as the " f"shape of example is unexpected. Expected:" f" {(3, args.max_seq_length)}, received {data_shape[1:]}." ) return True
[docs]def verify_saved_hdf5_files_mp(files, args): """Verify the generated HDF5 dataset. Args: files (list): List of files to process. args (argparse namespace): Arguments for verifying HDF5 dataset. """ if args.processes == 1: verify_saved_hdf5_files((files, args)) return try: n_proc = args.processes n_chunks = ceil(len(files) / n_proc) remain = len(files) % n_proc if n_chunks == 1 and remain: n_proc = remain logger.warning( f"There aren't enough files to distribute to {args.processes} " f"processes, resetting it to {n_proc}." ) files = split_list(files, n_chunks) except ValueError as e: # We hit errors in one potential scenario: # Files is an empty list, in which case there is nothing to split logger.error(e) raise with Pool(processes=n_proc) as pool: pbar = tqdm( pool.imap(verify_saved_hdf5_files, zip(files, repeat(args))), total=len(files), ) for test in pbar: if test: continue