# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
# isort: off
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
# isort: on
from cerebras.modelzoo.data_preparation.nlp.slimpajama.dedup import (
    dedup_train,
    generate_connected_components,
    generate_duplicate_pairs,
    generate_duplicates_dict,
    to_hash,
)
from cerebras.modelzoo.data_preparation.nlp.slimpajama.preprocessing import (
    datasets,
    filter,
    normalize_text,
    shuffle_holdout,
)
ds_names = [
    "arxiv",
    "stackexchange",
    "book",
    "wikipedia",
    "github",
    "c4",
    "common_crawl",
]
cc_years = ["2019-30", "2020-05", "2021-04", "2022-05", "2023-06"]
n_documents = [
    1558306,
    29825086,
    205744,
    29834171,
    28793312,
    364868892,
    81085420,
    90850492,
    98878523,
    94058868,
    111402716,
]
[docs]def main(input_dir):
    # norm text
    ds_dirs = ds_names.copy()
    ds_dirs.remove("common_crawl")
    for cc in cc_years:
        ds_dirs.append("common_crawl/" + cc)
    red_pj_norm = os.path.join(input_dir, "RedPajama_norm")
    for dataset in ds_dirs:
        norm_args = argparse.Namespace()
        norm_args.data_dir = os.path.join(input_dir, dataset)
        norm_args.target_dir = os.path.join(red_pj_norm, dataset)
        norm_args.zst = "common_crawl" in dataset
        norm_args.idx = -1
        normalize_text.normalize_text(norm_args)
    # filter docs
    short_docs = os.path.join(red_pj_norm, "red_pj_filter.pickle")
    filter_args = argparse.Namespace()
    filter_args.input_dir = red_pj_norm
    filter_args.output_file = short_docs
    filter_args.n_docs = sum(n_documents)
    filter_args.dataset_name = "all"
    filter_args.threshold = 200
    filter.filter_dataset(filter_args)
    # generate minhash
    for idx, dataset in enumerate(ds_dirs):
        hash_args = argparse.Namespace()
        hash_args.dataset_name = (
            "common_crawl" if "common_crawl" in dataset else dataset
        )
        hash_args.input_dir = os.path.join(red_pj_norm, dataset)
        hash_args.output_dir = os.path.join(red_pj_norm, dataset)
        hash_args.n_docs = n_documents[idx]
        hash_args.iter = 0
        hash_args.index_start = 0
        hash_args.index_end = None
        hash_args.w = 13
        hash_args.k = 10000
        to_hash.generate_hashes(hash_args)
    # generate duplicates
    dup_dir = os.path.join(red_pj_norm, "dup")
    os.makedirs(dup_dir, exist_ok=True)
    dup_pairs_args = argparse.Namespace()
    dup_pairs_args.input_dir = red_pj_norm
    dup_pairs_args.out_file = os.path.join(dup_dir, "duplicate_pairs.txt")
    dup_pairs_args.range = 13
    dup_pairs_args.bands = 9
    dup_pairs_args.processes = 45
    generate_duplicate_pairs.generate_pairs(dup_pairs_args)
    dup_connected_args = argparse.Namespace()
    dup_connected_args.input_dir = dup_dir
    dup_connected_args.out_file = os.path.join(
        dup_dir, "connected_components.pickle"
    )
    generate_connected_components.generate_connected_components_mp(
        dup_connected_args
    )
    dup_docs = os.path.join(dup_dir, "duplicates.pickle")
    dup_dict_args = argparse.Namespace()
    dup_dict_args.input_file = os.path.join(
        dup_dir, "connected_components.pickle"
    )
    dup_dict_args.out_file = dup_docs
    generate_duplicates_dict.generate_duplicates(dup_dict_args)
    # interleave & shuffle
    shuffle_holdout.pass_1_shuffle(
        datasets.RedPajamaReplication(
            datasets.redpj_datasets(red_pj_norm + "/"), dup_docs, short_docs
        ),
        output_dir_path=os.path.join(red_pj_norm, "pass1"),
    )
    # split train & holdout
    for j in range(1, 21):
        shuffle_holdout.pass_2_shuffle_holdout(
            input_dir_path=os.path.join(red_pj_norm, "pass1"),
            output_dir_path=os.path.join(red_pj_norm, "train"),
            output_holdout_dir_path=os.path.join(red_pj_norm, "holdout"),
            start_index=j - 1,
            end_index=j,
            chunk_id=j,
        )
    # Deduplicate Train against Holdout
    for j in range(1, 21):
        dedup_train.deduplicate_train_holdout_sets(
            os.path.join(red_pj_norm, "train"),
            os.path.join(red_pj_norm, "holdout"),
            os.path.join(red_pj_norm, "train_deduped"),
            j,
        ) 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_dir", help="Dataset input directory.")
    args = parser.parse_args()
    main(args.input_dir)