Source code for cerebras.modelzoo.data.nlp.dpo.DPOSyntheticDataset

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processors for synthetic data for DPO Training
"""

import numpy as np
import torch
from torch.utils.data import Dataset

from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.data.common.input_utils import is_distributed


[docs]class DPODataset(Dataset): """ A class representing a DPODataset inheriting torch.utils.data.Dataset. """
[docs] def __init__(self, data, data_processor): self.data = data self.length = data_processor.num_examples super(DPODataset, self).__init__()
def __getitem__(self, index): feature = { "chosen_input_ids": self.data["chosen_input_ids"][index], "chosen_attention_mask": self.data["chosen_attention_mask"][index], "chosen_labels": self.data["chosen_labels"][index], "rejected_input_ids": self.data["rejected_input_ids"][index], "rejected_attention_mask": self.data["rejected_attention_mask"][ index ], "rejected_labels": self.data["rejected_labels"][index], } return feature def __len__(self): return self.length
[docs]class DPOSyntheticDataProcessor: """ Synthetic dataset generator. :param dict params: dict containing training input parameters for creating dataset. Expects the following fields: - "num_examples (int): Number of training examples - "vocab_size" (int): Vocabulary size - "max_seq_length (int): Maximum length of the sequence to generate - "batch_size" (int): Batch size. - "shuffle" (bool): Flag to enable data shuffling. - "shuffle_seed" (int): Shuffle seed. """
[docs] def __init__(self, params): self.num_examples = params["num_examples"] self.vocab_size = params["vocab_size"] self.max_seq_len = params["max_sequence_length"] self.batch_size = get_streaming_batch_size(params["batch_size"]) self.shuffle = params["shuffle"] self.shuffle_seed = params.get("shuffle_seed", None) self.input_pad_id = params.get("input_pad_id", None) self.label_pad_id = params.get("label_pad_id", None) self.sampler = params.get("sampler", None) self.batch_sampler = params.get("batch_sampler", None) self.num_workers = params.get("num_workers", 8) self.pin_memory = params.get("pin_memory", False) self.drop_last = params.get("drop_last", True) self.timeout = params.get("timeout", 0) assert self.batch_size > 0, "Batch size should be positive." if self.drop_last and self.num_examples < self.batch_size: raise ValueError( f"This dataset does not return any batches because number of " f"examples in the dataset ({self.num_examples}) is less than " f"the batch size ({self.batch_size}) and `drop_last` is True." )
[docs] def create_dataloader(self): """ Create dataloader. :returns: dataloader """ np.random.seed(seed=0) data = dict() chosen_input_mask = np.zeros( (self.num_examples, self.max_seq_len), dtype=np.int32 ) seq_mid_idx = np.cast["int32"](self.max_seq_len / 2) for i in range(self.num_examples): start_idx = np.random.randint(seq_mid_idx, self.max_seq_len + 1) chosen_input_mask[i, start_idx : self.max_seq_len] = 1 data["chosen_attention_mask"] = 1 - chosen_input_mask data["chosen_input_ids"] = ( np.random.randint( low=0, high=self.vocab_size, size=(self.num_examples, self.max_seq_len), dtype=np.int32, ) * data["chosen_attention_mask"] ) data["chosen_labels"] = data["chosen_input_ids"] rejected_input_mask = np.zeros( (self.num_examples, self.max_seq_len), dtype=np.int32 ) seq_mid_idx = np.cast["int32"](self.max_seq_len / 2) for i in range(self.num_examples): start_idx = np.random.randint(seq_mid_idx, self.max_seq_len + 1) rejected_input_mask[i, start_idx : self.max_seq_len] = 1 data["rejected_attention_mask"] = 1 - rejected_input_mask data["rejected_input_ids"] = ( np.random.randint( low=0, high=self.vocab_size, size=(self.num_examples, self.max_seq_len), dtype=np.int32, ) * data["rejected_attention_mask"] ) data["rejected_labels"] = data["rejected_input_ids"] dataset = DPODataset(data, self) if is_distributed(): assert self.sampler is None, "Cannot use sampler in config with DDP" self.sampler = torch.utils.data.distributed.DistributedSampler( dataset, shuffle=self.shuffle, seed=self.shuffle_seed, ) return torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, sampler=self.sampler, batch_sampler=self.batch_sampler, num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=self.drop_last, timeout=self.timeout, )