# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import spacy
from tqdm import tqdm
from cerebras.modelzoo.data_preparation.nlp.tokenizers.Tokenization import (
    FullTokenizer,
)
from cerebras.modelzoo.data_preparation.utils import (
    convert_to_unicode,
    create_masked_lm_predictions,
    pad_instance_to_max_seq_length,
    text_to_tokenized_documents,
)
[docs]class SentencePairInstance:
    """
    A single training (sentence-pair) instance.
    :param list tokens: List of tokens for sentence pair
    :param list segment_ids: List of segment ids for sentence pair
    :param list masked_lm_positions: List of masked lm positions for sentence
    pair
    :param list masked_lm_labels: List of masked lm labels for sentence pair
    :param bool is_random_next: Specifies whether the second element in the
    pair is random
    """
[docs]    def __init__(
        self,
        tokens,
        segment_ids,
        masked_lm_positions,
        masked_lm_labels,
        is_random_next,
    ):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.masked_lm_labels = masked_lm_labels
        self.masked_lm_positions = masked_lm_positions
        self.is_random_next = is_random_next 
    def __str__(self):
        tokens = " ".join([convert_to_unicode(x) for x in self.tokens])
        segment_ids = " ".join([str(x) for x in self.segment_ids])
        mlm_positions = " ".join([str(x) for x in self.masked_lm_positions])
        mlm_labels = " ".join(
            [convert_to_unicode(x) for x in self.masked_lm_labels]
        )
        s = ""
        s += f"tokens: {tokens}\n"
        s += f"segment_ids: {segment_ids}\n"
        s += f"is_random_next: {self.is_random_next}\n"
        s += f"masked_lm_positions: {mlm_positions}\n"
        s += f"masked_lm_labels: {mlm_labels}\n"
        s += "\n"
        return s
    def __repr__(self):
        return self.__str__() 
