# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
from itertools import repeat
from math import ceil
from multiprocessing import Pool, cpu_count
import h5py
import yaml
from tqdm import tqdm
from modelzoo.transformers.data_processing.utils import split_list
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
[docs]def add_common_args(parser):
"""
For the argparse to parse arguments for subcommands, we add common
command line arguments to each subcommand parser here.
"""
parser.add_argument(
"--params",
type=str,
default=None,
help="Path to the YAML config file for setting dataset preprocessing hyper-parameters.",
)
parser.add_argument(
"--input_dir", type=str, help="Directory where raw data is stored.",
)
parser.add_argument(
"--metadata_files",
type=str,
default=None,
help="Path to text file containing a list of file names "
"corresponding to the raw input documents to be "
"processed and stored; can handle multiple metadata files "
"separated by comma.",
)
parser.add_argument(
"--output_dir",
type=str,
help="Directory where HDF5 files will be stored.",
)
parser.add_argument(
"--processes", type=int, help="Number of processes to use.",
)
parser.add_argument(
"--tokenizer_type",
type=str,
choices=["GPT2Tokenizer", "NeoXTokenizer"],
help=(
"Type of tokenizer to use for HDF5 dataset generation. "
"Can be one of `GPT2Tokenizer` or `NeoXTokenizer`."
),
)
parser.add_argument(
"--vocab_file", type=str, help="Path to the vocabulary file."
)
parser.add_argument(
"--encoder_file", type=str, help="Path to the encoder file."
)
parser.add_argument(
"--max_seq_length", type=int, help="Maximum sequence length.",
)
parser.add_argument(
"--short_seq_prob",
type=float,
default=0.0,
help=(
"Probability of creating sequences which are shorter than the"
+ " maximum sequence length."
),
)
parser.add_argument(
"--ftfy",
type=str,
choices=["True", "False"],
help="Whether to fix text with ftfy.",
)
parser.add_argument(
"--ftfy_normalizer",
type=str,
choices=["NFC", "None"],
help=(
"Choose what kind of unicode normalization is applied. Usually, we"
+ " apply `NFC` normalization, so that letters followed by combining"
+ " characters become single combined characters. Using `None`"
+ " applies no normalization while fixing text."
),
)
parser.add_argument(
"--wikitext_detokenize",
type=str,
choices=["True", "False"],
help="Whether to use wikitext detokenizer to fix text.",
)
parser.add_argument(
"--output_name",
type=str,
default="examples",
help=(
"Name of the dataset; i.e. prefix to use for HDF5 file names."
+ "Defaults to `examples`."
),
)
parser.add_argument(
"--files_per_record",
type=int,
help="Text files to write per HDF5 file.",
)
parser.add_argument(
"--write_in_batch",
type=str,
choices=["True", "False"],
help="Whether to write the samples in batch for the HDF5 format, "
"setting to false will save memory but a bit slower.",
)
parser.add_argument(
"--write_remainder",
type=str,
choices=["True", "False"],
help="Write the remainder files when data is left over from "
"processing.",
)
parser.add_argument(
"--pack_sequences",
type=str,
choices=["True", "False"],
help="Concatenate a document smaller than maximum sequence length with "
"other documents, instead of filling it with Padding token.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
choices=["True", "False"],
help="Resume record writing from a given checkpoint.",
)
parser.add_argument(
"--display_pbar",
type=str,
choices=["True", "False"],
help="Display progress while runs.",
)
parser.add_argument(
"--seed", type=int, help="Random seed.",
)
[docs]def get_parser(desc):
"""Argparser definition for command line arguments from user.
Returns:
Argparse namespace object with command line arguments.
"""
parser = argparse.ArgumentParser(description=desc)
subparser = parser.add_subparsers(
description="Sub command for HDF5 conversion.",
dest="mode",
required=True,
help="Sub command to choose saving the raw text into HDF5 files or "
"pre-processed text converted into token ids at desired maximum "
"sequence length.",
)
lm_parser = subparser.add_parser(
"LMData", help="Language modeling dataset in `.jsonl` or `.txt` format."
)
add_common_args(lm_parser)
lm_parser.add_argument(
"--jsonl_key",
type=str,
default="text",
help="The key name in input jsonl files from which the raw text will be "
"extracted in order to further process it.",
)
summarization_parser = subparser.add_parser(
"Summarization", help="Fine-tuning dataset in plane text format."
)
add_common_args(summarization_parser)
summarization_parser.add_argument(
"--input_sep",
type=str,
help="String that separates between prompts and their completions in input data files.",
)
summarization_parser.add_argument(
"--sep_token",
type=str,
default="<|sep|>",
help="Token added between prompt and completion in preprocessed sequences.",
)
json_pair_parser = subparser.add_parser(
"JsonPair", help="Fine-tuning dataset in `.json` format."
)
add_common_args(json_pair_parser)
json_pair_parser.add_argument(
"--prompt", type=str, help="Json key for the prompt.",
)
json_pair_parser.add_argument(
"--completion", type=str, help="Json key for the completion.",
)
json_pair_parser.add_argument(
"--sep_token",
type=str,
default="<|sep|>",
help="Token added between prompt and completion in preprocessed sequences.",
)
custom_parser = subparser.add_parser(
"Customize", help="Provide customized dataset processor."
)
add_common_args(custom_parser)
custom_parser.add_argument(
"--module",
type=str,
help="Python file name contains the custom dataset processor.",
)
custom_parser.add_argument(
"--dataset_processor",
type=str,
help="Name of the custom dataset processor.",
)
return parser.parse_args()
[docs]def update_params(params, args):
"""
Update config parameters with CLI arguments
"""
setup_params = [
"input_dir",
"metadata_files",
"output_dir",
"processes",
"module",
"dataset_processor",
]
processing_params = [
"tokenizer_type",
"vocab_file",
"encoder_file",
"max_seq_length",
"short_seq_prob",
"ftfy",
"ftfy_normalizer",
"wikitext_detokenize",
"output_name",
"files_per_record",
"write_in_batch",
"write_remainder",
"pack_sequences",
"resume_from_checkpoint",
"display_pbar",
"seed",
]
processor_map = {
"lmdata": "LMDataPreprocessor",
"summarization": "SummarizationPreprocessor",
}
mode = args.pop("mode").lower()
if mode != "customize":
params["setup"]["dataset_processor"] = processor_map[mode]
for key, value in args.items():
if value in ["True", "False"]:
value = value == "True"
if value is not None:
if key in setup_params:
params["setup"][key] = value
elif key in processing_params:
params["processing"][key] = value
else:
params["dataset"][key] = value
[docs]def get_params(desc):
"""Retrieve configuration parameters
Returns:
params (Dict): Dictionary contains the parameters used to configure
the data processing.
"""
args = get_parser(desc)
args = vars(args)
params_file = args.pop("params", None)
if params_file:
with open(params_file, 'r') as stream:
params = yaml.safe_load(stream)
else:
params = {}
for section in ["setup", "processing", "dataset"]:
if not params.get(section, None):
params[section] = {}
update_params(params, args)
return params
[docs]def dump_args(args, json_params_file):
"""
Write the input params to file.
"""
logger.info(f"User arguments can be found at {json_params_file}.")
# write initial params to file
with open(json_params_file, "w") as _fout:
json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def dump_result(
results, json_params_file, eos_id=None, pad_id=None, vocab_size=None
):
"""
Write outputs of execution
"""
with open(json_params_file, "r") as _fin:
data = json.load(_fin)
post_process = {}
post_process["discarded_files"] = results["discarded"]
post_process["processed_files"] = results["processed"]
post_process["successful_files"] = results["successful"]
post_process["n_examples"] = results["examples"]
if eos_id:
post_process["eos_id"] = (
eos_id[0] if isinstance(eos_id, list) else eos_id
)
if pad_id:
post_process["pad_id"] = pad_id
if vocab_size:
post_process["vocab_size"] = vocab_size
data["post-process"] = post_process
with open(json_params_file, "w") as _fout:
json.dump(data, _fout, indent=4, sort_keys=True)
[docs]def get_verification_args(params):
args = argparse.Namespace()
args.processes = params["setup"].get("processes", 0)
if args.processes == 0:
args.processes = cpu_count()
args.files_per_record = params["processing"].get("files_per_record", 50000)
args.max_seq_length = params["processing"].get("max_seq_length", 2048)
return args
[docs]def process_dataset(files, dataset_processor, processes):
"""Process a dataset and write it into HDF5 format.
Args:
files (list): List of files to process.
dataset_processor: Class containing methods that specify how the dataset
will be processed and written into HDF5 files.
processes (int): Number of processes to use.
Returns:
Dictionary containing results of execution, specifically as number of
processed, discarded, and successful files as well as number of examples
from all processes.
"""
if processes < 2:
# Run only single process run, with process number set as 0.
return dataset_processor.create_dataset((files, 0))
try:
n_proc = processes
n_chunks = ceil(len(files) / n_proc)
remain = len(files) % n_proc
if n_chunks == 1 and remain:
n_proc = remain
logger.warning(
f"There aren't enough files to distribute to {processes} "
f"processes, resetting it to {n_proc}. If you're working with a "
"small number of compressed archives and could extract it into "
"txt files, you might be able to get more benefits from the "
f"available {processes} processes."
)
files = split_list(files, n_chunks)
except ValueError as e:
# We hit errors in two potential scenarios,
# 1) Files is an empty list, in which case there is nothing to split
# 2) There are more processes than files, in which case we cannot split
# the files to processes correctly, as there will be many idle
# processes which are not doing anything.
logger.error(e)
raise
with Pool(processes=n_proc) as pool:
pbar = tqdm(
pool.imap(
dataset_processor.create_dataset,
zip(files, range(len(files)),),
),
total=len(files),
)
meta = {"discarded": 0, "processed": 0, "successful": 0, "examples": 0}
for results in pbar:
pbar.update()
for k, v in results.items():
meta[k] += v
return meta
[docs]def verify_saved_hdf5_files(params):
"""
This function is used to do sanity checks at the end of the creation
of hdf5 files.
This function loads every .h5 files generated and checks:
1. The data type
2. Shape of the dataset
3. Fact that labels and inputs are as expected
"""
h5_files_path, args = params
for h5_file_path in h5_files_path:
with h5py.File(h5_file_path, mode="r") as h5_file:
n_examples = h5_file.attrs["n_examples"]
dataset = h5_file["data"]
expected_dtype = "i4"
assert dataset.dtype == expected_dtype, (
f"Error in {h5_file}, conversion is corrupted as the "
f"datatype is unexpected. Expected: {expected_dtype}, "
f"received {dataset.dtype}."
)
data_shape = dataset[()].shape
assert (
n_examples <= args.files_per_record
or args.files_per_record == -1
), (
f"Error in {h5_file}, conversion is corrupted as the "
f"number of examples in file is unexpected. Expected:"
f" {args.files_per_record}, received {n_examples}."
)
assert (
data_shape[1:] == (3, args.max_seq_length)
or args.max_seq_length == -1
), (
f"Error in {h5_file}, conversion is corrupted as the "
f"shape of example is unexpected. Expected:"
f" {(3, args.max_seq_length)}, received {data_shape[1:]}."
)
return True
[docs]def verify_saved_hdf5_files_mp(files, args):
"""Verify the generated HDF5 dataset.
Args:
files (list): List of files to process.
args (argparse namespace): Arguments for verifying HDF5 dataset.
"""
if args.processes == 1:
verify_saved_hdf5_files((files, args))
return
try:
n_proc = args.processes
n_chunks = ceil(len(files) / n_proc)
remain = len(files) % n_proc
if n_chunks == 1 and remain:
n_proc = remain
logger.warning(
f"There aren't enough files to distribute to {args.processes} "
f"processes, resetting it to {n_proc}."
)
files = split_list(files, n_chunks)
except ValueError as e:
# We hit errors in one potential scenario:
# Files is an empty list, in which case there is nothing to split
logger.error(e)
raise
with Pool(processes=n_proc) as pool:
pbar = tqdm(
pool.imap(verify_saved_hdf5_files, zip(files, repeat(args))),
total=len(files),
)
for test in pbar:
if test:
continue