# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import numpy as np
import six
from keras_preprocessing.text import text_to_word_sequence
from modelzoo.common.model_utils.count_lines import count_lines
[docs]def convert_to_unicode(text):
    """
    Converts `text` to unicode, assuming utf-8 input
    Returns text encoded in a way suitable for print or `tf.compat.v1.logging`
    """
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError(f"Unsupported string type: {type(text)}")
    else:
        raise ValueError(f"Not running Python3") 
[docs]def count_total_documents(metadata_files):
    """
    Counts total number of documents
    in metadata_files.
    :param str or list[str] metadata_files: Path or list of paths
        to metadata files.
    :returns: Number of documents whose paths are contained
        in the metadata files.
    """
    total_documents = 0
    if isinstance(metadata_files, str):
        metadata_files = [metadata_files]
    for _file in metadata_files:
        total_documents += count_lines(_file)
    return total_documents 
[docs]def whitespace_tokenize(text, lower=False):
    """
    Splits a piece of text based on whitespace characters \t\r\n
    """
    return text_to_word_sequence(text, filters='\t\n\r', lower=lower) 
[docs]def get_output_type_shapes(
    max_seq_length, max_predictions_per_seq, mlm_only=False
):
    # process for output shapes and types
    output = {
        "input_ids": {"output_type": "int32", "shape": [max_seq_length],},
        "input_mask": {"output_type": "int32", "shape": [max_seq_length],},
        "masked_lm_positions": {
            "output_type": "int32",
            "shape": [max_predictions_per_seq],
        },
        "masked_lm_ids": {
            "output_type": "int32",
            "shape": [max_predictions_per_seq],
        },
        "masked_lm_weights": {
            "output_type": "float32",
            "shape": [max_predictions_per_seq],
        },
    }
    if not mlm_only:
        output["segment_ids"] = {
            "output_type": "int32",
            "shape": [max_seq_length],
        }
    return output 
[docs]def pad_instance_to_max_seq_length(
    instance,
    mlm_only,
    tokenizer,
    max_seq_length,
    max_predictions_per_seq,
    output_type_shapes,
    inverted_mask,
):
    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
    input_mask = [1] * len(input_ids)
    # initial assert to ensure wrong instances are not being
    # generated from the function call
    assert len(input_ids) <= max_seq_length
    # extend above lists with length difference
    length_diff = max_seq_length - len(input_ids)
    extended_list = [0] * length_diff
    input_ids.extend(extended_list)
    input_mask.extend(extended_list)
    # assertions to ensure correct output shapes
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    if not mlm_only:
        segment_ids = list(instance.segment_ids)
        segment_ids.extend(extended_list)
        assert len(segment_ids) == max_seq_length
    masked_lm_positions = list(instance.masked_lm_positions)
    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
    masked_lm_weights = [1.0] * len(masked_lm_ids)
    # initial assert to ensure wrong instances are not being
    # generated from the function call
    assert len(masked_lm_positions) <= max_predictions_per_seq
    # extend above lists with length difference
    length_diff = max_predictions_per_seq - len(masked_lm_positions)
    extended_list = [0] * length_diff
    masked_lm_positions.extend(extended_list)
    masked_lm_ids.extend(extended_list)
    masked_lm_weights.extend(extended_list)
    # assertions to ensure correct output shapes
    assert len(masked_lm_positions) == max_predictions_per_seq
    assert len(masked_lm_ids) == max_predictions_per_seq
    assert len(masked_lm_weights) == max_predictions_per_seq
    # create feature dict
    features = dict()
    features["input_ids"] = input_ids
    features["input_mask"] = input_mask
    features["masked_lm_positions"] = masked_lm_positions
    features["masked_lm_ids"] = masked_lm_ids
    features["masked_lm_weights"] = masked_lm_weights
    if not mlm_only:
        features["segment_ids"] = segment_ids
    # get associated numpy types and convert to
    # np.dtype using output_type_shapes
    feature = {
        k: getattr(np, output_type_shapes[k]["output_type"])(v)
        for k, v in features.items()
    }
    # handling input mask switch
    if inverted_mask:
        feature["input_mask"] = np.equal(feature["input_mask"], 0).astype(
            feature["input_mask"].dtype
        )
    if not mlm_only:
        # get label for function
        next_sentence_label = 1 if instance.is_random_next else 0
        # int32 label always
        label = np.int32(next_sentence_label)
    else:
        # Currently labels=None is not supported.
        label = np.int32(np.empty(1)[0])
    return feature, label 
