# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
# 2 pass shuffling algorithm: https://blog.janestreet.com/how-to-shuffle-a-big-dataset/
import queue
import random
import time
from multiprocessing import Process, Queue
import lm_dataformat as lmd
from more_itertools import chunked
from tqdm import tqdm
# isort: off
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../.."))
# isort: on
from modelzoo.transformers.data_processing.slimpajama.preprocessing.datasets import (
    RedPajamaReplication,
    redpj_datasets,
)
from modelzoo.transformers.data_processing.slimpajama.utils.utils import (
    rm_if_exists,
    write_lmd_dataset,
)
[docs]def write_docs(q, ar, archive_name):
    i = 0
    start_time = time.time()
    while True:
        try:
            doc = q.get(timeout=60)
            ar.add_data(doc["doc"], doc["meta"])
            i += 1
            if i % 10000 == 0:
                ar.commit(archive_name=archive_name + str(ar.i))
                print(
                    f"Total number of processed documents: {i} ",
                    f"Total time: {time.time() - start_time}",
                )
        except queue.Empty:
            ar.commit(archive_name="redpajama" + str(ar.i))
            break
    print("Finished writing documents.") 
[docs]def pass_1_shuffle(
    redpajama_dataset, output_dir_path="./", archive_name="redpajama"
):
    # We create piles of the dataset and store them as lmd;
    rm_if_exists(output_dir_path)
    os.mkdir(output_dir_path)
    total_bytes = redpajama_dataset.size()
    n_process = 20
    ars = [
        lmd.Archive(f"{output_dir_path}/chunk{i}", threads=10)
        for i in range(n_process)
    ]
    # queue to collect documents from reading processes
    docs_queue = [Queue(64 * 10000) for _ in range(n_process)]
    # returning a list of reading processes
    r_procs, manager = redpajama_dataset.documents(docs_queue)
    w_procs = []
    for process_id in range(n_process):
        p = Process(
            target=write_docs,
            args=(docs_queue[process_id], ars[process_id], archive_name),
        )
        w_procs.append(p)
    # run read and write processes in parallel
    prs = r_procs + w_procs
    for p in prs:
        p.start()
    for p in prs:
        p.join()
    print("Pass 1 finished...") 
[docs]def pass_2_shuffle_holdout(
    input_dir_path,
    output_dir_path,
    output_holdout_dir_path,
    start_index,
    end_index,
    chunk_id,
):
    # both eval and test set contain 0.17% of the data.
    holdout_ratio = 0.0017
    # We shuffle each pile of documents in memory here
    random.seed(42)
    print("Pass 2 started, going through pile documents...")
    chunks = os.listdir(input_dir_path)[start_index:end_index]
    print(chunks)
    start_time = time.time()
    for chunk in tqdm(chunks, total=len(chunks)):
        print(f"Started processing chunk {chunk_id} in pass 2...")
        reader = lmd.Reader(f"{input_dir_path}/{chunk}")
        lines = []
        for doc_id, doc in enumerate(reader._stream_data(get_meta=True)):
            text, meta = doc
            lines.append((text, meta))
            if doc_id % 10000 == 0:
                print(
                    f"Processed doc {doc_id} after {time.time() - start_time}"
                )
        # shuffling each output file.
        random.shuffle(lines)
        # selecting a subset for holdout
        pivot = int(len(lines) * holdout_ratio)
        n = len(os.listdir(f"{input_dir_path}/{chunk}"))
        buckets_train = list(
            chunked(range(pivot, len(lines)), (len(lines) - pivot) // n)
        )
        buckets_holdout = list(chunked(range(0, pivot), pivot // n))
        train_output_chunk = f"{output_dir_path}/chunk{chunk_id}"
        holdout_output_chunk = f"{output_holdout_dir_path}/chunk{chunk_id}"
        os.makedirs(output_dir_path, exist_ok=True)
        os.makedirs(output_holdout_dir_path, exist_ok=True)
        rm_if_exists(train_output_chunk)
        os.mkdir(train_output_chunk)
        rm_if_exists(holdout_output_chunk)
        os.mkdir(holdout_output_chunk)
        for j in range(n):
            output_file_name = (
                f"{train_output_chunk}/example_train_{j}.jsonl.zst"
            )
            output_holdout_file_name = (
                f"{holdout_output_chunk}/example_holdout_{j}.jsonl.zst"
            )
            with open(output_file_name, "wb") as fout, open(
                output_holdout_file_name, "wb"
            ) as holdout_fout:
                # train output set
                write_lmd_dataset(fout, lines, buckets_train[j])
                # holdout output set
                write_lmd_dataset(holdout_fout, lines, buckets_holdout[j])
        for j in range(n, len(buckets_train)):
            output_file_name = (
                f"{train_output_chunk}/example_train_{j}.jsonl.zst"
            )
            with open(output_file_name, "wb") as fout:
                write_lmd_dataset(fout, lines, buckets_train[j])
        for j in range(n, len(buckets_holdout)):
            output_holdout_file_name = (
                f"{holdout_output_chunk}/example_holdout_{j}.jsonl.zst"
            )
            with open(output_holdout_file_name, "wb") as holdout_fout:
                write_lmd_dataset(holdout_fout, lines, buckets_holdout[j])
        del lines
        gc.collect()
        print("Pass 2 is finished.") 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    subparser = parser.add_subparsers(dest="stage", required=True)
    pass1_parser = subparser.add_parser("pass1")
    pass1_parser.add_argument("--input_dir", type=str)
    pass1_parser.add_argument("--duplicates", type=str)
    pass1_parser.add_argument("--short_docs", type=str)
    pass1_parser.add_argument("--out_dir", type=str)
    pass2_parser = subparser.add_parser("pass2")
    pass2_parser.add_argument("start_index", type=int)
    pass2_parser.add_argument("end_index", type=int)
    pass2_parser.add_argument("chunk_id", type=int)
    pass2_parser.add_argument("--input_dir", type=str)
    pass2_parser.add_argument("--train_dir", type=str)
    pass2_parser.add_argument("--holdout_dir", type=str)
    args = parser.parse_args()
    if args.stage == "pass1":
        inputdir = args.input_dir
        if inputdir[-1] != '/':
            inputdir += '/'
        pass_1_shuffle(
            RedPajamaReplication(
                redpj_datasets(inputdir), args.duplicates, args.short_docs
            ),
            output_dir_path=args.out_dir,
        )
    elif args.stage == "pass2":
        pass_2_shuffle_holdout(
            input_dir_path=args.input_dir,
            output_dir_path=args.train_dir,
            output_holdout_dir_path=args.holdout_dir,
            start_index=args.start_index,
            end_index=args.end_index,
            chunk_id=args.chunk_id,
        )
    else:
        print("Please specify either pass1 or pass2")