# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import pickle
import queue
import time
from collections import defaultdict
from glob import glob
from multiprocessing import Process, Queue
from datasketch.lean_minhash import LeanMinHash
from more_itertools import divide
def _H(hs):
    return bytes(hs.byteswap().data)
[docs]def split_files(input_dir, n_proc):
    files = []
    for dataset in [
        "arxiv",
        "stackexchange",
        "book",
        "wikipedia",
        "github",
        "c4",
        "common_crawl",
    ]:
        if dataset == "common_crawl":
            files.extend(glob(f"{input_dir}/{dataset}/*/minhash_nfc/*"))
        else:
            files.extend(glob(f"{input_dir}/{dataset}/minhash_nfc/*"))
    files = sorted(files)
    parts = divide(n_proc, files)
    return [list(p) for p in parts] 
[docs]def get_hashes(files, doc_queues, r):
    for fp in files:
        with open(fp, "rb") as fin:
            for item in pickle.load(fin):
                key = f"{item['file_name']}@{item['doc_id']}"
                minhash = LeanMinHash(item["hash"])
                for i, doc_queue in enumerate(doc_queues):
                    H = _H(minhash.hashvalues[i * r : (i + 1) * r])
                    doc_queue.put((key, H)) 
[docs]def lsh(out_file, doc_queue, idx):
    lsh_dict = defaultdict(str)
    i = 0
    start_time = time.time()
    f = open(out_file.replace(".txt", f"-{idx}.txt"), "w")
    while True:
        try:
            key, H = doc_queue.get(timeout=30)
            cand = lsh_dict.get(H, "None")
            if cand != "None":
                f.write(f'{key} :: {cand}\n')
            else:
                lsh_dict[H] = key
            if i % 100000 == 0:
                print(
                    f"{idx}: Processed {i / 931361530 * 100}%.",
                    time.time() - start_time,
                )
            i += 1
        except queue.Empty:
            break
    print(f"Total number of documents: {i}")
    f.close() 
[docs]def generate_pairs(args):
    # size of the queue was tuned for optimal perf and memory constraints.
    doc_queues = [Queue(1000000) for _ in range(args.bands)]
    files = split_files(args.input_dir, args.processes)
    processes = []
    for process_id in range(args.processes):
        p = Process(
            target=get_hashes,
            args=(
                files[process_id],
                doc_queues,
                args.range,
            ),
        )
        processes.append(p)
        p.start()
    for process_id in range(args.bands):
        p = Process(
            target=lsh,
            args=(
                args.out_file,
                doc_queues[process_id],
                process_id,
            ),
        )
        processes.append(p)
        p.start()
    for p in processes:
        p.join() 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir")
    parser.add_argument("--out_file")
    parser.add_argument(
        "--range",
        type=int,
    )
    parser.add_argument(
        "--bands",
        type=int,
    )
    parser.add_argument(
        "--processes",
        type=int,
    )
    args = parser.parse_args()
    generate_pairs(args)