[docs]def text_to_tokenized_documents(
    data,
    tokenizer,
    multiple_docs_in_single_file,
    multiple_docs_separator,
    single_sentence_per_line,
    spacy_nlp,
):
    """
    Convert the input data into tokens
    :param str data: Contains data read from a text file
    :param tokenizer: Tokenizer object which contains functions to 
        convert words to tokens 
    :param bool multiple_docs_in_single_file: Indicates whether there 
        are multiple documents in the given data string
    :param str multiple_docs_separator: String used to separate documents 
        if there are multiple documents in data.
        Separator can be anything. It can be a new blank line or 
        some special string like "-----" etc. 
        There can only be one separator string for all the documents.
    :param bool single_sentence_per_line: Indicates whether the data contains
         one sentence in each line
    :param spacy_nlp: spaCy nlp module loaded with spacy.load()
        Used in segmenting a string into sentences
    :return List[List[List]] documents: Contains the tokens corresponding to
         sentences in documents.
         List of List of Lists [[[],[]], [[],[],[]]]
         documents[i][j] -> List of tokens in document i and sentence j
    """
    if "\\n" in multiple_docs_separator:
        multiple_docs_separator = multiple_docs_separator.replace("\\n", "\n")
    get_length = lambda input: sum([len(x) for x in input])
    documents = []
    num_tokens = 0
    if multiple_docs_in_single_file:
        # "\n" is added since seperator is always in newline
        # <doc1>
        # multiple_docs_separator
        # <doc2>
        data = data.split("\n" + multiple_docs_separator)
        data = [x for x in data if x]  # data[i] -> document i
    else:
        data = [data]
    if single_sentence_per_line:
        # The document has already been into sentences and each sentence is in a newline
        for doc in data:
            documents.append([])
            # Get sentences by splitting on newline, since each new sentence is in a newline
            lines = doc.split("\n")
            for line in lines:
                if line:
                    tokens = tokenizer.tokenize(
                        line.strip()
                    )  # tokens : list of tokens
                    if tokens:
                        documents[-1].append(tokens)
                        num_tokens += len(tokens)
    else:
        # The document should be segmented into sentences with a spacy_model
        for doc in data:
            processed_doc = spacy_nlp(convert_to_unicode(doc.replace('\n', '')))
            sentences = [
                tokenizer.tokenize(s.text) for s in list(processed_doc.sents)
            ]
            sentences = [
                s for s in sentences if s
            ]  # sentences[i][j] -> token j of sentence i
            documents.append(sentences)
            num_tokens += get_length(sentences)
    # documents[i][j] -> list of tokens of sentence j in  document i
    # Remove empty documents if any
    documents = [x for x in documents if x]
    return documents, num_tokens 
