# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import shutil
import ujson as json
import zstandard
[docs]def utf8len(s):
    return len(s.encode('utf-8')) 
[docs]def cycle_documents(dataset, process_id, n_process, dup_sh, short_sh):
    while True:
        # https://github.com/EleutherAI/the-pile/blob/df97f8651ae3da658b19659b3ceaa6a34b0fc014/the_pile/utils.py#L104
        yield from filter(
            lambda x: x,
            dataset.documents(process_id, n_process, dup_sh, short_sh),
        ) 
[docs]def sha256str(s):
    h = hashlib.sha256()
    try:
        h.update(s.encode("utf-8"))
    except UnicodeEncodeError:
        # to avoid things like \ud809\udc50\ud808\udefc\ud808\udedb
        h.update(s.encode("utf-8", "replace"))
    return h.hexdigest() 
[docs]def rm_if_exists(path):
    try:
        if os.path.exists(path):
            shutil.rmtree(path)
    except NotADirectoryError:
        os.remove(path) 
[docs]def write_lmd_dataset(fh, lines, indices=None, return_total_written=False):
    cctx = zstandard.ZstdCompressor(level=3, threads=10)
    compressor = cctx.stream_writer(fh)
    # to not store large lists into memory, use index
    total_written = 0
    if indices is not None:
        for index in indices:
            text, meta = lines[index]
            compressor.write(
                json.dumps({"text": text, "meta": meta}).encode("UTF-8") + b"\n"
            )
            total_written += 1
    else:
        for line in lines:
            text, meta = line
            compressor.write(
                json.dumps({"text": text, "meta": meta}).encode("UTF-8") + b"\n"
            )
            total_written += 1
    compressor.flush(zstandard.FLUSH_FRAME)
    if return_total_written:
        return total_written