# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PretrainingTokenGenerator Module
This module provides the PretrainingTokenGenerator class which is designed to process
text data and create features suitable for language modeling tasks.
Usage:
tokenizer = PretrainingTokenGenerator(dataset_params, max_sequence_length, tokenizer)
tokenized_features = tokenizer.encode("Sample text for processing.")
"""
import os
import random
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import numpy as np
from cerebras.modelzoo.data_preparation.data_preprocessing.utils import (
append_eos_to_multiple_semantic_regions,
clean_text,
find_token_range,
get_data_stats,
setup_warning_logging,
split_text_and_tokenize,
)
[docs]class PretrainingTokenGenerator:
def __init__(
self, params: Dict[str, Any], tokenizer: Any, eos_id: int, pad_id: int
):
"""
Initialize the PretrainingTokenGenerator class.
Args:
params (Dict[str, Any]): Parameters for the dataset and processing.
tokenizer (Any): Tokenizer to use for tokenization.
eos_id (int): End-of-sequence token ID.
pad_id (int): Padding token ID.
"""
dataset_params = params.get("dataset", {})
processing_params = params["processing"]
setup_params = params["setup"]
warning_log_dir = (
os.path.join(setup_params.get("output_dir"), "logs")
if setup_params.get("output_dir")
else "./data_preprocessing_logs"
)
self.logger = setup_warning_logging(warning_log_dir, __name__)
self.tokenizer = tokenizer
self.training_objective = dataset_params.pop("training_objective", None)
self.mlm = (
(self.training_objective == 'mlm')
if self.training_objective is not None
else False
)
self.use_vsl = dataset_params.pop("use_vsl", False)
self.use_ftfy = processing_params.pop("use_ftfy", True)
self.ftfy_normalizer = processing_params.pop("ftfy_normalizer", "NFC")
self.wikitext_detokenize = processing_params.pop(
"wikitext_detokenize", False
)
self.min_sequence_len = processing_params.pop("min_sequence_len", 10)
self.input_ids_dtype = processing_params.pop("input_ids_dtype", "int32")
self.input_mask_dtype = processing_params.pop(
"input_mask_dtype", "int32"
)
self.inverted_mask = processing_params.pop("inverted_mask", False)
self.seed = processing_params.pop("seed", 0)
np.random.seed(self.seed)
self.max_seq_length = processing_params.pop("max_seq_length", 2048)
self.short_seq_prob = processing_params.pop("short_seq_prob", 0.0)
self.semantic_drop_mask = processing_params.pop(
"semantic_drop_mask", {}
)
self.split_text_to_tokenize = dataset_params.pop(
"split_text_to_tokenize", False
)
if self.split_text_to_tokenize:
self.chunk_len_to_split = dataset_params.pop(
"chunk_len_to_split", 2000
)
self.remove_bos_in_chunks = dataset_params.pop(
"remove_bos_in_chunks", False
)
self.eos_id = eos_id
self.pad_id = pad_id
if self.pad_id:
self.pad_token = self.tokenizer.convert_ids_to_tokens(self.pad_id)
self.rng = random.Random()
self.rng.seed(self.seed)
self.prefix = []
self.prefix_doc = None
# Multimodal parameters initialization
self.is_multimodal = dataset_params.pop("is_multimodal", False)
self.features = (
[
"text_input_ids",
"loss_mask",
"labels",
"key_padding_mask",
"token_modality_idx",
]
if self.is_multimodal
else [
"input_ids",
"attention_mask",
"labels",
]
)
##MLM fields
if self.mlm:
import math
self.mlm_fraction = dataset_params.pop("mlm_fraction", 0.15)
self.max_predictions = math.ceil(
self.mlm_fraction * self.max_seq_length
)
self.mlm_with_gather = dataset_params.pop("mlm_with_gather", False)
self.ignore_index = dataset_params.pop(
"ignore_index", -100
) # default value for torch.nn.CrossEntropyLoss
self.excluded_tokens = dataset_params.pop(
"excluded_tokens",
['<cls>', '<pad>', '<eos>', '<unk>', '<null_1>', '<mask>'],
)
self.allowable_token_ids = self.get_allowable_token_ids()
self.special_tokens_ids = {
self.tokenizer.cls_token_id,
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id,
self.tokenizer.unk_token_id,
}
if self.mlm_with_gather:
self.features.extend(["masked_lm_positions", "masked_lm_mask"])
self.pack_sequences = dataset_params.pop(
"pack_sequences", False if self.is_multimodal else True
)
self.image_token = dataset_params.pop("image_token", "<image_token>")
self.image_dir = params["setup"].pop("image_dir", None)
self.max_num_img = dataset_params.pop("max_num_img", 1)
self.num_patches = dataset_params.pop("num_patches", 1)
self.image_token_id = -1
if (
self.is_multimodal
and self.image_token
and self.image_token not in self.tokenizer.get_vocab()
):
self.tokenizer.add_special_tokens(
{'additional_special_tokens': [self.image_token]}
)
self.image_token_id = self.tokenizer.convert_tokens_to_ids(
self.image_token
)
self.image_ids = (
[pad_id] * self.num_patches if self.is_multimodal else []
)
self.semantic_loss_weight = processing_params.pop(
"semantic_loss_weight", {}
)
self.semantic_drop_mask = processing_params.pop(
"semantic_drop_mask", {}
)
self.semantic_attention_mask = processing_params.pop(
"semantic_attention_mask", {}
)
self.include_image_tag = False
self.data_ranges = []
self.eos_token = (
self.tokenizer.pad_token_id
if self.eos_id is None
else self.tokenizer.convert_ids_to_tokens(self.eos_id)
)
def create_features_pretraining(
self,
doc,
token_modality_idx=None,
):
input_ids = doc.get("input_ids")
total_len = len(input_ids)
if total_len < self.min_sequence_len:
self.logger.warning(
"Length of token ids < min_sequence_len, skipping this example..."
)
return []
if not self.is_multimodal and self.rng.random() < self.short_seq_prob:
input_ids = input_ids[
0 : self.rng.randint(2, self.max_seq_length - 1)
]
input_mask = input_mask[0 : len(input_ids)]
attention_mask = attention_mask[0 : len(input_ids)]
input_mask, attention_mask = doc.get("loss_mask"), doc.get(
"attention_mask"
)
labels = input_ids[1:]
input_ids = input_ids[:-1]
attention_mask = attention_mask[:-1]
input_mask = input_mask[1:]
# Add padding
num_pad = self.max_seq_length - len(input_ids)
padding = [self.pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
padding = [0] * num_pad
input_mask.extend(padding)
attention_mask.extend([0] * num_pad)
assert (
len(input_ids) == self.max_seq_length
and len(labels) == self.max_seq_length
and len(input_mask) == self.max_seq_length
and len(attention_mask) == self.max_seq_length
), "Wrong sequence length"
# Create features dictionary
features = {
"input_ids": getattr(np, self.input_ids_dtype)(input_ids),
"labels": getattr(np, self.input_ids_dtype)(labels),
}
input_mask = getattr(np, self.input_mask_dtype)(input_mask)
attention_mask = getattr(np, self.input_ids_dtype)(attention_mask)
if self.inverted_mask:
input_mask = np.equal(input_mask, 0).astype(self.input_mask_dtype)
# NOTE this is because our internal stack requires the inverted mask and
# doesn't do the inversion internally
if self.is_multimodal:
key_padding_mask = np.equal(attention_mask, 0).astype(
input_mask.dtype
)
return (
np.stack(
[
features["input_ids"],
input_mask,
features["labels"],
key_padding_mask,
token_modality_idx,
]
)
if self.is_multimodal
else np.stack(
[
features["input_ids"],
input_mask,
features["labels"],
]
)
)
[docs] def create_features_auto_lm(
self,
token_ids: List[int],
) -> np.ndarray:
"""Given a list of token_ids, generate input sequence and labels.
Args:
token_ids (List[int]): List containing token ids for creating features,
labels and input mask from.
Returns:
np.ndarray: Array containing features, labels, and input mask.
"""
if not len(token_ids) >= self.min_sequence_len:
self.logger.warning(
f"token_ids must have at least {self.min_sequence_len} elements, skipping this example..."
)
return []
if self.rng.random() < self.short_seq_prob:
token_ids = token_ids[
0 : self.rng.randint(2, self.max_seq_length - 1)
]
input_ids = token_ids[:-1]
labels = token_ids[1:]
input_mask = [1] * len(labels)
# padding
num_pad = self.max_seq_length - len(input_ids)
padding = [self.pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
input_mask.extend([0] * num_pad)
# assertions to ensure correct output shapes
assert (
len(input_ids) == self.max_seq_length
and len(labels) == self.max_seq_length
and len(input_mask) == self.max_seq_length
), "Wrong sequence length"
# create feature dict
features = dict()
features["input_ids"] = getattr(np, self.input_ids_dtype)(input_ids)
features["input_mask"] = getattr(np, self.input_mask_dtype)(input_mask)
if self.inverted_mask:
features["input_mask"] = np.equal(features["input_mask"], 0).astype(
features["input_mask"].dtype
)
labels = getattr(np, self.input_ids_dtype)(labels)
return np.stack([features["input_ids"], features["input_mask"], labels])
def chop_doc_into_msl(self, data):
doc_list = []
tokenized_data = data.get("tokenized_data")
tokenized_semantic_regions = data.get("tokenized_semantic_regions")
image_paths = data.get("image_paths", [])
image_index = 0
max_len = self.max_seq_length + 1 # Including space for EOS token
last_chunk_tokens = []
last_chunk_loss_mask = []
last_chunk_attention_mask = []
last_chunk_img_paths = []
last_chunk_has_img = False
last_chunk_image_positions = []
if self.pack_sequences and self.prefix_doc:
last_chunk_tokens = self.prefix_doc['input_ids']
last_chunk_loss_mask = self.prefix_doc['loss_mask']
last_chunk_attention_mask = self.prefix_doc['attention_mask']
last_chunk_img_paths = self.prefix_doc['image_paths']
last_chunk_has_img = self.prefix_doc['has_img']
last_chunk_image_positions = self.prefix_doc['image_data_positions']
self.prefix_doc = None
for idx, region in enumerate(tokenized_semantic_regions):
modality = region["region_modality"]
loss_weight = region["loss_weight"]
attn_mask_value = region["attention_mask"]
start_idx, end_idx = region["indices"]
orig_idx = (start_idx, end_idx)
start_idx = 0 if idx == 0 else start_idx
tokenized_semantic_regions[idx]['indices'] = (start_idx, end_idx)
tokens = tokenized_data["input_ids"][start_idx:end_idx]
region_len = len(tokens)
# Generate loss_mask and attention_mask for the region
region_loss_mask = [loss_weight] * region_len
region_attention_mask = [attn_mask_value] * region_len
if modality != "image":
# Combine last_chunk_* variables with current tokens
tokens = last_chunk_tokens + tokens
region_loss_mask = last_chunk_loss_mask + region_loss_mask
region_attention_mask = (
last_chunk_attention_mask + region_attention_mask
)
region_len = len(
tokens
) # Update region_len after concatenation
# Split text region into chunks fitting max_len
chunks = [
(
tokens[i : i + max_len],
region_loss_mask[i : i + max_len],
region_attention_mask[i : i + max_len],
)
for i in range(0, region_len, max_len)
]
# Determine the number of complete chunks
num_chunks = len(chunks)
if len(chunks[-1][0]) < max_len:
num_complete_chunks = num_chunks - 1
else:
num_complete_chunks = num_chunks
# Process complete chunks
for idx in range(num_complete_chunks):
chunk_tokens, chunk_loss_mask, chunk_attention_mask = (
chunks[idx]
)
if idx == 0 and last_chunk_tokens:
assert len(chunk_tokens) == len(
chunk_loss_mask
), "Length of input ids and loss is different"
# First chunk may have image data from previous last_chunk
doc_list.append(
{
"input_ids": chunk_tokens,
"loss_mask": chunk_loss_mask,
"attention_mask": chunk_attention_mask,
"image_paths": last_chunk_img_paths,
"has_img": last_chunk_has_img,
"image_data_positions": last_chunk_image_positions,
}
)
# Reset image data after first chunk
last_chunk_img_paths = []
last_chunk_has_img = False
last_chunk_image_positions = []
else:
assert len(chunk_tokens) == len(
chunk_loss_mask
), "Length of input ids and loss is different"
# Subsequent chunks without image data
doc_list.append(
{
"input_ids": chunk_tokens,
"loss_mask": chunk_loss_mask,
"attention_mask": chunk_attention_mask,
"image_paths": [], # No images in subsequent chunks
"has_img": False,
"image_data_positions": [],
}
)
# Update last_chunk_* variables for the next iteration
if num_complete_chunks < num_chunks:
# The last chunk is incomplete; store it for the next iteration
last_chunk_tokens = chunks[-1][0]
last_chunk_loss_mask = chunks[-1][1]
last_chunk_attention_mask = chunks[-1][2]
# last_chunk_img_paths and last_chunk_has_img remain reset
else:
# All chunks are complete; reset last_chunk_* variables
last_chunk_tokens = []
last_chunk_loss_mask = []
last_chunk_attention_mask = []
# last_chunk_img_paths and last_chunk_has_img remain reset
else:
# Handle image region
image_path = image_paths[image_index]
image_index += 1
image_tokens = tokens # Image tokens should not be split
image_loss_mask = region_loss_mask
image_attention_mask = region_attention_mask
image_len = len(image_tokens)
combined_len = len(last_chunk_tokens) + image_len
if combined_len < max_len - 1:
# Add image tokens to last_chunk
start_idx = len(last_chunk_tokens)
last_chunk_tokens.extend(image_tokens)
last_chunk_loss_mask.extend(image_loss_mask)
last_chunk_attention_mask.extend(image_attention_mask)
end_idx = len(last_chunk_tokens)
last_chunk_img_paths += [image_path]
last_chunk_has_img = True
image_indices = (
orig_idx if idx == 0 else (start_idx, end_idx)
)
last_chunk_image_positions.append(image_indices)
else:
# Finalize last_chunk
if last_chunk_tokens:
assert len(last_chunk_tokens) == len(
last_chunk_loss_mask
), "Length of input ids and loss is different"
doc_list.append(
{
"input_ids": last_chunk_tokens,
"loss_mask": last_chunk_loss_mask,
"attention_mask": last_chunk_attention_mask,
"image_paths": last_chunk_img_paths,
"has_img": last_chunk_has_img,
"image_data_positions": last_chunk_image_positions,
}
)
# Start new last_chunk with image tokens if they fit
if image_len < max_len - 1:
last_chunk_tokens = image_tokens
last_chunk_loss_mask = image_loss_mask
last_chunk_attention_mask = image_attention_mask
last_chunk_img_paths = [image_path]
last_chunk_has_img = True
last_chunk_image_positions = [(0, image_len)]
else:
# Image tokens exceed max_len; cannot split images
raise ValueError(
"Image tokens exceed maximum sequence length."
)
# Append any remaining last_chunk to doc_list
if last_chunk_tokens:
assert len(last_chunk_tokens) == len(
last_chunk_loss_mask
), "Length of input ids and loss is different"
doc = {
"input_ids": last_chunk_tokens,
"loss_mask": last_chunk_loss_mask,
"attention_mask": last_chunk_attention_mask,
"image_paths": last_chunk_img_paths,
"has_img": last_chunk_has_img,
"image_data_positions": last_chunk_image_positions,
}
if not self.pack_sequences:
doc_list.append(doc)
else:
self.prefix_doc = doc
return doc_list
[docs] def get_segment_indices(
self,
tokenized_data: Dict[str, int],
semantic_region_list: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Get segment indices for the data ranges.
Args:
tokenized_data (List[Tuple[int, int]]): Tokenized data with offset mappings.
text_semantic_region_list (List[Dict[str, Any]]): List of text semantic regions with region details.
Returns:
List[Dict[str, Any]]: List of tokenized semantic regions and image regions with their indices.
"""
tokenized_semantic_region_list = []
tokenized_semantic_region = None
starting_offset_index = 0
for region in semantic_region_list:
region_name = region.get("region_modality")
tokenized_semantic_region = find_token_range(
region,
tokenized_data["offset_mapping"],
starting_offset_index,
)
tokenized_semantic_region["region_modality"] = region_name
starting_offset_index = tokenized_semantic_region["indices"][1]
tokenized_semantic_region_list.append(tokenized_semantic_region)
return tokenized_semantic_region_list
[docs] def get_allowable_token_ids(self) -> List[int]:
"""Generate a list of token IDs that can be masked."""
excluded_token_ids = {
self.tokenizer.convert_tokens_to_ids(tok)
for tok in self.excluded_tokens
if tok in self.tokenizer.get_vocab()
}
allowable_token_ids = [
tok_id
for tok, tok_id in self.tokenizer.get_vocab().items()
if tok_id not in excluded_token_ids
]
return list(allowable_token_ids)
[docs] def mask_single_sequence(
self, input_ids: List[int]
) -> Tuple[List[int], List[int], List[int], List[int]]:
"""
Masks tokens in a single sequence according to the MLM strategy.
When self.mlm_with_gather is False, the returning len(labels) == len(input_ids)
When self.mlm_with_gather is True, the returning len(labels) == self.max_predictions
Args:
input_ids (List[int]): Original sequence of token IDs.
Returns:
Tuple[List[int], List[int], List[int], List[int]]:
- input_ids: Modified sequence with masked tokens.
- masked_lm_positions: Positions of the masked tokens, empty if not self.mlm_with_gather.
- masked_lm_mask: Binary indicators (1s) for positions that were masked, empty if not self.mlm_with_gather.
- labels: Original token IDs of the masked tokens for label purposes.
"""
sequence = np.array(input_ids.copy())
masked_lm_positions = []
masked_lm_mask = []
labels = (
[] if self.mlm_with_gather else [self.ignore_index] * len(input_ids)
)
indices_can_be_masked = [
i
for i, token_id in enumerate(input_ids)
if token_id not in self.special_tokens_ids
]
# Calculate the number of tokens to mask
num_tokens_to_mask = min(
int(self.mlm_fraction * len(indices_can_be_masked)),
self.max_predictions,
)
if num_tokens_to_mask > 0:
# Randomly select tokens to mask
indices_to_mask = sorted(
self.rng.sample(indices_can_be_masked, k=num_tokens_to_mask)
)
for pos in indices_to_mask:
original_token_id = sequence[pos].copy()
prob = self.rng.random()
if prob < 0.8: # 80% of the time, replace with [MASK]
sequence[pos] = self.tokenizer.mask_token_id
elif prob < 0.9: # 10% of the time, replace with a random token
# Ensure selected token is not a special token
masked_token_id = np.random.choice(self.allowable_token_ids)
sequence[pos] = masked_token_id
elif prob <= 1.0:
pass # 10% of the time, keep the original token
# Store the original token ID as the label
if self.mlm_with_gather:
masked_lm_positions.append(pos)
masked_lm_mask.append(1)
labels.append(original_token_id)
else:
labels[pos] = original_token_id
if self.mlm_with_gather:
# Pad the lists to reach max_predictions length
num_paddings = self.max_predictions - len(masked_lm_positions)
masked_lm_positions = masked_lm_positions + [0] * num_paddings
masked_lm_mask = masked_lm_mask + [0] * num_paddings
labels = labels + [self.ignore_index] * num_paddings
return list(sequence), masked_lm_positions, masked_lm_mask, labels
[docs] def process_chunks(
self, tokenized_text_chunks: List[List[int]]
) -> Tuple[List[np.ndarray], Dict[str, int]]:
"""
Processes chunks of tokenized text and returns processed features along with the total padding added.
Args:
tokenized_text_chunks (List[List[int]]): A list of tokenized text chunks, where each chunk is represented as a list of integers.
Returns:
Tuple[List[np.ndarray], Dict[str, int]]: A tuple containing a list of processed results and dataset stats.
"""
results = {"data": []} # List to store the processed results
stats = defaultdict(int)
# Iterate over each chunk in the tokenized text chunks
for chunk in tokenized_text_chunks:
# Process the chunk and get the processed result and number of padding tokens added
processed = self.create_features_auto_lm(
chunk,
)
# If the processed chunk is not empty, add the results to the list and update the total padding
if len(processed) != 0:
processed_stats = get_data_stats(
processed, self.pad_id, self.eos_id, self.max_seq_length
)
for key in processed_stats:
stats[key] += processed_stats[key]
results["data"].append(processed)
# Return the list of processed results and data stats
return results, stats
[docs] def process_chunks_mlm(
self, tokenized_text_chunks: List[List[int]]
) -> Tuple[List[Any], Dict]:
"""
Processes chunks of tokenized text and returns processed features along with the total padding added.
Args:
tokenized_text_chunks (List[List[int]]): A list of tokenized text chunks, where each chunk is represented as a list of integers.
Returns:
Tuple[List[Any], Dict]: A tuple containing a list of processed results and dataset stats.
"""
results = {
'data': [],
'labels': [],
} # List to store the processed result
stats = defaultdict(int)
masked_lm_positions_list = []
masked_lm_mask_list = []
input_id_list = []
labels_list = []
attention_mask_list = []
# Iterate over each chunk in the tokenized text chunks
for chunk in tokenized_text_chunks:
input_ids, masked_lm_positions, masked_lm_mask, labels = (
self.mask_single_sequence(chunk)
)
num_pad = self.max_seq_length - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * num_pad
input_ids = input_ids + [self.pad_id] * num_pad
input_id_list.append(input_ids)
attention_mask_list.append(attention_mask)
labels_list.append(labels)
masked_lm_positions_list.append(masked_lm_positions)
masked_lm_weights_list.append(masked_lm_weights)
lvt = len(labels) - labels.count(self.ignore_index)
processed_stats = get_data_stats(
np.expand_dims(np.array(input_ids), 0),
self.pad_id,
self.eos_id,
self.max_seq_length,
lvt,
)
for key in processed_stats:
stats[key] += processed_stats[key]
if len(tokenized_text_chunks) > 0:
results['data'] = np.stack(
[np.array(input_id_list), np.array(attention_mask_list)], axis=1
)
if self.mlm_with_gather:
results['labels'] = np.stack(
[
np.array(labels_list),
np.array(masked_lm_positions_list),
np.array(masked_lm_weights_list),
],
axis=1,
)
else:
results['labels'] = np.stack(
[np.array(labels_list)],
axis=1,
)
# Return the list of processed results and data stats
return results, stats
def process_mlm(self, text_data, raw_data_stats):
tokenized_data = self.tokenizer(
text_data,
max_length=self.max_seq_length,
truncation=True,
padding='max_length',
return_attention_mask=True,
)
input_ids, attention_mask = (
tokenized_data['input_ids'],
tokenized_data['attention_mask'],
)
tokenized_data_stats = dict()
results = dict()
tokenized_data_stats["processed"] = 1
tokenized_data_stats["successful"] = 0
if input_ids == []:
tokenized_data_stats["discarded"] = 1
return {"data": [], "labels": []}, tokenized_data_stats
tokenized_data_stats["successful"] = 1
input_ids, masked_lm_positions, masked_lm_mask, labels = (
self.mask_single_sequence(input_ids)
)
results['data'] = np.stack(
[np.array(input_ids), np.array(attention_mask)], axis=0
).reshape(1, 2, self.max_seq_length)
if self.mlm_with_gather:
results['labels'] = np.stack(
[
np.array(labels),
np.array(masked_lm_positions),
np.array(masked_lm_mask),
],
axis=0,
).reshape(1, 3, self.max_predictions)
else:
results['labels'] = np.stack(
[
np.array(labels),
],
axis=0,
).reshape(1, 1, self.max_seq_length)
tokenized_data_stats["non_pad_tokens"] = sum(
1 for id in input_ids if id != self.pad_id
)
tokenized_data_stats["num_pad_tokens"] = (
self.max_seq_length - tokenized_data_stats["non_pad_tokens"]
)
tokenized_data_stats["num_tokens"] = self.max_seq_length
tokenized_data_stats["num_masked_tokens"] = input_ids.count(
self.tokenizer.mask_token_id
)
tokenized_data_stats["loss_valid_tokens"] = len(labels) - labels.count(
self.ignore_index
)
tokenized_data_stats.update(raw_data_stats)
return results, tokenized_data_stats
def process_single_semantic_region(self, text_data, raw_data_stats):
discarded_files = 0
# tokenize text
if self.split_text_to_tokenize:
tokenized_text = split_text_and_tokenize(
text_data,
self.tokenizer,
max_tok_len=self.chunk_len_to_split,
remove_bos_in_chunks=self.remove_bos_in_chunks,
)
else:
tokenized_text = self.tokenizer.encode(text_data)
if self.eos_id is not None:
tokenized_text += [self.eos_id]
all_text = self.prefix + tokenized_text
tokenized_text_chunks = [
all_text[i : i + self.max_seq_length + 1]
for i in range(0, len(all_text), self.max_seq_length)
]
# reset prefix
self.prefix = []
# update prefix if last chunk is < max_seq_length
num_tokens_last_chunk = len(tokenized_text_chunks[-1])
if self.pack_sequences:
if num_tokens_last_chunk < self.max_seq_length + 1:
last_chunk = tokenized_text_chunks.pop(-1)
self.prefix.extend(last_chunk)
elif num_tokens_last_chunk < 2:
_ = tokenized_text_chunks.pop(-1)
results, tokenized_data_stats = self.process_chunks(
tokenized_text_chunks
)
if results["data"] == [] and self.prefix == []:
discarded_files += 1
tokenized_data_stats["discarded"] = discarded_files
tokenized_data_stats["processed"] = 1
tokenized_data_stats["successful"] = (
tokenized_data_stats["processed"]
- tokenized_data_stats["discarded"]
)
tokenized_data_stats.update(raw_data_stats)
return results, tokenized_data_stats
def tokenize_data(self, semantic_data_array):
region_data, raw_data_stats = self.parse_semantic_data_array(
semantic_data_array
)
if not region_data:
return {}, raw_data_stats
semantic_regions = region_data.get("semantic_regions")
data, image_paths = (
region_data.get("data"),
region_data.get("image_paths"),
)
if data == "":
return {}, raw_data_stats
if self.mlm:
return self.process_mlm(data, raw_data_stats)
if (self.split_text_to_tokenize) and not self.is_multimodal:
return self.process_single_semantic_region(data, raw_data_stats)
if self.split_text_to_tokenize:
raise ValueError(
f"Multiple semantic region is not supported with `split_text_to_tokenize`."
)
tokenized_data = self.tokenizer(
data,
return_offsets_mapping=True,
)
if self.is_multimodal:
# Convert input_ids to numpy array
input_ids = np.array(tokenized_data['input_ids'])
# Replace image_token_id with pad_id
input_ids[input_ids == self.image_token_id] = self.pad_id
# Convert back to list
tokenized_data['input_ids'] = input_ids.tolist()
if len(semantic_regions) > 0:
append_eos_to_multiple_semantic_regions(
data,
self.data_ranges,
self.eos_token,
self.image_token,
False,
)
tokenized_semantic_region_list = self.get_segment_indices(
tokenized_data,
semantic_regions,
)
data = {
"tokenized_data": tokenized_data,
"image_paths": image_paths,
"tokenized_semantic_regions": tokenized_semantic_region_list,
}
doc_list = self.chop_doc_into_msl(data)
results, tokenized_data_stats = self.process_docs(doc_list)
data_stats = {
"discarded": tokenized_data_stats["discarded"],
"processed": tokenized_data_stats["processed"],
"successful": tokenized_data_stats["successful"],
"raw_chars_count": raw_data_stats["raw_chars_count"],
"raw_bytes_count": raw_data_stats["raw_bytes_count"],
"normalized_chars_count": raw_data_stats["normalized_chars_count"],
"normalized_bytes_count": raw_data_stats["normalized_bytes_count"],
"num_pad_tokens": tokenized_data_stats["num_pad_tokens"],
"non_pad_tokens": tokenized_data_stats["non_pad_tokens"],
"num_masked_tokens": tokenized_data_stats["num_masked_tokens"],
"loss_valid_tokens": tokenized_data_stats["loss_valid_tokens"],
"num_tokens": tokenized_data_stats["num_tokens"],
}
return results, data_stats
def parse_semantic_data_array(
self, semantic_data_array: List[Dict[str, Any]]
) -> Tuple[Tuple[List[str], List[Dict[str, str]]], Dict[str, int]]:
if not semantic_data_array:
return {}, {}
image_paths = []
semantic_regions = []
stats = {
"raw_chars_count": 0,
"raw_bytes_count": 0,
"normalized_chars_count": 0,
"normalized_bytes_count": 0,
}
formatted_data = ""
formatted_data_length = 0
for entry in semantic_data_array:
semantic_loss_weight = entry.get("semantic_loss_weight")
semantic_drop_mask = entry.get("semantic_drop_mask")
semantic_attention_mask = entry.get("semantic_attention_mask")
if semantic_loss_weight is not None and len(
semantic_loss_weight
) != len(entry["content"]):
raise ValueError(
"The length of semantic loss mask must match the number of regions"
)
if semantic_drop_mask is not None and len(
semantic_drop_mask
) != len(entry["content"]):
raise ValueError(
"The length of semantic drop mask must match the number of regions"
)
if semantic_attention_mask is not None and len(
semantic_attention_mask
) != len(entry["content"]):
raise ValueError(
"The length of semantic attention mask must match the number of regions"
)
content_parts = []
global_idx = 0
for i, part in enumerate(entry["content"]):
region_key = list(part.keys())[0]
region_val = part[region_key]
if not region_val:
continue
if region_key != "image":
cleaned_region_val = clean_text(
region_val,
self.use_ftfy,
self.wikitext_detokenize,
self.ftfy_normalizer,
)
stats["raw_chars_count"] += len(region_val)
stats["raw_bytes_count"] += len(region_val.encode("utf-8"))
stats["normalized_chars_count"] += len(cleaned_region_val)
stats["normalized_bytes_count"] += len(
cleaned_region_val.encode("utf-8")
)
else:
cleaned_region_val = region_val
include_tags = part.pop("include_tags", False)
if not semantic_loss_weight:
loss_weight = self.semantic_loss_weight.get(region_key)
if not loss_weight:
## set default weights
loss_weight = 1 if region_key != "image" else 0
else:
loss_weight = semantic_loss_weight[i]
if not semantic_drop_mask:
drop_region = self.semantic_drop_mask.get(region_key, False)
else:
drop_region = semantic_drop_mask[i]
if not semantic_attention_mask:
attention_mask = self.semantic_attention_mask.get(
region_key, True
)
else:
attention_mask = semantic_attention_mask[i]
attention_mask = 1 if attention_mask else 0
if region_key != "image": ## hardcoding name of image
if not drop_region and cleaned_region_val != "":
if include_tags:
cleaned_region_val = (
f"<|{region_key}|>"
+ cleaned_region_val
+ f"<|{region_key}|>"
)
semantic_regions.append(
{
"indices": (
formatted_data_length,
formatted_data_length
+ len(cleaned_region_val),
),
"region_modality": region_key,
"region_len": len(cleaned_region_val),
"loss_weight": loss_weight,
"attention_mask": attention_mask,
}
)
formatted_data_length += len(cleaned_region_val)
content_parts.append(cleaned_region_val)
else:
if not drop_region:
image_paths.append(cleaned_region_val)
# Store the pad id for image region and handle `include_tags`
patches = self.num_patches * [self.image_token]
patches = ''.join(patches)
if include_tags:
patches = (
f"<|{region_key}|>"
+ patches
+ f"<|{region_key}|>"
)
patch_len = len(patches)
semantic_regions.append(
{
"indices": (
formatted_data_length,
formatted_data_length + patch_len,
),
"region_modality": region_key,
"loss_weight": loss_weight,
"attention_mask": attention_mask,
}
)
formatted_data_length += patch_len
content_parts.append(patches)
global_idx += 1
formatted_data += ''.join(content_parts)
if (
self.is_multimodal
and self.max_num_img
and len(image_paths) > self.max_num_img
):
self.logger.warning(
f"Document more images than max_num_img. Skipping this doc..."
)
stats["raw_docs_skipped"] = 1
return {}, stats
# Validate image paths
for i, path in enumerate(image_paths):
if path:
full_path = os.path.join(self.image_dir, path)
if not os.path.exists(full_path):
self.logger.warning(
f"Image with path - {full_path} does not exist. Hence skipping this."
)
stats["raw_docs_skipped"] = 1
return {}, stats
else:
image_paths[i] = path.encode(encoding='utf-8')
transformed_data = {
"data": formatted_data,
"image_paths": image_paths,
"semantic_regions": semantic_regions,
}
return transformed_data, stats
def process_docs(self, doc_list):
results = defaultdict(list)
tokenized_data_stats = defaultdict(int)
if doc_list == []:
tokenized_data_stats["processed"] += 1
if self.prefix_doc is None:
tokenized_data_stats["discarded"] += 1
else:
tokenized_data_stats["successful"] += 1
return {}, tokenized_data_stats
# Add eos at the end.
# TODO: Find a better way to handle this?
last_doc = doc_list[-1]
if last_doc.get("input_ids", []) != [] and self.eos_id != None:
if last_doc['input_ids'][-1] != self.eos_id:
if len(last_doc['input_ids']) < self.max_seq_length + 1:
last_doc['input_ids'].append(self.eos_id)
last_doc['loss_mask'].append(last_doc['loss_mask'][-1])
last_doc['attention_mask'].append(
last_doc['attention_mask'][-1]
)
else:
last_doc['input_ids'][-1] = self.eos_id
for doc_idx, doc in enumerate(doc_list):
has_img = False
if doc.get("input_ids", []) == []:
continue
token_modality_idx = (
np.zeros(self.max_seq_length) if self.is_multimodal else None
)
image_paths, image_data_positions = doc.pop(
"image_paths", None
), doc.pop("image_data_positions", None)
has_img = doc.pop("has_img", None)
img_data_loc = None
if self.is_multimodal:
img_data_loc = np.full(
(self.max_num_img, self.num_patches), self.max_seq_length
)
assert (
len(image_data_positions) <= self.max_num_img
), "Number of images should be <= max_num_images"
# Preallocate img_data_loc as a list of arrays to avoid dynamic resizing
for image_index, (start_img_pos, end_img_pos) in enumerate(
image_data_positions
):
img_data_loc[image_index] = np.arange(
start_img_pos, end_img_pos
)
# Efficiently update the token_modality_idx using vectorized assignment
token_modality_idx[start_img_pos:end_img_pos] = 1
sample = self.create_features_pretraining(
doc,
token_modality_idx,
)
if sample == []:
continue
if self.is_multimodal:
if image_paths:
num_images = len(image_paths)
image_paths += [None] * (self.max_num_img - num_images)
has_img = True
else:
image_paths = [None] * (self.max_num_img)
sample_stats = get_data_stats(
sample, self.pad_id, self.eos_id, self.max_seq_length
)
for key in sample_stats:
tokenized_data_stats[key] += sample_stats[key]
data = (
{
"data": sample,
"img_path": np.array(image_paths, dtype="S"),
"has_img": np.array([has_img], dtype=np.bool_),
"img_data_loc": img_data_loc,
}
if self.is_multimodal
else {
"data": sample,
}
)
for key, value in data.items():
results[key].append(value)
tokenized_data_stats["processed"] += 1
if results.get("data", []) == []:
tokenized_data_stats["discarded"] += 1
else:
tokenized_data_stats["successful"] += 1
return results, tokenized_data_stats
[docs] def encode(
self, semantic_data_array: List[Dict[str, Any]]
) -> Tuple[Dict[str, Any], Dict[str, int]]:
"""
Tokenize and encode the data for auto-regressive language modeling.
Args:
semantic_data_array (Union[Dict[str, Any], List[Dict[str, Any]]]): Data to encode.
Returns:
Tuple[Dict[str, Any], Dict[str, int]]: Tuple of encoded features for auto-regressive language modeling and dataset stats.
"""
tokenized_data, data_stats = self.tokenize_data(semantic_data_array)
if tokenized_data.get("data", []) == []:
return {}, data_stats
else:
if self.is_multimodal:
data = tokenized_data
else:
if not self.mlm:
data = {'data': tokenized_data['data']}
else:
data = {'data': tokenized_data['data']}
if not self.use_vsl:
data['labels'] = tokenized_data['labels']
return data, data_stats
[docs] def encode_leftover_prefix(
self, prefix: List[np.ndarray]
) -> Tuple[Dict[str, Any], Dict[str, int]]:
"""
Processes the leftover prefix which is a list of ndarray tokens into chunks based
on max sequence length.
The last chunk is handled specifically if it's shorter than the max sequence
length. If the last chunk has less than two tokens, it's discarded.
Args:
prefix (List[np.ndarray]): The prefix list of token arrays to process.
Returns:
Tuple[Dict[str, Any], Dict[str, int]]: A tuple containing the processed token chunks as
a list of ndarrays and the dataset stats.
"""
if (self.split_text_to_tokenize or self.mlm) and not self.is_multimodal:
tokenized_text_chunks = (
[
prefix[i : i + self.max_seq_length]
for i in range(0, len(prefix), self.max_seq_length)
]
if self.mlm
else [
prefix[i : i + self.max_seq_length + 1]
for i in range(0, len(prefix), self.max_seq_length)
]
)
# Handle last chunk if shorter than max_seq_length
num_tokens_last_chunk = len(tokenized_text_chunks[-1])
if num_tokens_last_chunk < self.max_seq_length + 1:
_ = tokenized_text_chunks.pop(-1)
elif num_tokens_last_chunk < 2:
_ = tokenized_text_chunks.pop(-1)
results, stats = self.process_chunks(tokenized_text_chunks)
if results["data"] == []:
return {}, stats
data = results
return data, stats
# Handle prefix doc
if not prefix:
return {}, {}
doc_list = prefix
results, tokenized_data_stats = self.process_docs(doc_list)
if results == {}:
return {}, {}
data_stats = {
"num_pad_tokens": tokenized_data_stats["num_pad_tokens"],
"non_pad_tokens": tokenized_data_stats["non_pad_tokens"],
"num_masked_tokens": tokenized_data_stats["num_masked_tokens"],
"loss_valid_tokens": tokenized_data_stats["loss_valid_tokens"],
"num_tokens": tokenized_data_stats["num_tokens"],
}
return results, data_stats