# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is adapted from
# https://github.com/google-research/bert/blob/master/run_squad.py
#
# Copyright 2022 Cerebras Systems.
#
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import json
import os
import six
import tqdm
from modelzoo.transformers.data_processing.utils import (
    convert_to_unicode,
    whitespace_tokenize,
)
[docs]class SquadExample(object):
    """
    A single training/test example for simple sequence classification.
    For examples without an answer, the start and end position are -1.
    """
[docs]    def __init__(
        self,
        qas_id,
        question_text,
        doc_tokens,
        orig_answer_text=None,
        start_position=None,
        end_position=None,
        is_impossible=False,
    ):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible 
    def __str__(self):
        return self.__repr__()
    def __repr__(self):
        s = ""
        s += "qas_id: %s" % (convert_to_unicode(self.qas_id))
        s += ", question_text: %s" % (convert_to_unicode(self.question_text))
        s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
        if self.start_position:
            s += ", start_position: %d" % (self.start_position)
        if self.start_position:
            s += ", end_position: %d" % (self.end_position)
        if self.start_position:
            s += ", is_impossible: %r" % (self.is_impossible)
        return s 
# A single sample of features of data
InputFeatures = collections.namedtuple(
    "InputFeatures",
    [
        "unique_id",
        "example_index",
        "doc_span_index",
        "tokens",
        "token_to_orig_map",
        "token_is_max_context",
        "input_ids",
        "input_mask",
        "segment_ids",
        "start_position",
        "end_position",
        "is_impossible",
    ],
)
[docs]def read_squad_examples(input_file, is_training, version_2_with_negative):
    """
    Read a SQuAD json file into a list of SquadExample.
    """
    with open(input_file, "r") as reader:
        input_data = json.load(reader)["data"]
    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False
    examples = []
    for entry in input_data:
        for paragraph in entry["paragraphs"]:
            paragraph_text = paragraph["context"]
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position = None
                end_position = None
                orig_answer_text = None
                is_impossible = False
                if is_training:
                    if version_2_with_negative:
                        is_impossible = qa["is_impossible"]
                    if not is_impossible:
                        answer = qa["answers"][0]
                        orig_answer_text = answer["text"]
                        answer_offset = answer["answer_start"]
                        answer_length = len(orig_answer_text)
                        start_position = char_to_word_offset[answer_offset]
                        end_position = char_to_word_offset[
                            answer_offset + answer_length - 1
                        ]
                        # Only add answers where the text can be exactly recovered from the
                        # document. If this CAN'T happen it's likely due to weird Unicode
                        # stuff so we will just skip the example.
                        #
                        # Note that this means for training mode, every example is NOT
                        # guaranteed to be preserved.
                        actual_text = " ".join(
                            doc_tokens[start_position : (end_position + 1)]
                        )
                        cleaned_answer_text = " ".join(
                            whitespace_tokenize(orig_answer_text)
                        )
                        if actual_text.find(cleaned_answer_text) == -1:
                            print(
                                "Warning: Could not find answer: '%s' vs. '%s'"
                                % (actual_text, cleaned_answer_text,)
                            )
                            continue
                    else:
                        start_position = -1
                        end_position = -1
                        orig_answer_text = ""
                example = SquadExample(
                    qas_id=qas_id,
                    question_text=question_text,
                    doc_tokens=doc_tokens,
                    orig_answer_text=orig_answer_text,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=is_impossible,
                )
                examples.append(example)
    return examples 
[docs]def check_tokenizer_scheme(tokenizer_scheme):
    valid_schemes = ["bert", "t5"]
    if tokenizer_scheme not in valid_schemes:
        raise ValueError(
            f"Tokenizer scheme, {tokenizer_scheme}, is not currently supported, or a mistaken value has been passed in. Valid schemes are currently: {*valid_schemes,}"
        ) 