[docs]def data_generator(
    metadata_files,
    vocab_file,
    do_lower,
    split_num,
    max_seq_length,
    short_seq_prob,
    mask_whole_word,
    max_predictions_per_seq,
    masked_lm_prob,
    dupe_factor,
    output_type_shapes,
    min_short_seq_length=None,
    multiple_docs_in_single_file=False,
    multiple_docs_separator="\n",
    single_sentence_per_line=False,
    inverted_mask=False,
    seed=None,
    spacy_model="en_core_web_sm",
    input_files_prefix="",
    sop_labels=False,
):
    """
    Generator function used to create input dataset
    for MLM + NSP dataset.
    1. Generate raw examples by concatenating two parts
    'tokens-a' and 'tokens-b' as follows:
    [CLS] <tokens-a> [SEP] <tokens-b> [SEP]
    where :
        tokens-a: list of tokens taken from the
        current document and of random length (less than msl).
        tokens-b: list of tokens chosen based on the
        randomly set "next_sentence_labels" and of
        length msl-len(<tokens-a>)- 3 (to account for [CLS] and [SEP] tokens)
    If "next_sentence_labels" is 1, (set to 1 with 0.5 probability),
        tokens-b are list of tokens from sentences chosen randomly
        from different document
    else,
        tokens-b are list of tokens taken from the same document
        and is a continuation of tokens-a in the document
    The number of raw tokens depends on "short_sequence_prob" as well
    2. Mask the raw examples based on "max_predictions_per_seq"
    3. Pad the masked example to "max_sequence_length" if less that msl
    :param str or list[str] metadata_files: A string or strings list each
        pointing to a metadata file. A metadata file contains file paths for
        flat text cleaned documents. It has one file path per line.
    :param str vocab_file: Vocabulary file, to build tokenization from
    :param bool do_lower: Boolean value indicating if words should be
        converted to lowercase or not
    :param int split_num: Number of input files to read at a given
        time for processing.
    :param int max_seq_length: Maximum length of the sequence to generate
    :param int short_seq_prob: Probability of a short sequence. Defaults to 0.
        Sometimes we want to use shorter sequences to minimize the mismatch
        between pre-training and fine-tuning.
    :param bool mask_whole_word: If True, all subtokens corresponding to a word
        will be masked.
    :param int max_predictions_per_seq: Maximum number of Masked tokens
        in a sequence
    :param float masked_lm_prob: Proportion of tokens to be masked
    :param int dupe_factor: Number of times to duplicate the dataset
        with different static masks
    :param int min_short_seq_length: When short_seq_prob > 0, this number
        indicates the least number of tokens that each example should have i.e
        the num_tokens (excluding pad) would be in the range
        [min_short_seq_length, MSL]
    :param dict output_type_shapes: Dictionary indicating the shapes of
        different outputs
    :param bool multiple_docs_in_single_file: True, when a single text file
        contains multiple documents separated by <multiple_docs_separator>
    :param str multiple_docs_separator: String which separates
    multiple documents in a single text file.
    :param single_sentence_per_line: True,when the document is already
        split into sentences with one sentence in each line and there is
        no requirement for further sentence segmentation of a document
    :param bool inverted_mask: If set to False, has 0's on padded positions and
        1's elsewhere. Otherwise, "inverts" the mask, so that 1's are on padded
        positions and 0's elsewhere.
    :param int seed: Random seed.
    :param spacy_model: spaCy model to load, i.e. shortcut
        link, package name or path. Used to segment text into sentences.
    :param str input_file_prefix: Prefix to be added to paths of the input files.
    :param bool sop_labels: If true, negative examples of the dataset will be two
        consecutive sentences in reversed order. Otherwise, uses regular (NSP)
        labels (where negative examples are from different documents).
    :returns: yields training examples (feature, label)
    where label refers to the next_sentence_prediction label
    """
    if min_short_seq_length is None:
        min_short_seq_length = 2
    elif (min_short_seq_length < 2) or (
        min_short_seq_length > max_seq_length - 3
    ):
        raise ValueError(
            f"The min_short_seq_len param {min_short_seq_length} is invalid.\n"
            f"Allowed values are [2, {max_seq_length - 3})"
        )
    # define tokenizer
    vocab_file = os.path.abspath(vocab_file)
    tokenizer = FullTokenizer(vocab_file, do_lower)
    vocab_words = tokenizer.get_vocab_words()
    rng = random.Random(seed)
    # get all text files by reading metadata files
    if isinstance(metadata_files, str):
        metadata_files = [metadata_files]
    input_files = []
    for _file in metadata_files:
        with open(_file, "r") as _fin:
            input_files.extend(_fin.readlines())
    input_files = [x.strip() for x in input_files if x]
    rng.shuffle(input_files)
    split_num = len(input_files) if split_num <= 0 else split_num
    # for better performance load spacy model once here
    nlp = spacy.load(spacy_model)
    for i in range(0, len(input_files), split_num):
        current_input_files = input_files[i : i + split_num]
        all_documents = []
        for _file in tqdm(current_input_files):
            _fin_path = os.path.abspath(os.path.join(input_files_prefix, _file))
            with open(_fin_path, "r") as _fin:
                _fin_data = _fin.read()
            processed_doc, _ = text_to_tokenized_documents(
                _fin_data,
                tokenizer,
                multiple_docs_in_single_file,
                multiple_docs_separator,
                single_sentence_per_line,
                nlp,
            )
            all_documents.extend(processed_doc)
        rng.shuffle(all_documents)
        # create a set of instance to process further
        # repeat this process `dupe_factor` times
        # get a list of SentencePairInstances
        instances = []
        for _ in range(dupe_factor):
            for document_index in range(len(all_documents)):
                instances.extend(
                    _create_sentence_instances_from_document(
                        all_documents,
                        document_index,
                        vocab_words,
                        max_seq_length,
                        short_seq_prob,
                        min_short_seq_length,
                        mask_whole_word,
                        max_predictions_per_seq,
                        masked_lm_prob,
                        rng,
                        sop_labels,
                    )
                )
        rng.shuffle(instances)
        for instance in instances:
            feature, label = pad_instance_to_max_seq_length(
                instance=instance,
                mlm_only=False,
                tokenizer=tokenizer,
                max_seq_length=max_seq_length,
                max_predictions_per_seq=max_predictions_per_seq,
                output_type_shapes=output_type_shapes,
                inverted_mask=inverted_mask,
            )
            yield (feature, label) 
