# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides the VSLPretrainingTokenGenerator class, extending
PretrainingTokenGenerator for advanced processing of tokenized text data tailored
for variable-length sequence language modeling (VSLLM). Includes methods for
processing chunks of tokenized text, optimizing representation of tokenized
data by merging shorter sequences within a specified maximum sequence length,
and tokenizing text for auto-regressive language modeling.
"""
from typing import Any, Dict, List, Tuple
import numpy as np
from cerebras.modelzoo.data_preparation.data_preprocessing.pretraining_token_generator import (
PretrainingTokenGenerator,
)
from cerebras.modelzoo.data_preparation.data_preprocessing.utils import (
create_features_auto_lm_vsl,
)
[docs]class VSLPretrainingTokenGenerator(PretrainingTokenGenerator):
"""
Processes tokenized text data, specifically for VSLLM. Extends
PretrainingTokenGenerator by handling text tokenization, feature creation,
and optimizing representation of tokenized data for language modeling tasks.
Attributes:
use_vsl (bool): Usage of variable sequence length logic.
fold_long_doc (bool): Whether to fold long documents.
position_ids_dtype (str): Data type for position IDs in tokenized output.
Args:
params (dict): Parameters for the dataset and model.
tokenizer: Tokenizer instance for text tokenization.
eos_id (int): End-of-sequence ID.
pad_id (int): Padding ID.
"""
use_vsl = True
def __init__(
self, params: Dict[str, Any], tokenizer: Any, eos_id: int, pad_id: int
):
"""
Initialize VSLPretrainingTokenGenerator with dataset parameters,
tokenizer, and token IDs.
"""
super().__init__(params, tokenizer, eos_id, pad_id)
self.fold_long_doc = params["dataset"].pop("fold_long_doc", True)
self.position_ids_dtype = params["dataset"].pop(
"position_ids_dtype", "int32"
)
self.pack_sequences = False
self.sample_features = [
"input_ids",
"attention_mask",
"labels",
"attention_span",
"position_ids",
]
[docs] def create_features_vsl_mlm(
self,
bin,
):
"""Given a list of VSL sequences, generate input features and labels.
Args:
bin (list(sequence)): list of VSL sequences.
Returns:
Tuple containing features and labels
"""
input_ids, attention_span, position_ids = [], [], []
for sample in bin:
input_ids.extend(sample)
attention_span.extend(list(range(len(sample) - 1, -1, -1)))
position_ids.extend(list(range(len(sample))))
input_ids, masked_lm_positions, masked_lm_mask, labels = (
self.mask_single_sequence(input_ids)
)
num_pad = self.max_seq_length - len(input_ids)
input_mask = [1] * len(input_ids) + [0] * num_pad
attention_span = attention_span + [0] * num_pad
position_ids = position_ids + [0] * num_pad
input_ids = input_ids + [self.pad_id] * num_pad
if not self.mlm_with_gather:
labels = labels + [self.ignore_index] * num_pad
# assertions to ensure correct output shapes
assert (
len(input_ids) == self.max_seq_length
and len(input_mask) == self.max_seq_length
and len(attention_span) == self.max_seq_length
and len(position_ids) == self.max_seq_length
), "Wrong sequence length"
if self.inverted_mask:
input_mask = [elem == 0 for elem in input_mask]
return (
input_ids,
input_mask,
attention_span,
position_ids,
masked_lm_positions,
masked_lm_mask,
labels,
)
def process_chunks(
self, tokenized_data: List[List[Any]]
) -> Tuple[List[Any], dict]:
if self.mlm:
return self.process_chunks_mlm(tokenized_data)
else:
return self.process_chunks_nextwordpred(tokenized_data)
[docs] def process_chunks_mlm(
self, tokenized_data: List[List[Any]]
) -> Tuple[List[Any], dict]:
"""
Processes chunks of tokenized text and returns processed features
along with statistics about padding and tokens.
Args:
tokenized_data (List[List[Any]]): Tokenized text chunks as a list.
Returns:
Tuple[List[Any], dict]: Processed results and statistics.
"""
results = {
"data": [],
"labels": [],
} # List to store the processed results
stats = {
"loss_valid_tokens": 0,
"num_tokens": 0,
"num_pad_tokens": 0,
"non_pad_tokens": 0,
"num_masked_tokens": 0,
"processed": 0,
"discarded": 0,
"successful": 0,
}
input_id_list = []
input_mask_list = []
attention_span_list = []
position_id_list = []
labels_list = []
masked_lm_positions_list = []
masked_lm_mask_list = []
results = {"data": [], "labels": []}
for tokenized_text_chunks in tokenized_data:
(
input_ids,
input_mask,
attention_span,
position_ids,
masked_lm_positions,
masked_lm_mask,
labels,
) = self.create_features_vsl_mlm(tokenized_text_chunks)
lvt = len(labels) - labels.count(self.ignore_index)
stats["num_masked_tokens"] += len(labels) - lvt
data = np.stack(
[input_ids, input_mask, attention_span, position_ids], axis=0
).reshape(-1, 4, self.max_seq_length)
if self.mlm_with_gather:
labels = np.stack(
[labels, masked_lm_positions, masked_lm_mask], axis=0
).reshape(-1, 3, self.max_predictions)
else:
labels = np.stack([labels], axis=0).reshape(
-1, 1, self.max_seq_length
)
results["data"].append(data)
results["labels"].append(labels)
pad_index = np.where(np.array(input_ids) == self.pad_id)[0]
p_i = int(pad_index[0] if len(pad_index) > 0 else len(input_ids))
num_pad = self.max_seq_length - p_i
stats["loss_valid_tokens"] += lvt
stats["num_pad_tokens"] += num_pad
stats["non_pad_tokens"] += self.max_seq_length - num_pad
stats["num_tokens"] += self.max_seq_length
stats["processed"] += 1
if results["data"] == []:
stats["discarded"] += 1
data = {}
else:
stats["successful"] += 1
data = results
return data, stats
[docs] def process_chunks_nextwordpred(
self, tokenized_data: List[List[Any]]
) -> Tuple[Dict[str, Any], Dict[str, int]]:
"""
Processes chunks of tokenized text and returns processed features
along with statistics about padding and tokens.
Args:
tokenized_data (List[List[Any]]): Tokenized text chunks as a list.
Returns:
Tuple[Dict[str, Any], Dict[str, int]]: Processed results and statistics.
"""
results = {"data": []} # List to store the processed results
stats = {
"loss_valid_tokens": 0,
"num_tokens": 0,
"num_pad_tokens": 0,
"non_pad_tokens": 0,
"num_masked_tokens": 0,
}
for tokenized_text_chunks in tokenized_data:
eos_len = 1 if self.eos_id is not None else 0
tokenized_text_chunks_len = sum(
(len(one_d_list) - eos_len)
for one_d_list in tokenized_text_chunks
)
num_pad = self.max_seq_length - tokenized_text_chunks_len
processed = create_features_auto_lm_vsl(
tokenized_text_chunks,
self.max_seq_length,
num_pad,
pad_id=self.pad_id,
inverted_mask=self.inverted_mask,
input_ids_dtype=self.input_ids_dtype,
input_mask_dtype=self.input_mask_dtype,
labels_dtype=self.input_ids_dtype,
attention_span_dtype=self.position_ids_dtype,
position_ids_dtype=self.position_ids_dtype,
)
if processed.size != 0:
loss_valid_tokens = int(processed[1, :].sum())
stats["num_pad_tokens"] += num_pad
stats["non_pad_tokens"] += self.max_seq_length - num_pad
stats["num_masked_tokens"] += (
self.max_seq_length - loss_valid_tokens
)
stats["loss_valid_tokens"] += loss_valid_tokens
stats["num_tokens"] += len(processed[0])
processed = np.expand_dims(processed, axis=0)
results["data"].append(processed)
if results["data"] == []:
data = {}
else:
data = results
return data, stats
[docs] def tokenize_data(
self, semantic_data_array: List[Dict[str, Any]]
) -> Tuple[List[np.ndarray], Dict[str, int]]:
"""
Tokenizes the given text and creates features suitable for
auto-regressive language modeling. Handles end-of-sequence addition,
sequence length adjustments, and document folding for long documents.
Args:
semantic_data_dict (Union[Dict[str, Any], List[Dict[str, Any]]]): The data to tokenize.
Returns:
Tuple[List[np.ndarray], Dict[str, int]]: Tokenized and processed text features
and statistics.
"""
text, raw_data_stats = self.parse_semantic_data_array(
semantic_data_array
)
if text == "":
return {"data": []}, raw_data_stats
if self.mlm:
tokenized_data = self.tokenizer(
text,
max_length=self.max_seq_length,
truncation=True,
padding=False,
return_attention_mask=True,
)
input_ids = tokenized_data['input_ids']
return {"data": [input_ids]}, raw_data_stats
# tokenize text
tokenized_text = self.tokenizer.encode(text)
if self.eos_id is not None:
tokenized_text += [self.eos_id]
tokenized_text_len = len(tokenized_text)
if tokenized_text_len < self.min_sequence_len:
raw_data_stats["discarded"] = 1
raw_data_stats["processed"] = 1
raw_data_stats["successful"] = 0
return {"data": []}, raw_data_stats
if self.rng.random() < self.short_seq_prob:
tokenized_text = tokenized_text[
0 : self.rng.randint(2, self.max_seq_length)
]
tokenized_text_len = len(tokenized_text)
if tokenized_text_len > self.max_seq_length + 1:
if not self.fold_long_doc:
raw_data_stats["discarded"] = 1
raw_data_stats["processed"] = 1
raw_data_stats["successful"] = 0
return {"data": []}, raw_data_stats
tokenized_text_chunks = (
[
tokenized_text[i : i + self.max_seq_length]
for i in range(0, len(tokenized_text), self.max_seq_length)
]
if self.mlm
else [
tokenized_text[i : i + self.max_seq_length + 1]
for i in range(0, len(tokenized_text), self.max_seq_length)
]
)
# update prefix if last chunk is < max_seq_length
num_tokens_last_chunk = len(tokenized_text_chunks[-1])
if num_tokens_last_chunk < 2:
_ = tokenized_text_chunks.pop(-1)
return {"data": tokenized_text_chunks}, raw_data_stats
[docs] def append_within_max_length(
self, tokenized_data: Dict[str, List[List[Any]]]
) -> List[List[List[Any]]]:
"""
Optimizes representation of tokenized data by merging shorter sequences
within the specified maximum sequence length. Converts 3D list to a
modified 3D structure where each innermost list is treated as a separate
2D list, then merges these 2D lists if their combined length is within
the max sequence length.
Args:
tokenized_data (Dict[str, List[List[Any]]]): 3D list of tokenized text data.
Returns:
List[List[List[Any]]]: Optimized 3D list after merging shorter sequences.
"""
tokenized_data = tokenized_data["data"]
def convert_3d_to_modified_3d(tokenized_data):
# First, flatten the 3D list to a 2D list
flattened_2d_list = []
for two_d_list in tokenized_data:
for one_d_list in two_d_list:
flattened_2d_list.append(one_d_list)
# Then, convert each list in the flattened 2D list to a 2D list
# within a new 3D list
new_3d_list = []
for one_d_list in flattened_2d_list:
new_2d_list = [one_d_list]
new_3d_list.append(new_2d_list)
return new_3d_list
tokenized_data = convert_3d_to_modified_3d(tokenized_data)
# Precompute combined length of all lists in each 2D list
combined_lengths = [
sum(len(one_d_list) for one_d_list in two_d_list)
for two_d_list in tokenized_data
]
indices_to_remove = set()
# Iterate over each 2D list in the 3D list in reverse order
for i in range(len(tokenized_data) - 1, 0, -1):
# Use the precomputed length
current_combined_length = combined_lengths[i]
# Check if combined length of current 2D list is less than max_seq_length
if current_combined_length < self.max_seq_length:
# Look for a previous 2D list to merge with
for j in range(i - 1, -1, -1):
# Use the precomputed length
total_combined_length = (
current_combined_length + combined_lengths[j]
)
# Check if combined length of both 2D lists is within max_seq_length
if total_combined_length <= self.max_seq_length:
# If so, merge current 2D list into the previous 2D list
tokenized_data[j].extend(tokenized_data[i])
# Update combined length for merged 2D list
combined_lengths[j] += combined_lengths[i]
# Instead of deleting immediately, add the index to the set
indices_to_remove.add(i)
break # Exit inner loop as merge is done
# Delete the elements after the loop is done
# Convert indices_to_remove to a list and sort in reverse order to ensure indices remain correct while deleting
for index in sorted(indices_to_remove, reverse=True):
del tokenized_data[index]
del combined_lengths[index]
return tokenized_data