[docs]def convert_examples_to_features(
    examples,
    tokenize_fn,
    convert_tokens_to_ids_fn,
    max_seq_length,
    doc_stride,
    max_query_length,
    tokenizer_scheme,
    is_training,
    output_fn,
):
    """
    Loads a data file into a list of `InputBatch`s.
    """
    check_tokenizer_scheme(tokenizer_scheme)
    num_samples = 0
    unique_id = 1000000000
    total_examples = len(examples)
    for (example_index, example) in tqdm.tqdm(
        enumerate(examples), total=total_examples
    ):
        query_tokens = tokenize_fn(example.question_text)
        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]
        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenize_fn(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)
        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = (
                    orig_to_tok_index[example.end_position + 1] - 1
                )
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens,
                tok_start_position,
                tok_end_position,
                tokenize_fn,
                example.orig_answer_text,
            )
        # The -3 accounts for [CLS], [SEP] and [SEP]
        special_token_adjuster = 3 if tokenizer_scheme == 'bert' else 0
        max_tokens_for_doc = (
            max_seq_length - len(query_tokens) - special_token_adjuster
        )
        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"]
        )
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)
        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            # for T5 and other models we do not need the [CLS] and [SEP] tokens, but we want to keep the interface for BERT as it was when shipped to customers
            if tokenizer_scheme == 'bert':
                tokens.append("[CLS]")
                segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            if tokenizer_scheme == 'bert':
                tokens.append("[SEP]")
                segment_ids.append(0)
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(tokens)] = tok_to_orig_index[
                    split_token_index
                ]
                is_max_context = _check_is_max_context(
                    doc_spans, doc_span_index, split_token_index
                )
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            if tokenizer_scheme == 'bert':
                tokens.append("[SEP]")
                segment_ids.append(1)
            input_ids = convert_tokens_to_ids_fn(tokens)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)
            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (
                    tok_start_position >= doc_start
                    and tok_end_position <= doc_end
                ):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    special_toks_offset = 2 if tokenizer_scheme == 'bert' else 0
                    doc_offset = len(query_tokens) + special_toks_offset
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset
            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0
            if example_index < 3:
                print("*** Example ***")
                print("unique_id: %s" % (unique_id))
                print("example_index: %s" % (example_index))
                print("doc_span_index: %s" % (doc_span_index))
                print(
                    "tokens: %s"
                    % " ".join([convert_to_unicode(x) for x in tokens])
                )
                print(
                    "token_to_orig_map: %s"
                    % " ".join(
                        [
                            "%d:%d" % (x, y)
                            for (x, y) in six.iteritems(token_to_orig_map)
                        ]
                    )
                )
                print(
                    "token_is_max_context: %s"
                    % " ".join(
                        [
                            "%d:%s" % (x, y)
                            for (x, y) in six.iteritems(token_is_max_context)
                        ]
                    )
                )
                print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                print("input_mask: %s" % " ".join([str(x) for x in input_mask]))
                print(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids])
                )
                if is_training and example.is_impossible:
                    print("impossible example")
                if is_training and not example.is_impossible:
                    answer_text = " ".join(
                        tokens[start_position : (end_position + 1)]
                    )
                    print("start_position: %d" % (start_position))
                    print("end_position: %d" % (end_position))
                    print("answer: %s" % (convert_to_unicode(answer_text)))
            features = InputFeatures(
                unique_id=unique_id,
                example_index=example_index,
                doc_span_index=doc_span_index,
                tokens=tokens,
                token_to_orig_map=token_to_orig_map,
                token_is_max_context=token_is_max_context,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                start_position=start_position,
                end_position=end_position,
                is_impossible=example.is_impossible,
            )
            # Run callback
            output_fn(features)
            unique_id += 1
            num_samples += 1
    return num_samples 
[docs]def convert_examples_to_features_and_write(
    examples,
    tokenize_fn,
    convert_tokens_to_ids_fn,
    max_seq_length,
    doc_stride,
    max_query_length,
    output_dir,
    file_prefix,
    num_output_files,
    tokenizer_scheme,
    is_training=True,
    return_features=False,
):
    meta_data = collections.defaultdict(int)
    total_num_samples = 0
    num_output_files = max(num_output_files, 1)
    output_files = [
        os.path.join(output_dir, "%s-%04i.csv" % (file_prefix, fidx + 1))
        for fidx in range(num_output_files)
    ]
    divided_examples = _divide_list(examples, num_output_files)
    all_features = list()
    for _examples, _output_file in zip(divided_examples, output_files):
        with open(_output_file, "w") as csvfile:
            writer = csv.DictWriter(
                csvfile,
                fieldnames=InputFeatures._fields,
                quoting=csv.QUOTE_MINIMAL,
            )
            writer.writeheader()
            def write_fn(features):
                features_dict = features._asdict()
                writer.writerow(features_dict)
                if return_features:
                    all_features.append(features)
            num_samples = convert_examples_to_features(
                examples=_examples,
                tokenize_fn=tokenize_fn,
                convert_tokens_to_ids_fn=convert_tokens_to_ids_fn,
                max_seq_length=max_seq_length,
                doc_stride=doc_stride,
                max_query_length=max_query_length,
                tokenizer_scheme=tokenizer_scheme,
                is_training=is_training,
                output_fn=write_fn,
            )
            output_file = os.path.basename(_output_file)
            meta_data[output_file] += num_samples
            total_num_samples += num_samples
    if return_features:
        return total_num_samples, meta_data, all_features
    else:
        return total_num_samples, meta_data 
def _divide_list(li, n):
    """
    Yields n successive lists of equal size, 
    modulo the remainder.
    Example:
        >>> a = list(range(10))
        >>> list(_divide_list(a, 3))
        [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]
    """
    start = 0
    for i in range(n):
        stop = start + len(li[i::n])
        yield li[start:stop]
        start = stop
def _improve_answer_span(
    doc_tokens, input_start, input_end, tokenize_fn, orig_answer_text
):
    """
    Returns tokenized answer spans that better match the annotated answer.
    """
    # The SQuAD annotations are character based. We first project them to
    # whitespace-tokenized words. But then after WordPiece tokenization, we can
    # often find a "better match". For example:
    #
    #   Question: What year was John Smith born?
    #   Context: The leader was John Smith (1895-1943).
    #   Answer: 1895
    #
    # The original whitespace-tokenized answer will be "(1895-1943).". However
    # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
    # the exact answer, 1895.
    #
    # However, this is not always possible. Consider the following:
    #
    #   Question: What country is the top exporter of electornics?
    #   Context: The Japanese electronics industry is the lagest in the world.
    #   Answer: Japan
    #
    # In this case, the annotator chose "Japan" as a character sub-span of
    # the word "Japanese". Since our WordPiece tokenizer does not split
    # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
    # in SQuAD, but does happen.
    # TODO not sure if this will be compatible
    tok_answer_text = " ".join(tokenize_fn(orig_answer_text))
    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)
    return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position):
    """
    Check if this is the 'max context' doc span for the token.
    """
    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = (
            min(num_left_context, num_right_context) + 0.01 * doc_span.length
        )
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index
    return cur_span_index == best_span_index