# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort: off
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../.."))
# isort: on
import json
import logging
import os
import traceback
from argparse import ArgumentParser
from copy import deepcopy
import h5py
import numpy as np
from flask import (
Flask,
jsonify,
render_template,
request,
send_from_directory,
url_for,
)
from cerebras.modelzoo.data_preparation.data_preprocessing.tokenflow import (
tokenizer,
)
from cerebras.modelzoo.data_preparation.data_preprocessing.tokenflow.utils import (
construct_attention_mask,
)
app = Flask(__name__)
logging.basicConfig(
level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s'
)
[docs]class TokenFlowDataProcessor:
def __init__(self, filepath, data_params):
self.filepath = filepath
self.data_params = data_params
assert (
'processing' in data_params
), "'processing' key must be in data_params"
assert 'setup' in data_params, "'setup' key must be in data_params"
assert 'features' in data_params, "'features' must be in data_params"
assert (
'max_seq_length' in data_params['processing']
), "'max_seq_length' must be in 'processing'"
assert (
'pad_id' in data_params['processing']
), "'pad_id' must be in 'processing'"
assert (
'eos_id' in data_params['processing']
), "'eos_id' must be in 'processing'"
assert 'mode' in data_params['setup'], "'mode' must be in 'setup'"
self.msl = data_params["processing"].get("max_seq_length")
self.pad_id = data_params['processing'].get('pad_id')
self.eos_id = data_params['processing'].get('eos_id')
self.mode = data_params['setup'].get('mode')
self.features = data_params.get('features')
self.is_multimodal = self.data_params['dataset'].get(
'is_multimodal', False
)
self.input_ids = 'text_input_ids' if self.is_multimodal else 'input_ids'
self.tokenizer = tokenizer.GenericTokenizer(
deepcopy(data_params['processing']), filepath
)
self.datadict = {}
# Special handling for MLM mode while loading data.
if data_params['dataset'].get('training_objective') == 'mlm':
# Handle load_data() for MLM separately as it returns an extra array.
self.mlm_with_gather = data_params['dataset'].get('mlm_with_gather')
self.ignore_index = data_params['dataset'].get('ignore_index', -100)
self.data, labels, self.image_paths, image_data_locs = (
self.load_data()
)
if not self.mlm_with_gather:
self.datadict['labels'] = labels[:, 0, :].copy()
else:
self.datadict['labels'] = labels
else:
# image_paths and image_data_locs are used in multimodal datasets
self.data, self.image_paths, image_data_locs = self.load_data()
self.datadict['image_paths'] = self.image_paths
if self.mode == 'dpo':
for i, feature in enumerate(self.features):
if feature == "chosen_attention_mask":
feature = "chosen_loss_mask"
if feature == "rejected_attention_mask":
feature = "rejected_loss_mask"
self.datadict[feature] = self.data[:, i, :].copy()
chosen_nstrings = self.datadict['chosen_input_ids'].shape[0]
rejected_nstrings = self.datadict['rejected_input_ids'].shape[0]
self.nstrings = chosen_nstrings + rejected_nstrings
else:
# Special handling for MLM mode; we don't need to construct labels as we've done it above.
## Also, rename attention_mask to loss_mask. Since dataloader incorrectly renames loss mask to attention mask.
if data_params['dataset'].get('training_objective') == 'mlm':
self.datadict['input_ids'] = self.data[:, 0, :].copy()
self.datadict['loss_mask'] = self.data[:, 1, :].copy()
else:
for i, feature in enumerate(self.features):
if feature == "attention_mask":
feature = "loss_mask"
self.datadict[feature] = self.data[:, i, :].copy()
self.nstrings = self.datadict[self.input_ids].shape[0]
## Inverted attention_mask is named as key_padding_mask in multimodal in other cases as it is not a part of data
if "key_padding_mask" in self.datadict:
self.datadict['attention_mask'] = (
1 - self.datadict['key_padding_mask']
).tolist()
else:
if self.mode == 'dpo':
self.datadict['chosen_attention_mask'] = (
construct_attention_mask(
self.datadict,
self.eos_id,
self.pad_id,
input_key='chosen_input_ids',
)
)
self.datadict['rejected_attention_mask'] = (
construct_attention_mask(
self.datadict,
self.eos_id,
self.pad_id,
input_key='rejected_input_ids',
)
)
else:
self.datadict['attention_mask'] = construct_attention_mask(
self.datadict,
self.eos_id,
self.pad_id,
input_key=self.input_ids,
)
# This should modify the dict attr
if self.mode == 'dpo':
self.datadict['images_bitmap'] = np.zeros(
self.datadict['chosen_input_ids'].shape
)
else:
self.datadict['images_bitmap'] = np.zeros(
self.datadict[self.input_ids].shape
)
if image_data_locs.size:
for i in range(image_data_locs.shape[0]):
image_index = 1
for j in range(image_data_locs.shape[1]):
if image_data_locs[i][j][0] == self.msl:
break
for k in image_data_locs[i][j]:
self.datadict['images_bitmap'][i][k] = image_index
image_index += 1
# Construct dummy input_strings and label_strings that are overwritten once decoded on-demand
if self.mode == 'dpo':
# Chosen responses.
self.datadict['chosen_input_strings'] = np.full(
self.datadict['chosen_input_ids'].shape, '', dtype='<U20'
)
self.datadict['chosen_label_strings'] = np.full(
self.datadict['chosen_labels'].shape, '', dtype='<U20'
)
# Rejected responses.
self.datadict['rejected_input_strings'] = np.full(
self.datadict['rejected_input_ids'].shape, '', dtype='<U20'
)
self.datadict['rejected_label_strings'] = np.full(
self.datadict['rejected_labels'].shape, '', dtype='<U20'
)
else:
self.datadict['input_strings'] = np.full(
self.datadict[self.input_ids].shape, '', dtype='<U20'
)
self.datadict['label_strings'] = np.full(
self.datadict['labels'].shape, '', dtype='<U20'
)
def load_data(self):
try:
with h5py.File(self.filepath, mode='r') as h5_file:
# Multimodal data has this format
if self.data_params['dataset'].get('is_multimodal', False):
return (
np.array(h5_file['data']),
np.array(
[
[i.decode('utf-8') for i in paths]
for paths in h5_file.get('img_path')
]
),
np.array(h5_file.get('img_data_loc')),
)
# MLM data has this format.
elif (
self.data_params['dataset'].get('training_objective')
== 'mlm'
):
if not self.mlm_with_gather:
labels = np.array(h5_file['labels'])
return (
np.array(h5_file['data']),
labels,
np.array([]),
np.array([]),
)
else:
labels = np.array(h5_file['labels'][:, 0, :])
masked_lm_positions_list = np.array(
h5_file['labels'][:, 1, :]
)
masked_lm_weights_list = np.array(
h5_file['labels'][:, 2, :]
)
updated_shape = (labels.shape[0], self.msl)
updated_labels = np.full(
updated_shape, self.ignore_index
)
for i in range(labels.shape[0]):
positions = masked_lm_positions_list[i]
updated_labels[i, positions] = labels[i]
return (
np.array(h5_file['data']),
updated_labels,
np.array([]),
np.array([]),
)
elif h5_file.get('data'):
return np.array(h5_file['data']), np.array([]), np.array([])
return (
np.array(h5_file['data_data']),
np.array([]),
np.array([]),
)
except Exception as e:
logging.error(f"Failed to load data from {self.filepath}: {str(e)}")
logging.error(traceback.format_exc())
raise RuntimeError(
f"Error while loading data from {self.filepath}: {str(e)}"
)
def get_stats(self):
stats = self.data_params['processing']
stats.update(self.data_params['setup'])
stats.update(self.data_params['dataset'])
stats['multimodal'] = self.data_params['dataset'].get(
'is_multimodal', False
)
# Removed empty keys
for attr in list(stats.keys()):
if stats[attr] is None:
stats.pop(attr)
# Data is trivial to the user
if stats.get('data'):
stats.pop('data')
if stats.get('input_dir'):
stats.pop('input_dir')
return stats
def get_datadict(self, sequence):
response = {}
if self.mode != 'dpo':
response['input_ids'] = self.datadict[self.input_ids][sequence]
response['labels'] = self.datadict['labels'][sequence]
response['input_strings'] = self.tokenizer.convert_ids_to_tokens(
response['input_ids']
)
response['label_strings'] = self.tokenizer.convert_ids_to_tokens(
response['labels']
)
if 'images_bitmap' in self.datadict:
response['images_bitmap'] = self.datadict['images_bitmap'][
sequence
]
response['image_paths'] = []
if self.data_params['dataset'].get('is_multimodal', False):
response['image_paths'] = self.datadict['image_paths'][sequence]
if 'loss_mask' in self.datadict:
response['loss_mask'] = self.datadict['loss_mask'][sequence]
response['attention_mask'] = self.datadict['attention_mask'][
sequence
]
# Store back the decoded strings for faster call next time
self.datadict['input_strings'][sequence] = response['input_strings']
self.datadict['label_strings'][sequence] = response['label_strings']
else:
# Update response for chosen.
response['chosen_input_ids'] = self.datadict['chosen_input_ids'][
sequence
]
response['chosen_labels'] = self.datadict['chosen_labels'][sequence]
response['chosen_input_strings'] = (
self.tokenizer.convert_ids_to_tokens(
response['chosen_input_ids']
)
)
response['chosen_label_strings'] = (
self.tokenizer.convert_ids_to_tokens(response['chosen_labels'])
)
if 'chosen_loss_mask' in self.datadict:
response['chosen_loss_mask'] = self.datadict[
'chosen_loss_mask'
][sequence]
response['chosen_attention_mask'] = self.datadict[
'chosen_attention_mask'
][sequence]
self.datadict['chosen_input_strings'][sequence] = response[
'chosen_input_strings'
]
self.datadict['chosen_label_strings'][sequence] = response[
'chosen_label_strings'
]
# Update response for rejected.
response['rejected_input_ids'] = self.datadict[
'rejected_input_ids'
][sequence]
response['rejected_labels'] = self.datadict['rejected_labels'][
sequence
]
response['rejected_input_strings'] = (
self.tokenizer.convert_ids_to_tokens(
response['rejected_input_ids']
)
)
response['rejected_label_strings'] = (
self.tokenizer.convert_ids_to_tokens(
response['rejected_labels']
)
)
if 'rejected_loss_mask' in self.datadict:
response['rejected_loss_mask'] = self.datadict[
'rejected_loss_mask'
][sequence]
response['rejected_attention_mask'] = self.datadict[
'rejected_attention_mask'
][sequence]
self.datadict['rejected_input_strings'][sequence] = response[
'rejected_input_strings'
]
self.datadict['rejected_label_strings'][sequence] = response[
'rejected_label_strings'
]
# Update responses related to image.
if 'images_bitmap' in self.datadict:
response['images_bitmap'] = self.datadict['images_bitmap'][sequence]
response['image_paths'] = []
if self.data_params['dataset'].get('is_multimodal', False):
response['image_paths'] = self.datadict['image_paths'][sequence]
response.update({'stats': self.get_stats()})
for key, val in response.items():
if isinstance(val, np.ndarray):
response[key] = val.tolist()
return response
[docs]def process_file_for_sequence_distribution(filename, bin_edges, pad_id, msl):
import os
import h5py
import numpy as np
length_of_sequences = np.zeros(len(bin_edges) - 1, dtype=int)
sequence_lengths = []
filename = os.path.abspath(filename)
try:
with h5py.File(filename, mode='r') as h5_file:
data = h5_file["data"][:]
no_of_sequences = data.shape[0]
for i in range(no_of_sequences):
tokens = data[i, 0]
if pad_id in tokens:
sequence_length = np.argmax(tokens == pad_id)
else:
sequence_length = (
msl # Sequence has no padding, use max length
)
sequence_lengths.append(sequence_length)
# Calculate the bin
percentage = (sequence_length * 100) // msl
bin_number = min(percentage // 5, len(length_of_sequences) - 1)
length_of_sequences[bin_number] += 1
except Exception as e:
return np.zeros(len(bin_edges) - 1, dtype=int), []
return length_of_sequences, sequence_lengths
[docs]def save_sequence_distribution(file_directory, data_params_path):
import glob
import json
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from os import cpu_count
import matplotlib.pyplot as plt
import numpy as np
h5_files = glob.glob(os.path.join(file_directory, "*.h5"))
sequence_dist_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'static',
'images',
)
with open(data_params_path, 'r') as json_file:
data_params = json.load(json_file)
all_sequence_lengths = []
pad_id = data_params['processing'].get('pad_id')
msl = data_params['processing'].get('max_seq_length')
num_bins = 20
bin_edges = np.linspace(0, msl, num_bins + 1)
length_of_sequences_total = np.zeros(len(bin_edges) - 1, dtype=int)
with ProcessPoolExecutor(max_workers=cpu_count()) as executor:
futures = [
executor.submit(
process_file_for_sequence_distribution,
filename,
bin_edges,
pad_id,
msl,
)
for filename in h5_files
]
for future in as_completed(futures):
length_of_sequences, sequence_lengths = future.result()
length_of_sequences_total += length_of_sequences
all_sequence_lengths.extend(sequence_lengths)
total_sequences = np.sum(length_of_sequences_total)
# Calculate percentage distribution
if total_sequences > 0:
length_of_sequences_percentage = (
length_of_sequences_total / total_sequences
) * 100
else:
length_of_sequences_percentage = length_of_sequences_total
# Calculate mean and standard deviation of the sequence lengths
if all_sequence_lengths:
mean_sequence_length = np.mean(all_sequence_lengths)
std_sequence_length = np.std(all_sequence_lengths)
else:
mean_sequence_length = 0
std_sequence_length = 0
plt.figure(figsize=(12, 6))
ranges = [f'{int(bin_edges[i+1])}' for i in range(len(bin_edges) - 1)]
text_ranges = [
f'{int(bin_edges[i])}--{int(bin_edges[i+1])}'
for i in range(len(bin_edges) - 1)
]
plt.bar(ranges, length_of_sequences_percentage, color='skyblue')
# Create a string that contains the information for each range
percentage_info = "\n".join(
[
f"{range_label}: {v:.1f}%"
for range_label, v in zip(
text_ranges, length_of_sequences_percentage
)
]
)
# Add a text box inside the plot with mean and std
text_info = f"Mean: {mean_sequence_length:.1f}\nStd: {std_sequence_length:.1f}\n\n{percentage_info}"
plt.gca().text(
1.05,
0.5,
text_info,
transform=plt.gca().transAxes,
bbox=dict(
facecolor='white', edgecolor='black', boxstyle='round,pad=1.0'
),
verticalalignment='center',
fontsize=6,
)
plt.gca().set_position([0.1, 0.1, 0.75, 0.8])
plt.xlabel('MSL Length')
plt.ylabel('% of sequences')
plt.grid(axis='y', linestyle='--', alpha=0.7)
image_filename = f'sequence_distribution.png'
image_path = os.path.join(sequence_dist_dir, image_filename)
plt.title(f'Sequence Distribution Plot')
plt.savefig(image_path)
plt.close()
[docs]def get_data_or_error(filename):
global data_processors
global data_params
try:
if not data_processors[filename]:
data_processors[filename] = TokenFlowDataProcessor(
filename, data_params
)
return data_processors[filename]
except Exception as e:
logging.error(f"Error processing the file {filename}: {str(e)}")
logging.error(traceback.format_exc())
return (
f"The requested file is not found (or error in processing -- please check the logs for more details.): {e}",
400,
)
[docs]def load_params(args):
try:
with open(args.data_params, 'r') as json_file:
return json.load(json_file)
except:
return
@app.route('/')
def index():
global args
global data_params
global data_processors
files = []
initial_data = None
# Load data_params here to make it a one time operation
data_params = load_params(args)
if not data_params:
return (
"Error in loading data_params.json! Please check if the output directory contains the data_params.json file, or specify it as a CLI argument.",
400,
)
if os.path.isdir(args.output_dir):
files = [
os.path.join(args.output_dir, f)
for f in os.listdir(args.output_dir)
if f.endswith('.h5')
]
if not files:
return (
"There are no HDF5 files present in the directory. Please check the directory.",
404,
)
elif os.path.isfile(args.output_dir) and args.output_dir.endswith('.h5'):
files = [args.output_dir]
if not files:
return (
"The passed file is not a valid HDF5 file. Please check the file.",
404,
)
# Get initial data that is supposed to be loaded.
data_processors = {file: None for file in files}
initial_data = get_data_or_error(files[0])
if not isinstance(initial_data, TokenFlowDataProcessor):
return jsonify({"error": initial_data[0], "code": initial_data[1]})
return render_template(
'index.html',
files=files,
initial_data=initial_data.get_datadict(0),
nstrings=initial_data.nstrings,
)
@app.route('/data', methods=['POST'])
def data():
filename = request.form['filename']
sequence = request.form['sequence']
processor = get_data_or_error(filename)
if not isinstance(processor, TokenFlowDataProcessor):
return jsonify({"error": processor[0], "code": processor[1]})
response = processor.get_datadict(int(sequence))
response['nstrings'] = processor.nstrings
return response
@app.route('/images/<path:filename>')
def serve_image(filename):
return send_from_directory(
os.path.join(
os.path.dirname(os.path.abspath(args.data_params)),
data_params['setup']['image_dir'],
),
filename,
)
@app.route('/get_data_params')
def get_data():
with open(args.data_params, 'r') as file:
data = json.load(file)
return jsonify(data)
@app.route('/generate_sequence_distribution', methods=['POST'])
def serve_sequence_distribtion():
try:
file_directory = args.output_dir
save_sequence_distribution(file_directory, args.data_params)
image_path = url_for(
'static', filename='images/sequence_distribution.png'
)
return jsonify({'image_path': image_path}), 200
except Exception as e:
return (
jsonify({'Unable to retrieve sequence distribution!': str(e)}),
500,
)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
"--output_dir",
type=str,
help="Directory/location of one or more HDF5 files. In case a single file is passed, data_params should also be specified",
required=True,
)
parser.add_argument(
"--data_params",
type=str,
help="Location of data_params, required for loading heruistics related to the preprocessed data",
)
parser.add_argument(
"--port", type=int, help="Port to run the Flask app on", default=5000
)
global args
args = parser.parse_args()
if not args.data_params:
if os.path.isdir(args.output_dir):
args.data_params = os.path.join(args.output_dir, 'data_params.json')
else:
exit(
"Use --data_params <path/to/file> to specify the path of data_params.json. Required when passing a single HDF5 file."
)
app.run(debug=True, host='0.0.0.0', port=args.port)