def _create_sentence_instances_from_document(
    all_documents,
    document_index,
    vocab_words,
    max_seq_length,
    short_seq_prob,
    min_short_seq_length,
    mask_whole_word,
    max_predictions_per_seq,
    masked_lm_prob,
    rng,
    sop_labels=False,
):
    """
    Create instances from documents.
    :param list all_documents: List of lists which contains tokenized
    senteneces from each document
    :param int document_index: Index of document to process currently
    :param list vocab_words: List of all words present in the vocabulary
    :param bool sop_labels: If true, negative examples of the dataset will be two
        consecutive sentences in reversed order. Otherwise, uses regular (NSP)
        labels (where negative examples are from different documents).
    :returns: List of SentencePairInstance objects
    """
    # get document with document_index
    # Example:
    # [
    #   [line1], [line2], [line3]
    # ]
    # where each line = [tokens]
    document = all_documents[document_index]
    # account for [CLS], [SEP], [SEP]
    max_num_tokens = max_seq_length - 3
    # We usually want to fill up the entire sequence since we are padding
    # to `max_seq_length` anyways, so short sequences are generally not
    # needed. However, we sometimes (i.e., short_seq_prob = 0.1 == 10% of
    # the time) want to use shorter sequences to minimize the mismatch
    # between pre-training and fine-tuning. The `target_seq_len` is just a
    # rough target however, whereas `max_seq_length` is a hard limit
    target_seq_len = max_num_tokens
    if rng.random() < short_seq_prob:
        target_seq_len = rng.randint(min_short_seq_length, max_num_tokens)
    # We don't just concatenate all of the tokens from a line into a long
    # sequence and choose an arbitrary split point because this would make
    # the NSP task too easy. Instead, we split the input into segments
    # `A` and `B` based on the actual "sentences" provided by the user
    # input
    instances = []
    current_chunk = []
    current_length = 0
    i = 0
    # lambda function for fast internal calls. called multiple times
    flatten = lambda l: [item for sublist in l for item in sublist]
    while i < len(document):
        # a line is a list of tokens [includes words / punctuations /
        # special characters / wordpieces]
        # we initially read an entire line - but also ensure that if we
        # meet the target seq_len with the current line - we cut it off
        # remove the unused `segments` and put them back in circulation for
        # input creation
        line = document[i]
        current_chunk.append(line)
        current_length += len(line)
        if i == len(document) - 1 or current_length >= target_seq_len:
            if current_chunk:
                # generate a sentence pair instance for NSP loss
                # `a_end` is how many segments from `current_chunk` go into
                # `A` (first sentence)
                a_end = 1
                if len(current_chunk) >= 2:
                    a_end = rng.randint(1, len(current_chunk) - 1)
                tokens_a = []
                tokens_a.extend(flatten(current_chunk[0:a_end]))
                tokens_b = []
                # Random next
                is_random_next = False
                if len(current_chunk) == 1 or (
                    not sop_labels and rng.random() < 0.5
                ):
                    is_random_next = True
                    target_b_length = target_seq_len - len(tokens_a)
                    # this should rarely go for more than one iteration
                    # for large corpora. However, just to be careful, we
                    # try to make sure that the random document is
                    # not the same as the document we are processing
                    for _ in range(10):
                        random_document_index = rng.randint(
                            0, len(all_documents) - 1
                        )
                        if random_document_index != document_index:
                            break
                    random_document = all_documents[random_document_index]
                    random_start = rng.randint(0, len(random_document) - 1)
                    for j in range(random_start, len(random_document)):
                        tokens_b.extend(random_document[j])
                        if len(tokens_b) >= target_b_length:
                            break
                    # We don't actually use these segments [peices of line]
                    # so we "put them back" so they do not to waste for
                    # later computations
                    num_unused_segments = len(current_chunk) - a_end
                    i -= num_unused_segments
                elif sop_labels and rng.random() < 0.5:
                    is_random_next = True
                    for j in range(a_end, len(current_chunk)):
                        tokens_b.extend(current_chunk[j])
                    tokens_a, tokens_b = tokens_b, tokens_a
                else:
                    # Actual next
                    is_random_next = False
                    tokens_b.extend(
                        flatten(current_chunk[a_end : len(current_chunk)])
                    )
                # When using SOP, with prob 0.5, the sentence ordering should be
                # swapped forming the negative samples.
                if sop_labels and (
                    len(current_chunk) == 1 or rng.random() < 0.5
                ):
                    tokens_a, tokens_b = tokens_b, tokens_a
                    is_random_next = True
                # truncate seq pair tokens to max_num_tokens
                _truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
                assert len(tokens_a) >= 1
                assert len(tokens_b) >= 1
                # create actual input instance
                tokens = []
                segment_ids = []
                # add special token for input start
                tokens.append("[CLS]")
                segment_ids.append(0)
                # append input `A`
                extend_list = [0] * len(tokens_a)
                segment_ids.extend(extend_list)
                tokens.extend(tokens_a)
                # add special token for input separation
                tokens.append("[SEP]")
                segment_ids.append(0)
                # append input `B`
                extend_list = [1] * len(tokens_b)
                segment_ids.extend(extend_list)
                tokens.extend(tokens_b)
                # add special token for input separation
                tokens.append("[SEP]")
                segment_ids.append(1)
                (
                    tokens,
                    masked_lm_positions,
                    masked_lm_labels,
                ) = create_masked_lm_predictions(
                    tokens,
                    vocab_words,
                    mask_whole_word,
                    max_predictions_per_seq,
                    masked_lm_prob,
                    rng,
                )
                instance = SentencePairInstance(
                    tokens=tokens,
                    segment_ids=segment_ids,
                    is_random_next=is_random_next,
                    masked_lm_positions=masked_lm_positions,
                    masked_lm_labels=masked_lm_labels,
                )
                instances.append(instance)
            # reset buffers
            current_chunk = []
            current_length = 0
        # move on to next segment
        i += 1
    return instances
def _truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
    """
    Truncate a pair of tokens so that their total length is lesser than
    defined maximum number of tokens
    :param list tokens_a: First list of tokens in sequence pair
    :param list tokens_b: Second list of tokens in sequence pair
    :param int max_num_tokens: Maximum number of tokens for the length of
    sequence pair tokens
    """
    total_length = len(tokens_a) + len(tokens_b)
    while total_length > max_num_tokens:
        # find the correct list to truncate this iteration of the loop
        trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
        assert (len(trunc_tokens)) >= 1
        # check whether to remove from front or rear
        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()
        # recompute lengths again after deletion of token
        total_length = len(tokens_a) + len(tokens_b)