# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PretrainingTokenGenerator Module
This module provides the PretrainingTokenGenerator class which is designed to process
text data and create features suitable for language modeling tasks.
Usage:
    tokenizer = PretrainingTokenGenerator(dataset_params, max_sequence_length, tokenizer)
    tokenized_features = tokenizer.encode("Sample text for processing.")
"""
import os
import random
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import numpy as np
from cerebras.modelzoo.data_preparation.data_preprocessing.utils import (
    append_eos_to_multiple_semantic_regions,
    clean_text,
    find_token_range,
    get_data_stats,
    setup_warning_logging,
    split_text_and_tokenize,
)
[docs]class PretrainingTokenGenerator:
    def __init__(
        self, params: Dict[str, Any], tokenizer: Any, eos_id: int, pad_id: int
    ):
        """
        Initialize the PretrainingTokenGenerator class.
        Args:
            params (Dict[str, Any]): Parameters for the dataset and processing.
            tokenizer (Any): Tokenizer to use for tokenization.
            eos_id (int): End-of-sequence token ID.
            pad_id (int): Padding token ID.
        """
        dataset_params = params.get("dataset", {})
        processing_params = params["processing"]
        setup_params = params["setup"]
        warning_log_dir = (
            os.path.join(setup_params.get("output_dir"), "logs")
            if setup_params.get("output_dir")
            else "./data_preprocessing_logs"
        )
        self.logger = setup_warning_logging(warning_log_dir, __name__)
        self.tokenizer = tokenizer
        self.training_objective = dataset_params.pop("training_objective", None)
        self.mlm = (
            (self.training_objective == 'mlm')
            if self.training_objective is not None
            else False
        )
        self.use_vsl = dataset_params.pop("use_vsl", False)
        self.use_ftfy = processing_params.pop("use_ftfy", True)
        self.ftfy_normalizer = processing_params.pop("ftfy_normalizer", "NFC")
        self.wikitext_detokenize = processing_params.pop(
            "wikitext_detokenize", False
        )
        self.min_sequence_len = processing_params.pop("min_sequence_len", 10)
        self.input_ids_dtype = processing_params.pop("input_ids_dtype", "int32")
        self.input_mask_dtype = processing_params.pop(
            "input_mask_dtype", "int32"
        )
        self.inverted_mask = processing_params.pop("inverted_mask", False)
        self.seed = processing_params.pop("seed", 0)
        np.random.seed(self.seed)
        self.max_seq_length = processing_params.pop("max_seq_length", 2048)
        self.short_seq_prob = processing_params.pop("short_seq_prob", 0.0)
        self.semantic_drop_mask = processing_params.pop(
            "semantic_drop_mask", {}
        )
        self.split_text_to_tokenize = dataset_params.pop(
            "split_text_to_tokenize", False
        )
        if self.split_text_to_tokenize:
            self.chunk_len_to_split = dataset_params.pop(
                "chunk_len_to_split", 2000
            )
            self.remove_bos_in_chunks = dataset_params.pop(
                "remove_bos_in_chunks", False
            )
        self.eos_id = eos_id
        self.pad_id = pad_id
        if self.pad_id:
            self.pad_token = self.tokenizer.convert_ids_to_tokens(self.pad_id)
        self.rng = random.Random()
        self.rng.seed(self.seed)
        self.prefix = []
        self.prefix_doc = None
        # Multimodal parameters initialization
        self.is_multimodal = dataset_params.pop("is_multimodal", False)
        self.features = (
            [
                "text_input_ids",
                "loss_mask",
                "labels",
                "key_padding_mask",
                "token_modality_idx",
            ]
            if self.is_multimodal
            else [
                "input_ids",
                "attention_mask",
                "labels",
            ]
        )
        ##MLM fields
        if self.mlm:
            import math
            self.mlm_fraction = dataset_params.pop("mlm_fraction", 0.15)
            self.max_predictions = math.ceil(
                self.mlm_fraction * self.max_seq_length
            )
            self.mlm_with_gather = dataset_params.pop("mlm_with_gather", False)
            self.ignore_index = dataset_params.pop(
                "ignore_index", -100
            )  # default value for torch.nn.CrossEntropyLoss
            self.excluded_tokens = dataset_params.pop(
                "excluded_tokens",
                ['<cls>', '<pad>', '<eos>', '<unk>', '<null_1>', '<mask>'],
            )
            self.allowable_token_ids = self.get_allowable_token_ids()
            self.special_tokens_ids = {
                self.tokenizer.cls_token_id,
                self.tokenizer.pad_token_id,
                self.tokenizer.eos_token_id,
                self.tokenizer.unk_token_id,
            }
            if self.mlm_with_gather:
                self.features.extend(["masked_lm_positions", "masked_lm_mask"])
        self.pack_sequences = dataset_params.pop(
            "pack_sequences", False if self.is_multimodal else True
        )
        self.image_token = dataset_params.pop("image_token", "<image_token>")
        self.image_dir = params["setup"].pop("image_dir", None)
        self.max_num_img = dataset_params.pop("max_num_img", 1)
        self.num_patches = dataset_params.pop("num_patches", 1)
        self.image_token_id = -1
        if (
            self.is_multimodal
            and self.image_token
            and self.image_token not in self.tokenizer.get_vocab()
        ):
            self.tokenizer.add_special_tokens(
                {'additional_special_tokens': [self.image_token]}
            )
            self.image_token_id = self.tokenizer.convert_tokens_to_ids(
                self.image_token
            )
        self.image_ids = (
            [pad_id] * self.num_patches if self.is_multimodal else []
        )
        self.semantic_loss_weight = processing_params.pop(
            "semantic_loss_weight", {}
        )
        self.semantic_drop_mask = processing_params.pop(
            "semantic_drop_mask", {}
        )
        self.semantic_attention_mask = processing_params.pop(
            "semantic_attention_mask", {}
        )
        self.include_image_tag = False
        self.data_ranges = []
        self.eos_token = (
            self.tokenizer.pad_token_id
            if self.eos_id is None
            else self.tokenizer.convert_ids_to_tokens(self.eos_id)
        )
    def create_features_pretraining(
        self,
        doc,
        token_modality_idx=None,
    ):
        input_ids = doc.get("input_ids")
        total_len = len(input_ids)
        if total_len < self.min_sequence_len:
            self.logger.warning(
                "Length of token ids < min_sequence_len, skipping this example..."
            )
            return []
        if not self.is_multimodal and self.rng.random() < self.short_seq_prob:
            input_ids = input_ids[
                0 : self.rng.randint(2, self.max_seq_length - 1)
            ]
            input_mask = input_mask[0 : len(input_ids)]
            attention_mask = attention_mask[0 : len(input_ids)]
        input_mask, attention_mask = doc.get("loss_mask"), doc.get(
            "attention_mask"
        )
        labels = input_ids[1:]
        input_ids = input_ids[:-1]
        attention_mask = attention_mask[:-1]
        input_mask = input_mask[1:]
        # Add padding
        num_pad = self.max_seq_length - len(input_ids)
        padding = [self.pad_id] * num_pad
        input_ids.extend(padding)
        labels.extend(padding)
        padding = [0] * num_pad
        input_mask.extend(padding)
        attention_mask.extend([0] * num_pad)
        assert (
            len(input_ids) == self.max_seq_length
            and len(labels) == self.max_seq_length
            and len(input_mask) == self.max_seq_length
            and len(attention_mask) == self.max_seq_length
        ), "Wrong sequence length"
        # Create features dictionary
        features = {
            "input_ids": getattr(np, self.input_ids_dtype)(input_ids),
            "labels": getattr(np, self.input_ids_dtype)(labels),
        }
        input_mask = getattr(np, self.input_mask_dtype)(input_mask)
        attention_mask = getattr(np, self.input_ids_dtype)(attention_mask)
        if self.inverted_mask:
            input_mask = np.equal(input_mask, 0).astype(self.input_mask_dtype)
        # NOTE this is because our internal stack requires the inverted mask and
        # doesn't do the inversion internally
        if self.is_multimodal:
            key_padding_mask = np.equal(attention_mask, 0).astype(
                input_mask.dtype
            )
        return (
            np.stack(
                [
                    features["input_ids"],
                    input_mask,
                    features["labels"],
                    key_padding_mask,
                    token_modality_idx,
                ]
            )
            if self.is_multimodal
            else np.stack(
                [
                    features["input_ids"],
                    input_mask,
                    features["labels"],
                ]
            )
        )
[docs]    def create_features_auto_lm(
        self,
        token_ids: List[int],
    ) -> np.ndarray:
        """Given a list of token_ids, generate input sequence and labels.
        Args:
            token_ids (List[int]): List containing token ids for creating features,
                labels and input mask from.
        Returns:
            np.ndarray: Array containing features, labels, and input mask.
        """
        if not len(token_ids) >= self.min_sequence_len:
            self.logger.warning(
                f"token_ids must have at least {self.min_sequence_len} elements, skipping this example..."
            )
            return []
        if self.rng.random() < self.short_seq_prob:
            token_ids = token_ids[
                0 : self.rng.randint(2, self.max_seq_length - 1)
            ]
        input_ids = token_ids[:-1]
        labels = token_ids[1:]
        input_mask = [1] * len(labels)
        # padding
        num_pad = self.max_seq_length - len(input_ids)
        padding = [self.pad_id] * num_pad
        input_ids.extend(padding)
        labels.extend(padding)
        input_mask.extend([0] * num_pad)
        # assertions to ensure correct output shapes
        assert (
            len(input_ids) == self.max_seq_length
            and len(labels) == self.max_seq_length
            and len(input_mask) == self.max_seq_length
        ), "Wrong sequence length"
        # create feature dict
        features = dict()
        features["input_ids"] = getattr(np, self.input_ids_dtype)(input_ids)
        features["input_mask"] = getattr(np, self.input_mask_dtype)(input_mask)
        if self.inverted_mask:
            features["input_mask"] = np.equal(features["input_mask"], 0).astype(
                features["input_mask"].dtype
            )
        labels = getattr(np, self.input_ids_dtype)(labels)
        return np.stack([features["input_ids"], features["input_mask"], labels]) 
    def chop_doc_into_msl(self, data):
        doc_list = []
        tokenized_data = data.get("tokenized_data")
        tokenized_semantic_regions = data.get("tokenized_semantic_regions")
        image_paths = data.get("image_paths", [])
        image_index = 0
        max_len = self.max_seq_length + 1  # Including space for EOS token
        last_chunk_tokens = []
        last_chunk_loss_mask = []
        last_chunk_attention_mask = []
        last_chunk_img_paths = []
        last_chunk_has_img = False
        last_chunk_image_positions = []
        if self.pack_sequences and self.prefix_doc:
            last_chunk_tokens = self.prefix_doc['input_ids']
            last_chunk_loss_mask = self.prefix_doc['loss_mask']
            last_chunk_attention_mask = self.prefix_doc['attention_mask']
            last_chunk_img_paths = self.prefix_doc['image_paths']
            last_chunk_has_img = self.prefix_doc['has_img']
            last_chunk_image_positions = self.prefix_doc['image_data_positions']
            self.prefix_doc = None
        for idx, region in enumerate(tokenized_semantic_regions):
            modality = region["region_modality"]
            loss_weight = region["loss_weight"]
            attn_mask_value = region["attention_mask"]
            start_idx, end_idx = region["indices"]
            orig_idx = (start_idx, end_idx)
            start_idx = 0 if idx == 0 else start_idx
            tokenized_semantic_regions[idx]['indices'] = (start_idx, end_idx)
            tokens = tokenized_data["input_ids"][start_idx:end_idx]
            region_len = len(tokens)
            # Generate loss_mask and attention_mask for the region
            region_loss_mask = [loss_weight] * region_len
            region_attention_mask = [attn_mask_value] * region_len
            if modality != "image":
                # Combine last_chunk_* variables with current tokens
                tokens = last_chunk_tokens + tokens
                region_loss_mask = last_chunk_loss_mask + region_loss_mask
                region_attention_mask = (
                    last_chunk_attention_mask + region_attention_mask
                )
                region_len = len(
                    tokens
                )  # Update region_len after concatenation
                # Split text region into chunks fitting max_len
                chunks = [
                    (
                        tokens[i : i + max_len],
                        region_loss_mask[i : i + max_len],
                        region_attention_mask[i : i + max_len],
                    )
                    for i in range(0, region_len, max_len)
                ]
                # Determine the number of complete chunks
                num_chunks = len(chunks)
                if len(chunks[-1][0]) < max_len:
                    num_complete_chunks = num_chunks - 1
                else:
                    num_complete_chunks = num_chunks
                # Process complete chunks
                for idx in range(num_complete_chunks):
                    chunk_tokens, chunk_loss_mask, chunk_attention_mask = (
                        chunks[idx]
                    )
                    if idx == 0 and last_chunk_tokens:
                        assert len(chunk_tokens) == len(
                            chunk_loss_mask
                        ), "Length of input ids and loss is different"
                        # First chunk may have image data from previous last_chunk
                        doc_list.append(
                            {
                                "input_ids": chunk_tokens,
                                "loss_mask": chunk_loss_mask,
                                "attention_mask": chunk_attention_mask,
                                "image_paths": last_chunk_img_paths,
                                "has_img": last_chunk_has_img,
                                "image_data_positions": last_chunk_image_positions,
                            }
                        )
                        # Reset image data after first chunk
                        last_chunk_img_paths = []
                        last_chunk_has_img = False
                        last_chunk_image_positions = []
                    else:
                        assert len(chunk_tokens) == len(
                            chunk_loss_mask
                        ), "Length of input ids and loss is different"
                        # Subsequent chunks without image data
                        doc_list.append(
                            {
                                "input_ids": chunk_tokens,
                                "loss_mask": chunk_loss_mask,
                                "attention_mask": chunk_attention_mask,
                                "image_paths": [],  # No images in subsequent chunks
                                "has_img": False,
                                "image_data_positions": [],
                            }
                        )
                # Update last_chunk_* variables for the next iteration
                if num_complete_chunks < num_chunks:
                    # The last chunk is incomplete; store it for the next iteration
                    last_chunk_tokens = chunks[-1][0]
                    last_chunk_loss_mask = chunks[-1][1]
                    last_chunk_attention_mask = chunks[-1][2]
                    # last_chunk_img_paths and last_chunk_has_img remain reset
                else:
                    # All chunks are complete; reset last_chunk_* variables
                    last_chunk_tokens = []
                    last_chunk_loss_mask = []
                    last_chunk_attention_mask = []
                    # last_chunk_img_paths and last_chunk_has_img remain reset
            else:
                # Handle image region
                image_path = image_paths[image_index]
                image_index += 1
                image_tokens = tokens  # Image tokens should not be split
                image_loss_mask = region_loss_mask
                image_attention_mask = region_attention_mask
                image_len = len(image_tokens)
                combined_len = len(last_chunk_tokens) + image_len
                if combined_len < max_len - 1:
                    # Add image tokens to last_chunk
                    start_idx = len(last_chunk_tokens)
                    last_chunk_tokens.extend(image_tokens)
                    last_chunk_loss_mask.extend(image_loss_mask)
                    last_chunk_attention_mask.extend(image_attention_mask)
                    end_idx = len(last_chunk_tokens)
                    last_chunk_img_paths += [image_path]
                    last_chunk_has_img = True
                    image_indices = (
                        orig_idx if idx == 0 else (start_idx, end_idx)
                    )
                    last_chunk_image_positions.append(image_indices)
                else:
                    # Finalize last_chunk
                    if last_chunk_tokens:
                        assert len(last_chunk_tokens) == len(
                            last_chunk_loss_mask
                        ), "Length of input ids and loss is different"
                        doc_list.append(
                            {
                                "input_ids": last_chunk_tokens,
                                "loss_mask": last_chunk_loss_mask,
                                "attention_mask": last_chunk_attention_mask,
                                "image_paths": last_chunk_img_paths,
                                "has_img": last_chunk_has_img,
                                "image_data_positions": last_chunk_image_positions,
                            }
                        )
                    # Start new last_chunk with image tokens if they fit
                    if image_len < max_len - 1:
                        last_chunk_tokens = image_tokens
                        last_chunk_loss_mask = image_loss_mask
                        last_chunk_attention_mask = image_attention_mask
                        last_chunk_img_paths = [image_path]
                        last_chunk_has_img = True
                        last_chunk_image_positions = [(0, image_len)]
                    else:
                        # Image tokens exceed max_len; cannot split images
                        raise ValueError(
                            "Image tokens exceed maximum sequence length."
                        )
        # Append any remaining last_chunk to doc_list
        if last_chunk_tokens:
            assert len(last_chunk_tokens) == len(
                last_chunk_loss_mask
            ), "Length of input ids and loss is different"
            doc = {
                "input_ids": last_chunk_tokens,
                "loss_mask": last_chunk_loss_mask,
                "attention_mask": last_chunk_attention_mask,
                "image_paths": last_chunk_img_paths,
                "has_img": last_chunk_has_img,
                "image_data_positions": last_chunk_image_positions,
            }
            if not self.pack_sequences:
                doc_list.append(doc)
            else:
                self.prefix_doc = doc
        return doc_list
[docs]    def get_segment_indices(
        self,
        tokenized_data: Dict[str, int],
        semantic_region_list: List[Dict[str, Any]],
    ) -> List[Dict[str, Any]]:
        """
        Get segment indices for the data ranges.
        Args:
            tokenized_data (List[Tuple[int, int]]): Tokenized data with offset mappings.
            text_semantic_region_list (List[Dict[str, Any]]): List of text semantic regions with region details.
        Returns:
            List[Dict[str, Any]]: List of tokenized semantic regions and image regions with their indices.
        """
        tokenized_semantic_region_list = []
        tokenized_semantic_region = None
        starting_offset_index = 0
        for region in semantic_region_list:
            region_name = region.get("region_modality")
            tokenized_semantic_region = find_token_range(
                region,
                tokenized_data["offset_mapping"],
                starting_offset_index,
            )
            tokenized_semantic_region["region_modality"] = region_name
            starting_offset_index = tokenized_semantic_region["indices"][1]
            tokenized_semantic_region_list.append(tokenized_semantic_region)
        return tokenized_semantic_region_list 
[docs]    def get_allowable_token_ids(self) -> List[int]:
        """Generate a list of token IDs that can be masked."""
        excluded_token_ids = {
            self.tokenizer.convert_tokens_to_ids(tok)
            for tok in self.excluded_tokens
            if tok in self.tokenizer.get_vocab()
        }
        allowable_token_ids = [
            tok_id
            for tok, tok_id in self.tokenizer.get_vocab().items()
            if tok_id not in excluded_token_ids
        ]
        return list(allowable_token_ids) 
[docs]    def mask_single_sequence(
        self, input_ids: List[int]
    ) -> Tuple[List[int], List[int], List[int], List[int]]:
        """
        Masks tokens in a single sequence according to the MLM strategy.
        When self.mlm_with_gather is False, the returning len(labels) == len(input_ids)
        When self.mlm_with_gather is True, the returning len(labels) == self.max_predictions
        Args:
            input_ids (List[int]): Original sequence of token IDs.
        Returns:
            Tuple[List[int], List[int], List[int], List[int]]:
                - input_ids: Modified sequence with masked tokens.
                - masked_lm_positions: Positions of the masked tokens, empty if not self.mlm_with_gather.
                - masked_lm_mask: Binary indicators (1s) for positions that were masked, empty if not self.mlm_with_gather.
                - labels: Original token IDs of the masked tokens for label purposes.
        """
        sequence = np.array(input_ids.copy())
        masked_lm_positions = []
        masked_lm_mask = []
        labels = (
            [] if self.mlm_with_gather else [self.ignore_index] * len(input_ids)
        )
        indices_can_be_masked = [
            i
            for i, token_id in enumerate(input_ids)
            if token_id not in self.special_tokens_ids
        ]
        # Calculate the number of tokens to mask
        num_tokens_to_mask = min(
            int(self.mlm_fraction * len(indices_can_be_masked)),
            self.max_predictions,
        )
        if num_tokens_to_mask > 0:
            # Randomly select tokens to mask
            indices_to_mask = sorted(
                self.rng.sample(indices_can_be_masked, k=num_tokens_to_mask)
            )
            for pos in indices_to_mask:
                original_token_id = sequence[pos].copy()
                prob = self.rng.random()
                if prob < 0.8:  # 80% of the time, replace with [MASK]
                    sequence[pos] = self.tokenizer.mask_token_id
                elif prob < 0.9:  # 10% of the time, replace with a random token
                    # Ensure selected token is not a special token
                    masked_token_id = np.random.choice(self.allowable_token_ids)
                    sequence[pos] = masked_token_id
                elif prob <= 1.0:
                    pass  # 10% of the time, keep the original token
                # Store the original token ID as the label
                if self.mlm_with_gather:
                    masked_lm_positions.append(pos)
                    masked_lm_mask.append(1)
                    labels.append(original_token_id)
                else:
                    labels[pos] = original_token_id
        if self.mlm_with_gather:
            # Pad the lists to reach max_predictions length
            num_paddings = self.max_predictions - len(masked_lm_positions)
            masked_lm_positions = masked_lm_positions + [0] * num_paddings
            masked_lm_mask = masked_lm_mask + [0] * num_paddings
            labels = labels + [self.ignore_index] * num_paddings
        return list(sequence), masked_lm_positions, masked_lm_mask, labels 
[docs]    def process_chunks(
        self, tokenized_text_chunks: List[List[int]]
    ) -> Tuple[List[np.ndarray], Dict[str, int]]:
        """
        Processes chunks of tokenized text and returns processed features along with the total padding added.
        Args:
            tokenized_text_chunks (List[List[int]]): A list of tokenized text chunks, where each chunk is represented as a list of integers.
        Returns:
            Tuple[List[np.ndarray], Dict[str, int]]: A tuple containing a list of processed results and dataset stats.
        """
        results = {"data": []}  # List to store the processed results
        stats = defaultdict(int)
        # Iterate over each chunk in the tokenized text chunks
        for chunk in tokenized_text_chunks:
            # Process the chunk and get the processed result and number of padding tokens added
            processed = self.create_features_auto_lm(
                chunk,
            )
            # If the processed chunk is not empty, add the results to the list and update the total padding
            if len(processed) != 0:
                processed_stats = get_data_stats(
                    processed, self.pad_id, self.eos_id, self.max_seq_length
                )
                for key in processed_stats:
                    stats[key] += processed_stats[key]
                results["data"].append(processed)
        # Return the list of processed results and data stats
        return results, stats 
[docs]    def process_chunks_mlm(
        self, tokenized_text_chunks: List[List[int]]
    ) -> Tuple[List[Any], Dict]:
        """
        Processes chunks of tokenized text and returns processed features along with the total padding added.
        Args:
        tokenized_text_chunks (List[List[int]]): A list of tokenized text chunks, where each chunk is represented as a list of integers.
        Returns:
        Tuple[List[Any], Dict]: A tuple containing a list of processed results and dataset stats.
        """
        results = {
            'data': [],
            'labels': [],
        }  # List to store the processed result
        stats = defaultdict(int)
        masked_lm_positions_list = []
        masked_lm_mask_list = []
        input_id_list = []
        labels_list = []
        attention_mask_list = []
        # Iterate over each chunk in the tokenized text chunks
        for chunk in tokenized_text_chunks:
            input_ids, masked_lm_positions, masked_lm_mask, labels = (
                self.mask_single_sequence(chunk)
            )
            num_pad = self.max_seq_length - len(input_ids)
            attention_mask = [1] * len(input_ids) + [0] * num_pad
            input_ids = input_ids + [self.pad_id] * num_pad
            input_id_list.append(input_ids)
            attention_mask_list.append(attention_mask)
            labels_list.append(labels)
            masked_lm_positions_list.append(masked_lm_positions)
            masked_lm_weights_list.append(masked_lm_weights)
            lvt = len(labels) - labels.count(self.ignore_index)
            processed_stats = get_data_stats(
                np.expand_dims(np.array(input_ids), 0),
                self.pad_id,
                self.eos_id,
                self.max_seq_length,
                lvt,
            )
            for key in processed_stats:
                stats[key] += processed_stats[key]
        if len(tokenized_text_chunks) > 0:
            results['data'] = np.stack(
                [np.array(input_id_list), np.array(attention_mask_list)], axis=1
            )
            if self.mlm_with_gather:
                results['labels'] = np.stack(
                    [
                        np.array(labels_list),
                        np.array(masked_lm_positions_list),
                        np.array(masked_lm_weights_list),
                    ],
                    axis=1,
                )
            else:
                results['labels'] = np.stack(
                    [np.array(labels_list)],
                    axis=1,
                )
        # Return the list of processed results and data stats
        return results, stats 
    def process_mlm(self, text_data, raw_data_stats):
        tokenized_data = self.tokenizer(
            text_data,
            max_length=self.max_seq_length,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
        )
        input_ids, attention_mask = (
            tokenized_data['input_ids'],
            tokenized_data['attention_mask'],
        )
        tokenized_data_stats = dict()
        results = dict()
        tokenized_data_stats["processed"] = 1
        tokenized_data_stats["successful"] = 0
        if input_ids == []:
            tokenized_data_stats["discarded"] = 1
            return {"data": [], "labels": []}, tokenized_data_stats
        tokenized_data_stats["successful"] = 1
        input_ids, masked_lm_positions, masked_lm_mask, labels = (
            self.mask_single_sequence(input_ids)
        )
        results['data'] = np.stack(
            [np.array(input_ids), np.array(attention_mask)], axis=0
        ).reshape(1, 2, self.max_seq_length)
        if self.mlm_with_gather:
            results['labels'] = np.stack(
                [
                    np.array(labels),
                    np.array(masked_lm_positions),
                    np.array(masked_lm_mask),
                ],
                axis=0,
            ).reshape(1, 3, self.max_predictions)
        else:
            results['labels'] = np.stack(
                [
                    np.array(labels),
                ],
                axis=0,
            ).reshape(1, 1, self.max_seq_length)
        tokenized_data_stats["non_pad_tokens"] = sum(
            1 for id in input_ids if id != self.pad_id
        )
        tokenized_data_stats["num_pad_tokens"] = (
            self.max_seq_length - tokenized_data_stats["non_pad_tokens"]
        )
        tokenized_data_stats["num_tokens"] = self.max_seq_length
        tokenized_data_stats["num_masked_tokens"] = input_ids.count(
            self.tokenizer.mask_token_id
        )
        tokenized_data_stats["loss_valid_tokens"] = len(labels) - labels.count(
            self.ignore_index
        )
        tokenized_data_stats.update(raw_data_stats)
        return results, tokenized_data_stats
    def process_single_semantic_region(self, text_data, raw_data_stats):
        discarded_files = 0
        # tokenize text
        if self.split_text_to_tokenize:
            tokenized_text = split_text_and_tokenize(
                text_data,
                self.tokenizer,
                max_tok_len=self.chunk_len_to_split,
                remove_bos_in_chunks=self.remove_bos_in_chunks,
            )
        else:
            tokenized_text = self.tokenizer.encode(text_data)
        if self.eos_id is not None:
            tokenized_text += [self.eos_id]
        all_text = self.prefix + tokenized_text
        tokenized_text_chunks = [
            all_text[i : i + self.max_seq_length + 1]
            for i in range(0, len(all_text), self.max_seq_length)
        ]
        # reset prefix
        self.prefix = []
        # update prefix if last chunk is < max_seq_length
        num_tokens_last_chunk = len(tokenized_text_chunks[-1])
        if self.pack_sequences:
            if num_tokens_last_chunk < self.max_seq_length + 1:
                last_chunk = tokenized_text_chunks.pop(-1)
                self.prefix.extend(last_chunk)
        elif num_tokens_last_chunk < 2:
            _ = tokenized_text_chunks.pop(-1)
        results, tokenized_data_stats = self.process_chunks(
            tokenized_text_chunks
        )
        if results["data"] == [] and self.prefix == []:
            discarded_files += 1
        tokenized_data_stats["discarded"] = discarded_files
        tokenized_data_stats["processed"] = 1
        tokenized_data_stats["successful"] = (
            tokenized_data_stats["processed"]
            - tokenized_data_stats["discarded"]
        )
        tokenized_data_stats.update(raw_data_stats)
        return results, tokenized_data_stats
    def tokenize_data(self, semantic_data_array):
        region_data, raw_data_stats = self.parse_semantic_data_array(
            semantic_data_array
        )
        if not region_data:
            return {}, raw_data_stats
        semantic_regions = region_data.get("semantic_regions")
        data, image_paths = (
            region_data.get("data"),
            region_data.get("image_paths"),
        )
        if data == "":
            return {}, raw_data_stats
        if self.mlm:
            return self.process_mlm(data, raw_data_stats)
        if (self.split_text_to_tokenize) and not self.is_multimodal:
            return self.process_single_semantic_region(data, raw_data_stats)
        if self.split_text_to_tokenize:
            raise ValueError(
                f"Multiple semantic region is not supported with `split_text_to_tokenize`."
            )
        tokenized_data = self.tokenizer(
            data,
            return_offsets_mapping=True,
        )
        if self.is_multimodal:
            # Convert input_ids to numpy array
            input_ids = np.array(tokenized_data['input_ids'])
            # Replace image_token_id with pad_id
            input_ids[input_ids == self.image_token_id] = self.pad_id
            # Convert back to list
            tokenized_data['input_ids'] = input_ids.tolist()
        if len(semantic_regions) > 0:
            append_eos_to_multiple_semantic_regions(
                data,
                self.data_ranges,
                self.eos_token,
                self.image_token,
                False,
            )
        tokenized_semantic_region_list = self.get_segment_indices(
            tokenized_data,
            semantic_regions,
        )
        data = {
            "tokenized_data": tokenized_data,
            "image_paths": image_paths,
            "tokenized_semantic_regions": tokenized_semantic_region_list,
        }
        doc_list = self.chop_doc_into_msl(data)
        results, tokenized_data_stats = self.process_docs(doc_list)
        data_stats = {
            "discarded": tokenized_data_stats["discarded"],
            "processed": tokenized_data_stats["processed"],
            "successful": tokenized_data_stats["successful"],
            "raw_chars_count": raw_data_stats["raw_chars_count"],
            "raw_bytes_count": raw_data_stats["raw_bytes_count"],
            "normalized_chars_count": raw_data_stats["normalized_chars_count"],
            "normalized_bytes_count": raw_data_stats["normalized_bytes_count"],
            "num_pad_tokens": tokenized_data_stats["num_pad_tokens"],
            "non_pad_tokens": tokenized_data_stats["non_pad_tokens"],
            "num_masked_tokens": tokenized_data_stats["num_masked_tokens"],
            "loss_valid_tokens": tokenized_data_stats["loss_valid_tokens"],
            "num_tokens": tokenized_data_stats["num_tokens"],
        }
        return results, data_stats
    def parse_semantic_data_array(
        self, semantic_data_array: List[Dict[str, Any]]
    ) -> Tuple[Tuple[List[str], List[Dict[str, str]]], Dict[str, int]]:
        if not semantic_data_array:
            return {}, {}
        image_paths = []
        semantic_regions = []
        stats = {
            "raw_chars_count": 0,
            "raw_bytes_count": 0,
            "normalized_chars_count": 0,
            "normalized_bytes_count": 0,
        }
        formatted_data = ""
        formatted_data_length = 0
        for entry in semantic_data_array:
            semantic_loss_weight = entry.get("semantic_loss_weight")
            semantic_drop_mask = entry.get("semantic_drop_mask")
            semantic_attention_mask = entry.get("semantic_attention_mask")
            if semantic_loss_weight is not None and len(
                semantic_loss_weight
            ) != len(entry["content"]):
                raise ValueError(
                    "The length of semantic loss mask must match the number of regions"
                )
            if semantic_drop_mask is not None and len(
                semantic_drop_mask
            ) != len(entry["content"]):
                raise ValueError(
                    "The length of semantic drop mask must match the number of regions"
                )
            if semantic_attention_mask is not None and len(
                semantic_attention_mask
            ) != len(entry["content"]):
                raise ValueError(
                    "The length of semantic attention mask must match the number of regions"
                )
            content_parts = []
            global_idx = 0
            for i, part in enumerate(entry["content"]):
                region_key = list(part.keys())[0]
                region_val = part[region_key]
                if not region_val:
                    continue
                if region_key != "image":
                    cleaned_region_val = clean_text(
                        region_val,
                        self.use_ftfy,
                        self.wikitext_detokenize,
                        self.ftfy_normalizer,
                    )
                    stats["raw_chars_count"] += len(region_val)
                    stats["raw_bytes_count"] += len(region_val.encode("utf-8"))
                    stats["normalized_chars_count"] += len(cleaned_region_val)
                    stats["normalized_bytes_count"] += len(
                        cleaned_region_val.encode("utf-8")
                    )
                else:
                    cleaned_region_val = region_val
                include_tags = part.pop("include_tags", False)
                if not semantic_loss_weight:
                    loss_weight = self.semantic_loss_weight.get(region_key)
                    if not loss_weight:
                        ## set default weights
                        loss_weight = 1 if region_key != "image" else 0
                else:
                    loss_weight = semantic_loss_weight[i]
                if not semantic_drop_mask:
                    drop_region = self.semantic_drop_mask.get(region_key, False)
                else:
                    drop_region = semantic_drop_mask[i]
                if not semantic_attention_mask:
                    attention_mask = self.semantic_attention_mask.get(
                        region_key, True
                    )
                else:
                    attention_mask = semantic_attention_mask[i]
                attention_mask = 1 if attention_mask else 0
                if region_key != "image":  ## hardcoding name of image
                    if not drop_region and cleaned_region_val != "":
                        if include_tags:
                            cleaned_region_val = (
                                f"<|{region_key}|>"
                                + cleaned_region_val
                                + f"<|{region_key}|>"
                            )
                        semantic_regions.append(
                            {
                                "indices": (
                                    formatted_data_length,
                                    formatted_data_length
                                    + len(cleaned_region_val),
                                ),
                                "region_modality": region_key,
                                "region_len": len(cleaned_region_val),
                                "loss_weight": loss_weight,
                                "attention_mask": attention_mask,
                            }
                        )
                        formatted_data_length += len(cleaned_region_val)
                        content_parts.append(cleaned_region_val)
                else:
                    if not drop_region:
                        image_paths.append(cleaned_region_val)
                        # Store the pad id for image region and handle `include_tags`
                        patches = self.num_patches * [self.image_token]
                        patches = ''.join(patches)
                        if include_tags:
                            patches = (
                                f"<|{region_key}|>"
                                + patches
                                + f"<|{region_key}|>"
                            )
                        patch_len = len(patches)
                        semantic_regions.append(
                            {
                                "indices": (
                                    formatted_data_length,
                                    formatted_data_length + patch_len,
                                ),
                                "region_modality": region_key,
                                "loss_weight": loss_weight,
                                "attention_mask": attention_mask,
                            }
                        )
                        formatted_data_length += patch_len
                        content_parts.append(patches)
                global_idx += 1
            formatted_data += ''.join(content_parts)
        if (
            self.is_multimodal
            and self.max_num_img
            and len(image_paths) > self.max_num_img
        ):
            self.logger.warning(
                f"Document more images than max_num_img. Skipping this doc..."
            )
            stats["raw_docs_skipped"] = 1
            return {}, stats
        # Validate image paths
        for i, path in enumerate(image_paths):
            if path:
                full_path = os.path.join(self.image_dir, path)
                if not os.path.exists(full_path):
                    self.logger.warning(
                        f"Image with path - {full_path} does not exist. Hence skipping this."
                    )
                    stats["raw_docs_skipped"] = 1
                    return {}, stats
                else:
                    image_paths[i] = path.encode(encoding='utf-8')
        transformed_data = {
            "data": formatted_data,
            "image_paths": image_paths,
            "semantic_regions": semantic_regions,
        }
        return transformed_data, stats
    def process_docs(self, doc_list):
        results = defaultdict(list)
        tokenized_data_stats = defaultdict(int)
        if doc_list == []:
            tokenized_data_stats["processed"] += 1
            if self.prefix_doc is None:
                tokenized_data_stats["discarded"] += 1
            else:
                tokenized_data_stats["successful"] += 1
            return {}, tokenized_data_stats
        # Add eos at the end.
        # TODO: Find a better way to handle this?
        last_doc = doc_list[-1]
        if last_doc.get("input_ids", []) != [] and self.eos_id != None:
            if last_doc['input_ids'][-1] != self.eos_id:
                if len(last_doc['input_ids']) < self.max_seq_length + 1:
                    last_doc['input_ids'].append(self.eos_id)
                    last_doc['loss_mask'].append(last_doc['loss_mask'][-1])
                    last_doc['attention_mask'].append(
                        last_doc['attention_mask'][-1]
                    )
                else:
                    last_doc['input_ids'][-1] = self.eos_id
        for doc_idx, doc in enumerate(doc_list):
            has_img = False
            if doc.get("input_ids", []) == []:
                continue
            token_modality_idx = (
                np.zeros(self.max_seq_length) if self.is_multimodal else None
            )
            image_paths, image_data_positions = doc.pop(
                "image_paths", None
            ), doc.pop("image_data_positions", None)
            has_img = doc.pop("has_img", None)
            img_data_loc = None
            if self.is_multimodal:
                img_data_loc = np.full(
                    (self.max_num_img, self.num_patches), self.max_seq_length
                )
                assert (
                    len(image_data_positions) <= self.max_num_img
                ), "Number of images should be <= max_num_images"
                # Preallocate img_data_loc as a list of arrays to avoid dynamic resizing
                for image_index, (start_img_pos, end_img_pos) in enumerate(
                    image_data_positions
                ):
                    img_data_loc[image_index] = np.arange(
                        start_img_pos, end_img_pos
                    )
                    # Efficiently update the token_modality_idx using vectorized assignment
                    token_modality_idx[start_img_pos:end_img_pos] = 1
            sample = self.create_features_pretraining(
                doc,
                token_modality_idx,
            )
            if sample == []:
                continue
            if self.is_multimodal:
                if image_paths:
                    num_images = len(image_paths)
                    image_paths += [None] * (self.max_num_img - num_images)
                    has_img = True
                else:
                    image_paths = [None] * (self.max_num_img)
            sample_stats = get_data_stats(
                sample, self.pad_id, self.eos_id, self.max_seq_length
            )
            for key in sample_stats:
                tokenized_data_stats[key] += sample_stats[key]
            data = (
                {
                    "data": sample,
                    "img_path": np.array(image_paths, dtype="S"),
                    "has_img": np.array([has_img], dtype=np.bool_),
                    "img_data_loc": img_data_loc,
                }
                if self.is_multimodal
                else {
                    "data": sample,
                }
            )
            for key, value in data.items():
                results[key].append(value)
        tokenized_data_stats["processed"] += 1
        if results.get("data", []) == []:
            tokenized_data_stats["discarded"] += 1
        else:
            tokenized_data_stats["successful"] += 1
        return results, tokenized_data_stats
[docs]    def encode(
        self, semantic_data_array: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Tokenize and encode the data for auto-regressive language modeling.
        Args:
            semantic_data_array (Union[Dict[str, Any], List[Dict[str, Any]]]): Data to encode.
        Returns:
            Tuple[Dict[str, Any], Dict[str, int]]: Tuple of encoded features for auto-regressive language modeling and dataset stats.
        """
        tokenized_data, data_stats = self.tokenize_data(semantic_data_array)
        if tokenized_data.get("data", []) == []:
            return {}, data_stats
        else:
            if self.is_multimodal:
                data = tokenized_data
            else:
                if not self.mlm:
                    data = {'data': tokenized_data['data']}
                else:
                    data = {'data': tokenized_data['data']}
                    if not self.use_vsl:
                        data['labels'] = tokenized_data['labels']
        return data, data_stats 
