Source code for cerebras.modelzoo.data_preparation.nlp.gptj.split_trc_dataset

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Split input src file into multiple files.
Number of examples per file is controlled by `buffer_len` param.
"""

import argparse
import os

from tqdm import tqdm


[docs]def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--input_file", type=str, required=True, help="Path to the original source language dataset stored as one file.", ) parser.add_argument( "--out_dir", type=str, required=True, help="Path to output directory with source language dataset files.", ) parser.add_argument( "--buffer_len", type=int, default=10000, help="Number of examples to store in one file.", ) parser.add_argument( "--val_split_ratio", type=float, default=0.1, help="Ratio of the output number of files to be considered as " "validation dataset.", ) args = parser.parse_args() src_postfix = args.input_file.split("/")[-1] split_dir = os.path.join(args.out_dir, "split_files") if not os.path.isdir(args.out_dir): os.mkdir(args.out_dir) os.mkdir(split_dir) out_files_list = [] with open(args.input_file, "r") as src_fin: src_lines = src_fin.readlines() count = 0 start_idx = 0 end_idx = min(args.buffer_len, len(src_lines)) num_out_files = (len(src_lines) // args.buffer_len) + 1 pbar = tqdm(range(num_out_files)) while end_idx < len(src_lines): split_lines = src_lines[start_idx:end_idx] shard_out_file_name = os.path.join(split_dir, f"{src_postfix}-{count}") out_files_list.append(shard_out_file_name) with open(shard_out_file_name, "w") as src_fout: src_fout.writelines(split_lines) count += 1 start_idx += args.buffer_len end_idx = min(end_idx + args.buffer_len, len(src_lines)) pbar.update() val_count = int(args.val_split_ratio * len(out_files_list)) train_count = len(out_files_list) - val_count train_files = out_files_list[:train_count] val_files = out_files_list[train_count:] with open(os.path.join(args.out_dir, "train_meta.txt"), "w") as fid: fid.writelines([file + "\n" for file in train_files]) with open(os.path.join(args.out_dir, "val_meta.txt"), "w") as fid: fid.writelines([file + "\n" for file in val_files])
if __name__ == "__main__": main()