maskedLmInstance = collections.namedtuple(
    "maskedLmInstance", ["index", "label"]
)
[docs]def create_masked_lm_predictions(
    tokens,
    vocab_words,
    mask_whole_word,
    max_predictions_per_seq,
    masked_lm_prob,
    rng,
    exclude_from_masking=None,
):
    """
    Creates the predictions for the masked LM objective
    :param list tokens: List of tokens to process
    :param list vocab_words: List of all words present in the vocabulary
    :param bool mask_whole_word: If true, mask all the subtokens of a word
    :param int max_predictions_per_seq: Maximum number of masked LM predictions per sequence
    :param float masked_lm_prob: Masked LM probability
    :param rng: random.Random object with shuffle function
    :param Optional[list] exclude_from_masking: List of tokens to exclude from masking. Defaults to ["[CLS]", "[SEP]"]
    :returns: tuple of tokens which include masked tokens,
    the corresponding positions for the masked tokens
    and also the corresponding labels for training
    """
    if exclude_from_masking is not None:
        if not isinstance(exclude_from_masking, list):
            exclude_from_masking = list(exclude_from_masking)
    else:
        exclude_from_masking = ["[CLS]", "[SEP]"]
    cand_indexes = []
    for (i, token) in enumerate(tokens):
        if token in exclude_from_masking:
            continue
        # Whole word masking means that if we mask all of the wordpieces
        # corresponding to an original word. When a word has been split
        # into WordPieces, the first token does not have any marker and
        # any subsequences tokens are prefixed with ##. So whenever we see
        # the ## token, we append it to the previous set of word indexes.
        # Note that whole word masking does not change the training code
        # at all -- we still predict each WordPiece independently,
        # softmaxed over the entire vocabulary
        if (
            mask_whole_word
            and len(cand_indexes) >= 1
            and token.startswith("##")
        ):
            cand_indexes[-1].append(i)
        else:
            cand_indexes.append([i])
    rng.shuffle(cand_indexes)
    output_tokens = list(tokens)
    # get number of tokens to mask and predict
    num_to_predict = min(
        max_predictions_per_seq,
        max(1, int(round(len(tokens) * masked_lm_prob))),
    )
    masked_lms = []
    covered_indexes = set()
    for index_set in cand_indexes:
        if len(masked_lms) >= num_to_predict:
            break
        # if adding a whole-word mask would exceed the maximum number of
        # predictions, then just skip this candidate
        if len(masked_lms) + len(index_set) > num_to_predict:
            continue
        # Check if any index is covered already.
        is_any_index_covered = False
        for index in index_set:
            if index in covered_indexes:
                is_any_index_covered = True
                break
        if is_any_index_covered:
            continue
        for index in index_set:
            covered_indexes.add(index)
            # splits comes from
            # google-research/bert/create_pretraining_data.py
            masked_token = None
            random_value = rng.random()
            if random_value < 0.8:
                # 80% of times, replace with [MASK]
                masked_token = "[MASK]"
            else:
                # 10% of the time, keep the original token
                if rng.random() < 0.5:
                    masked_token = tokens[index]
                # 10 % of times, replace with random word
                else:
                    masked_token = vocab_words[
                        rng.randint(0, len(vocab_words) - 1)
                    ]
            output_tokens[index] = masked_token
            masked_lms.append(
                maskedLmInstance(index=index, label=tokens[index])
            )
    assert len(masked_lms) <= num_to_predict
    masked_lms = sorted(masked_lms, key=lambda x: x.index)
    masked_lm_positions = []
    masked_lm_labels = []
    # create final masked_lm_positions, masked_lm_labels
    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
    return (output_tokens, masked_lm_positions, masked_lm_labels) 
[docs]def get_label_id_map(label_vocab_file):
    """
    Load the label-id mapping: Mapping between output labels and id
    :param str label_vocab_file: Path to the label vocab file
    """
    label_map = None
    if label_vocab_file is not None:
        with open(label_vocab_file, 'r') as fh:
            label_map = json.load(fh)
    return label_map 
[docs]def convert_str_to_int_list(s):
    """
    Converts a string (e.g. from parsing CSV) of the form
        "[1, 5, 7, 2]"
    to a list of integers.
    """
    assert s.startswith("[")
    assert s.endswith("]")
    x = s.strip("[]")
    x = x.split(",")
    return [int(y.strip()) for y in x] 
[docs]def split_list(l, n):
    """
    Splits list/string into n sized chunks.
    :param List[str] l: List or string to split.
    :param int n: Number of chunks to split to.
    :returns List[List]: List of lists 
        containing split list/string.
    """
    return [l[i : i + n] for i in range(0, len(l), n)] 
[docs]def get_vocab(vocab_file_path, do_lower):
    """
    Function to generate vocab from provided
    vocab_file_path.
    :param str vocab_file_path: Path to vocab file
    :param bool do_lower: If True, convert vocab words to
        lower case.
    :returns List[str]: list containing vocab words.
    """
    vocab = []
    with open(vocab_file_path, 'r') as reader:
        for line in reader:
            token = convert_to_unicode(line)
            if not token:
                break
            token = token.strip()
            vocab.append(token)
    vocab = list(map(lambda token: token.lower(), vocab)) if do_lower else vocab
    return vocab