# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processors for synthetic data for DPO Training
"""
import numpy as np
import torch
from torch.utils.data import Dataset
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.data.common.input_utils import is_distributed
[docs]class DPODataset(Dataset):
"""
A class representing a DPODataset inheriting torch.utils.data.Dataset.
"""
[docs] def __init__(self, data, data_processor):
self.data = data
self.length = data_processor.num_examples
super(DPODataset, self).__init__()
def __getitem__(self, index):
feature = {
"chosen_input_ids": self.data["chosen_input_ids"][index],
"chosen_attention_mask": self.data["chosen_attention_mask"][index],
"chosen_labels": self.data["chosen_labels"][index],
"rejected_input_ids": self.data["rejected_input_ids"][index],
"rejected_attention_mask": self.data["rejected_attention_mask"][
index
],
"rejected_labels": self.data["rejected_labels"][index],
}
return feature
def __len__(self):
return self.length
[docs]class DPOSyntheticDataProcessor:
"""
Synthetic dataset generator.
:param dict params: dict containing training
input parameters for creating dataset.
Expects the following fields:
- "num_examples (int): Number of training examples
- "vocab_size" (int): Vocabulary size
- "max_seq_length (int): Maximum length of the sequence to generate
- "batch_size" (int): Batch size.
- "shuffle" (bool): Flag to enable data shuffling.
- "shuffle_seed" (int): Shuffle seed.
"""
[docs] def __init__(self, params):
self.num_examples = params["num_examples"]
self.vocab_size = params["vocab_size"]
self.max_seq_len = params["max_sequence_length"]
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.shuffle = params["shuffle"]
self.shuffle_seed = params.get("shuffle_seed", None)
self.input_pad_id = params.get("input_pad_id", None)
self.label_pad_id = params.get("label_pad_id", None)
self.sampler = params.get("sampler", None)
self.batch_sampler = params.get("batch_sampler", None)
self.num_workers = params.get("num_workers", 8)
self.pin_memory = params.get("pin_memory", False)
self.drop_last = params.get("drop_last", True)
self.timeout = params.get("timeout", 0)
assert self.batch_size > 0, "Batch size should be positive."
if self.drop_last and self.num_examples < self.batch_size:
raise ValueError(
f"This dataset does not return any batches because number of "
f"examples in the dataset ({self.num_examples}) is less than "
f"the batch size ({self.batch_size}) and `drop_last` is True."
)
[docs] def create_dataloader(self):
"""
Create dataloader.
:returns: dataloader
"""
np.random.seed(seed=0)
data = dict()
chosen_input_mask = np.zeros(
(self.num_examples, self.max_seq_len), dtype=np.int32
)
seq_mid_idx = np.cast["int32"](self.max_seq_len / 2)
for i in range(self.num_examples):
start_idx = np.random.randint(seq_mid_idx, self.max_seq_len + 1)
chosen_input_mask[i, start_idx : self.max_seq_len] = 1
data["chosen_attention_mask"] = 1 - chosen_input_mask
data["chosen_input_ids"] = (
np.random.randint(
low=0,
high=self.vocab_size,
size=(self.num_examples, self.max_seq_len),
dtype=np.int32,
)
* data["chosen_attention_mask"]
)
data["chosen_labels"] = data["chosen_input_ids"]
rejected_input_mask = np.zeros(
(self.num_examples, self.max_seq_len), dtype=np.int32
)
seq_mid_idx = np.cast["int32"](self.max_seq_len / 2)
for i in range(self.num_examples):
start_idx = np.random.randint(seq_mid_idx, self.max_seq_len + 1)
rejected_input_mask[i, start_idx : self.max_seq_len] = 1
data["rejected_attention_mask"] = 1 - rejected_input_mask
data["rejected_input_ids"] = (
np.random.randint(
low=0,
high=self.vocab_size,
size=(self.num_examples, self.max_seq_len),
dtype=np.int32,
)
* data["rejected_attention_mask"]
)
data["rejected_labels"] = data["rejected_input_ids"]
dataset = DPODataset(data, self)
if is_distributed():
assert self.sampler is None, "Cannot use sampler in config with DDP"
self.sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
shuffle=self.shuffle,
seed=self.shuffle_seed,
)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
sampler=self.sampler,
batch_sampler=self.batch_sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
timeout=self.timeout,
)