# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
# isort: off
import sys
# isort: on
import subprocess
import warnings
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../.."))
from modelzoo.common.input.utils import check_and_create_output_dirs
[docs]def parse_args():
    """Argparser definition for command line arguments from user.
    Returns:
        Argparse namespace object with command line arguments.
    """
    parser = argparse.ArgumentParser(
        description="Download the raw Pile data and associated vocabulary for pre-processing."
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Base directory where raw data is to be downloaded.",
    )
    parser.add_argument(
        "--name",
        type=str,
        default="pile",
        help=(
            "Sub-directory where raw data is to be downloaded."
            + " Defaults to `pile`."
        ),
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Checks if a given split exists in remote location.",
    )
    return parser.parse_args() 
[docs]def get_urls_from_split(split):
    """Get urls given split of dataset.
    Args:
        split (str): Split of dataset to get urls for.
    Returns:
        List of urls, containing jsonl.zst file names for downloading.
    """
    if split == "train":
        warnings.warn(
            message=(
                f"Starting a large download process for full training data."
                + f" This process takes time and needs a storage with"
                + f" at least 500GB space."
            ),
            category=UserWarning,
        )
        urls = [
            f"https://mystic.the-eye.eu/public/AI/pile/train/{i:02}.jsonl.zst"
            for i in range(30)
        ]
    elif split == "val":
        urls = ["https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst"]
    elif split == "test":
        urls = ["https://mystic.the-eye.eu/public/AI/pile/test.jsonl.zst"]
    return urls 
[docs]def get_urls_for_tokenizer_files():
    """Get urls for downloading files for tokenization.
    Returns:
        A dictionary containing urls for original GPT2 tokenizaiton and GPT-NeoX
        tokenization schemes
    """
    return {
        "gpt2-vocab.bpe": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
        "gpt2-encoder.json": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
        "neox-20B-tokenizer.json": "https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json",
    } 
[docs]def debug_or_download_individual_file(url, filepath, debug=False):
    """Download a single file from url to specified filepath.
    Args:
        url (str): Url to download the data from.
        filepath (str): Filename (with path) to download the data to.
        debug (bool): Check if remote file exists. Defaults to `False`.
    """
    if debug:
        # use --no-check-certificate as eye.ai throws the below error:
        # `cannot verify mystic.the-eye.eu's certificate, issued by ‘/C=US/O=Let's Encrypt/CN=R3’,
        #    Issued certificate has expired.`
        cmd = f"wget --no-check-certificate --spider {url}"
        subprocess.run(cmd.split(" "), check=True)
        return
    execute = False
    # check for each individual file, because the train split has 30
    # individual files and in a potential previous download attempt,
    # some files may not have downloaded to the specified path.
    if not os.path.isfile(filepath):
        execute = True
    elif os.stat(filepath).st_size == 0:
        # Previous attempt at downloading file failed, but wget stats
        # the file. Check if filesize is 0, if so, delete and execute the
        # download process again.
        execute = True
        print(f"Got empty file at {filepath}, deleting and downloading again.")
        cmd = f"rm -rf {filepath}"
        subprocess.run(cmd.split(" "), check=True)
    else:
        print(
            f"{os.path.basename(filepath)} exists at {os.path.dirname(filepath)}"
            + f", skipping download."
        )
    # use --no-check-certificate as eye.ai throws the below error:
    # `cannot verify mystic.the-eye.eu's certificate, issued by ‘/C=US/O=Let's Encrypt/CN=R3’,
    #    Issued certificate has expired.`
    if execute:
        cmd = f"wget --no-check-certificate {url} -O {filepath}"
        subprocess.run(cmd.split(" "), check=True) 
[docs]def download_pile(args, split):
    """Download The Pile dataset from eye.ai website.
    Args:
        args (argparse namespace): Arguments for downloading the dataset.
        split (str): The subset of the PILE dataset to download.
    """
    check_and_create_output_dirs(
        os.path.join(args.data_dir, args.name, split), filetype="jsonl.zst",
    )
    urls = get_urls_from_split(split)
    for url in urls:
        filepath = os.path.join(
            args.data_dir, args.name, split, os.path.basename(url)
        )
        debug_or_download_individual_file(url, filepath, args.debug) 
[docs]def download_tokenizer_files(args):
    """Download files needed for tokenization for dataset creation.
    Args:
        args (argparse namespace): Arguments for downloading the tokenizer files.
    """
    check_and_create_output_dirs(
        os.path.join(args.data_dir, args.name, "vocab"), filetype="json",
    )
    check_and_create_output_dirs(
        os.path.join(args.data_dir, args.name, "vocab"), filetype="bpe",
    )
    urls_to_download = get_urls_for_tokenizer_files()
    for key, value in urls_to_download.items():
        if args.debug:
            cmd = f"wget --no-check-certificate --spider {value}"
            subprocess.run(cmd.split(" "), check=True)
            # continue since we want to run only debug, but for all items
            # in the url dictionary
            continue
        filepath = os.path.join(args.data_dir, args.name, "vocab", key)
        cmd = f"wget --no-check-certificate {value} -O {filepath}"
        subprocess.run(cmd.split(" "), check=True) 
[docs]def main():
    """Main function for execution."""
    args = parse_args()
    # download all subsets and the corresponding tokenizer files
    for split in ["train", "val", "test"]:
        download_pile(args, split)
    download_tokenizer_files(args) 
if __name__ == "__main__":
    main()