# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import pickle
import re
import string
from itertools import repeat
from multiprocessing import Pool, cpu_count
import jsonlines
from datasketch import MinHash
from lm_dataformat import Reader
from more_itertools import chunked
from nltk import ngrams
from tqdm import tqdm
[docs]def get_features(s, width):
    # lower cased
    s = s.lower()
    # remove punctuation
    s = s.translate(str.maketrans("", "", string.punctuation))
    # remove consecutive spaces, newlines, tabs in the middle and in the beginning / end
    s = re.sub(r"\s+", " ", s.strip())
    return map(lambda x: "".join(x), ngrams(s, width)) 
[docs]def get_documents(input_dir, index_start, index_end, output_dir, dataset_name):
    gc.collect()
    files = sorted(os.listdir(input_dir))
    files = list(filter(lambda file_: '.jsonl' in file_, files))
    for i, input_file in tqdm(enumerate(files[index_start:index_end])):
        file_path = f"{input_dir}/{input_file}"
        file_name = file_path.split("/")[-1]
        if dataset_name == "common_crawl":
            dir_2 = file_path.split("/")[-2]
            output_name = f"{dataset_name}/{dir_2}/{file_name}"
        else:
            output_name = f"{dataset_name}/{file_name}"
        if dataset_name == "common_crawl":
            reader = Reader(file_path)
            for doc_id, doc in enumerate(reader._stream_data(jsonl_key="text")):
                yield doc, file_path, doc_id
        else:
            with jsonlines.open(file_path) as rdr:
                for doc_id, doc in enumerate(rdr):
                    yield doc["text"], file_path, doc_id 
[docs]def to_minhash(chunks):
    gc.collect()
    buckets = []
    documents, output_dir, width, dataset_name, n_docs = chunks
    for doc in tqdm(documents, total=n_docs):
        text, file_path, doc_id = doc[0], doc[1], doc[2]
        file_name = file_path.split("/")[-1]
        if dataset_name == "common_crawl":
            dir_2 = file_path.split("/")[-2]
            output_name = f"{dataset_name}/{dir_2}/{file_name}"
        else:
            output_name = f"{dataset_name}/{file_name}"
        m = MinHash(num_perm=128)
        [m.update(x.encode('utf8')) for x in get_features(text, width)]
        buckets.append(
            {
                "file_name": output_name,
                "doc_id": doc_id,
                "hash": m,
            }
        )
    return buckets 
[docs]def output_results(output_dir, results, chunk_id, iter):
    with open(
        f"{output_dir}/minhash_nfc/{iter}-{chunk_id}.pickle", "wb"
    ) as fout:
        pickle.dump(results, fout) 
[docs]def generate_hashes(args):
    if not os.path.exists(f"{args.output_dir}/minhash_nfc"):
        os.mkdir(f"{args.output_dir}/minhash_nfc")
    documents = get_documents(
        args.input_dir,
        args.index_start,
        args.index_end,
        args.output_dir,
        args.dataset_name,
    )
    results = []
    chunk_id = 0
    gc.collect()
    with Pool(processes=cpu_count()) as pool:
        for i, chunks in enumerate(
            tqdm(
                pool.imap(
                    to_minhash,
                    zip(
                        chunked(documents, args.n_docs // cpu_count()),
                        repeat(args.output_dir),
                        repeat(args.w),
                        repeat(args.dataset_name),
                        repeat(args.n_docs // cpu_count()),
                    ),
                ),
                total=cpu_count(),
            )
        ):
            for chunk in chunks:
                if len(results) == args.k:
                    output_results(
                        args.output_dir, results, chunk_id, args.iter
                    )
                    del results
                    gc.collect()
                    results = []
                    chunk_id += 1
                results.append(chunk)
    if results:
        output_results(args.output_dir, results, chunk_id, args.iter) 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_name")
    parser.add_argument("input_dir", help="Input directory with documents.")
    parser.add_argument(
        "output_dir", help="Output directory to output minhash files to."
    )
    parser.add_argument(
        "n_docs", type=int, help="Number of documents located in the dataset."
    )
    parser.add_argument("iter", help="Job id")
    parser.add_argument(
        "index_start",
        type=int,
        help="Start indexing documents from input directory after ls.",
    )
    parser.add_argument(
        "index_end",
        type=int,
        help="End indexing documents from input directory after ls.",
    )
    parser.add_argument(
        "-w", type=int, default=6, help="The window size", required=False
    )
    parser.add_argument(
        "-k",
        type=int,
        default=10000,
        help="Number of batches to output with.",
        required=False,
    )
    args = parser.parse_args()
    generate_hashes(args)