[docs]    def encode_leftover_prefix(
        self, prefix: List[np.ndarray]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Processes the leftover prefix which is a list of ndarray tokens into chunks based
        on max sequence length.
        The last chunk is handled specifically if it's shorter than the max sequence
        length. If the last chunk has less than two tokens, it's discarded.
        Args:
            prefix (List[np.ndarray]): The prefix list of token arrays to process.
        Returns:
            Tuple[Dict[str, Any], Dict[str, int]]: A tuple containing the processed token chunks as
            a list of ndarrays and the dataset stats.
        """
        if (self.split_text_to_tokenize or self.mlm) and not self.is_multimodal:
            tokenized_text_chunks = (
                [
                    prefix[i : i + self.max_seq_length]
                    for i in range(0, len(prefix), self.max_seq_length)
                ]
                if self.mlm
                else [
                    prefix[i : i + self.max_seq_length + 1]
                    for i in range(0, len(prefix), self.max_seq_length)
                ]
            )
            # Handle last chunk if shorter than max_seq_length
            num_tokens_last_chunk = len(tokenized_text_chunks[-1])
            if num_tokens_last_chunk < self.max_seq_length + 1:
                _ = tokenized_text_chunks.pop(-1)
            elif num_tokens_last_chunk < 2:
                _ = tokenized_text_chunks.pop(-1)
            results, stats = self.process_chunks(tokenized_text_chunks)
            if results["data"] == []:
                return {}, stats
            data = results
            return data, stats
        # Handle prefix doc
        if not prefix:
            return {}, {}
        doc_list = prefix
        results, tokenized_data_stats = self.process_docs(doc_list)
        if results == {}:
            return {}, {}
        data_stats = {
            "num_pad_tokens": tokenized_data_stats["num_pad_tokens"],
            "non_pad_tokens": tokenized_data_stats["non_pad_tokens"],
            "num_masked_tokens": tokenized_data_stats["num_masked_tokens"],
            "loss_valid_tokens": tokenized_data_stats["loss_valid_tokens"],
            "num_tokens": tokenized_data_stats["num_tokens"],
        }
        return results, data_stats