# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common pre-processing functions taken from:
https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/LanguageModeling/BERT/run_ner.py 
with minor modifications 
"""
import argparse
import json
import os
import pickle
from modelzoo.transformers.data_processing.utils import convert_to_unicode
class NERProcessor:
    def __init__(self, split) -> None:
        self.split = split
        self.file_name = ".".join([split, "tsv"])
    def get_split_examples(self, data_dir):
        input_file = os.path.join(data_dir, self.file_name)
        print(f"*** Processing {self.split} exmples: {input_file} ***")
        return self._create_example(self._read_data(input_file))
    def get_labels(self):
        # NOTE:[PAD] should always be first inorder to have an id=0
        return ["[PAD]", "B", "I", "O", "X", "[CLS]", "[SEP]"]
    def _create_example(self, lines):
        examples = []
        for i, (label, text) in enumerate(lines):
            guid = f"{self.split}-{i}"
            text = convert_to_unicode(text)
            label = convert_to_unicode(label)
            if "-DOCSTART-" in text:
                # JNLPBA dataset has some entries which have this demarcation
                # print("Text ignore:{}".format(text))
                continue
            examples.append(InputExample(guid, text, label))
        return examples
    def _read_data(self, input_file):
        """
        Read 'B', 'I', 'O' data.
        """
        if os.path.exists(input_file):
            with open(input_file, "r") as f:
                lines, words, labels = [], [], []
                for line in f:
                    contents = line.strip()
                    if len(contents) == 0:
                        assert len(words) == len(labels)
                        while len(words) > 30:
                            # split the sentence if it is longer than 30
                            tmplabel = labels[:30]
                            for _ in range(len(tmplabel)):
                                if tmplabel.pop() == 'O':
                                    break
                            l = " ".join(
                                [
                                    label
                                    for label in labels[: len(tmplabel) + 1]
                                    if len(label) > 0
                                ]
                            )
                            w = " ".join(
                                [
                                    word
                                    for word in words[: len(tmplabel) + 1]
                                    if len(word) > 0
                                ]
                            )
                            lines.append([l, w])
                            words = words[len(tmplabel) + 1 :]
                            labels = labels[len(tmplabel) + 1 :]
                        if len(words) == 0:
                            continue
                        l = " ".join(
                            [label for label in labels if len(label) > 0]
                        )
                        w = " ".join([word for word in words if len(word) > 0])
                        lines.append([l, w])
                        words = []
                        labels = []
                        continue
                    word = line.strip().split()[0]
                    label = line.strip().split()[-1]
                    words.append(word)
                    labels.append(label)
                return lines
        else:
            return []
[docs]class NERProcessor:
    def get_train_examples(self, data_dir, file_name="train.tsv"):
        print(
            f"**** Processing train examples: {os.path.join(data_dir, file_name)}"
        )
        return self._create_example(
            self._read_data(os.path.join(data_dir, file_name)), "train"
        )
    def get_dev_examples(self, data_dir, file_name="dev.tsv"):
        print(
            f"**** Processing dev examples: {os.path.join(data_dir, file_name)}"
        )
        return self._create_example(
            self._read_data(os.path.join(data_dir, file_name)), "dev"
        )
    def get_test_examples(self, data_dir, file_name="test.tsv"):
        print(
            f"**** Processing test examples: {os.path.join(data_dir, file_name)}"
        )
        return self._create_example(
            self._read_data(os.path.join(data_dir, file_name)), "test"
        )
    def get_labels(self, data_split_type=None):
        # NOTE:[PAD] should always be first inorder to have an id=0
        return ["[PAD]", "B", "I", "O", "X", "[CLS]", "[SEP]"]
    def _create_example(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text = convert_to_unicode(line[1])
            label = convert_to_unicode(line[0])
            if "-DOCSTART-" in text:
                # JNLPBA dataset has some entries which have this demarcation
                # print("Text ignore:{}".format(text))
                continue
            examples.append(InputExample(guid=guid, text=text, label=label))
        return examples
    @classmethod
    def _read_data(cls, input_file):
        """Reads a BIO data."""
        if os.path.exists(input_file):
            with open(input_file, "r") as f:
                lines = []
                words = []
                labels = []
                for line in f:
                    contends = line.strip()
                    if len(contends) == 0:
                        assert len(words) == len(labels)
                        if len(words) > 30:
                            # split if the sentence is longer than 30
                            while len(words) > 30:
                                tmplabel = labels[:30]
                                for iidx in range(len(tmplabel)):
                                    if tmplabel.pop() == 'O':
                                        break
                                l = ' '.join(
                                    [
                                        label
                                        for label in labels[: len(tmplabel) + 1]
                                        if len(label) > 0
                                    ]
                                )
                                w = ' '.join(
                                    [
                                        word
                                        for word in words[: len(tmplabel) + 1]
                                        if len(word) > 0
                                    ]
                                )
                                lines.append([l, w])
                                words = words[len(tmplabel) + 1 :]
                                labels = labels[len(tmplabel) + 1 :]
                        if len(words) == 0:
                            continue
                        l = ' '.join(
                            [label for label in labels if len(label) > 0]
                        )
                        w = ' '.join([word for word in words if len(word) > 0])
                        lines.append([l, w])
                        words = []
                        labels = []
                        continue
                    word = line.strip().split()[0]
                    label = line.strip().split()[-1]
                    words.append(word)
                    labels.append(label)
                return lines
        else:
            return [] 
[docs]def write_label_map_files(label_list, out_dir):
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    label2id_file = os.path.join(out_dir, 'label2id.pkl')
    if not os.path.exists(label2id_file):
        with open(label2id_file, 'wb') as w:
            pickle.dump(label_map, w)
    label2id_json_file = os.path.join(out_dir, 'label2id.json')
    if not os.path.exists(label2id_json_file):
        with open(label2id_json_file, 'w') as w:
            json.dump(label_map, w)
    return label_map 
[docs]def get_tokens_and_labels(example, tokenizer, max_seq_length):
    textlist = example.text.split(' ')
    labellist = example.label.split(' ')
    tokens = []
    labels = []
    for i, word in enumerate(textlist):
        token = tokenizer.tokenize(word)
        tokens.extend(token)
        label_1 = labellist[i]
        for m in range(len(token)):
            # If a word is split into sub-words during tokenization,
            # then the first sub-word gets the label of the word and
            # the remaining words are marked with labels "X"
            if m == 0:
                labels.append(label_1)
            else:
                labels.append("X")
    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0 : (max_seq_length - 2)]
        labels = labels[0 : (max_seq_length - 2)]
    return tokens, labels 
[docs]def create_parser():
    """
    Parse command-line arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_split_type",
        choices=["train", "dev", "test", "all"],
        default="all",
        help="Dataset split, choose from 'train', 'test', 'dev' or 'all'.",
    )
    parser.add_argument(
        "--data_dir",
        required=True,
        help="Directory containing train.tsv, test.tsv, dev.tsv",
    )
    parser.add_argument(
        "--vocab_file",
        required=True,
        help="The vocabulary file that the BERT Pretrained model was trained on.",
    )
    parser.add_argument(
        "--do_lower_case",
        required=False,
        action="store_true",
        help="Whether to convert tokens to lowercase",
    )
    parser.add_argument(
        "--max_seq_length",
        required=False,
        type=int,
        default=128,
        help="The maximum total input sequence length after WordPiece tokenization.",
    )
    return parser