Source code for cerebras.modelzoo.data_preparation.raw_dataset_processor.RawDatasetProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
    This is Dataset process for processing Raw data set on the fly
    This contains methods for loading the dataset, tokenizing the dataset
    and all data transformations are handled as part of the collator function
"""

import logging
import os
import random
from typing import Any, Dict, Iterator, List, Literal, Optional

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, default_collate

from cerebras.modelzoo.config import DataConfig
from cerebras.modelzoo.data.common.input_utils import (
    num_tasks,
    shard_list_contiguous,
    task_id,
)
from cerebras.modelzoo.data.vision.preprocessing import get_preprocess_transform
from cerebras.modelzoo.data_preparation.data_preprocessing.data_preprocessor import (
    DataPreprocessor,
)
from cerebras.modelzoo.data_preparation.raw_dataset_processor.utils import (
    Reader,
)

LOGGER = logging.getLogger(__name__)


[docs]class RawDatasetProcessorConfig(DataConfig): """Configuration class for RawDatasetProcessor.""" data_processor: Literal["RawDatasetProcessor"] batch_size: int = ... """ The dataset preprocessing configuration. """ ##TODO: Create a config class for preprocessing as well preprocessing: dict = ... shuffle: bool = True shuffle_seed: int = 0 num_workers: int = 0 prefetch_factor: Optional[int] = 10 persistent_workers: bool = True drop_last: bool = True seed: Optional[int] = None def post_init(self, context): if not self.num_workers: self.prefetch_factor = None # the default value in DataLoader self.persistent_workers = False
[docs]class MultimodalRawDatasetProcessorConfig(RawDatasetProcessorConfig): """Multimodal Configuration class for RawDatasetProcessor.""" data_processor: Literal["MultimodalRawDatasetProcessor"] image_data_size: List[int] = ... """ The final C x H x W shape of the image. """ transforms: List[dict] = ... """ List of transformations to apply to images. """ img_data_dir: str = ... """ The directory containing the image data. """
[docs]class RawDatasetProcessor(torch.utils.data.IterableDataset): def __init__(self, config: RawDatasetProcessorConfig): if isinstance(config, dict): config = RawDatasetProcessorConfig(**config) super(RawDatasetProcessor, self).__init__() self.config = config self.batch_size = self.config.batch_size self.preprocessing_params = self.config.preprocessing self.dataset_processor = DataPreprocessor(self.preprocessing_params) self.features_list = self.dataset_processor.token_generator.features self.num_workers = self.config.num_workers self.prefetch_factor = self.config.prefetch_factor self.persistent_workers = self.config.persistent_workers self.reader = Reader( self.dataset_processor.input_files, keys=self.dataset_processor.data_keys, read_hook_fn=self.dataset_processor.read_hook_fn, ) self.seed = self.config.seed self.rng = random.Random(self.seed) self.input_files_in_this_task = shard_list_contiguous( self.dataset_processor.input_files, task_id(), num_tasks() ) def _worker_init_fn(self, worker_id: int): """ Initialization function for each worker in a DataLoader. Args: worker_id (int): The ID of the current worker. """ worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers else: # Single-process worker_id = 0 num_workers = 1 if self.seed is not None: # Use a unique seed for each worker. random.seed(self.seed + worker_id) # Shard the data files between workers self.input_files_in_this_worker = shard_list_contiguous( self.input_files_in_this_task, worker_id, num_workers ) def __iter__(self) -> Iterator[Dict[str, np.ndarray]]: """ Returns an iterator over the items of the class. Returns: Iterator[Dict[str, np.ndarray]]: An iterator yielding dictionaries with string keys and NumPy array values. """ return self.get_next_item()
[docs] def get_next_item(self) -> Iterator[Dict[str, np.ndarray]]: """ Returns the next item in the iteration. This function iterates over the data stream from the reader, tokenizes the data, and yields dictionaries containing features as keys and NumPy arrays as values. Returns: Iterator[Dict[str, np.ndarray]]: An iterator yielding dictionaries with string keys and NumPy array values. """ for data in self.reader.stream_data(): data_array = self.dataset_processor.read_hook_fn(data) # Tokenize the data and get stats tokenized_data, stats = ( self.dataset_processor.token_generator.encode(data_array) ) # Continue to next iteration if "data" key is not present if "data" not in tokenized_data.keys(): continue # Iterate through the tokenized data and yield feature dictionary for d in tokenized_data["data"]: yield { feature: np.array(d[i], np.int32) for i, feature in enumerate(self.features_list) }
[docs] def collate_fn(self, batch: List[Dict[str, np.ndarray]]) -> Any: """ Collates a list of dictionaries into a batch Args: batch (List[Dict[str, np.ndarray]]): A list of dictionaries, where each dictionary contains string keys and NumPy array values. Returns: Any: The collated batch. """ if self.dataset_processor.shuffle: random.shuffle(batch) return default_collate(batch)
[docs] def create_dataloader(self) -> DataLoader: """ Classmethod to create the dataloader object. Returns: DataLoader: A DataLoader object for the dataset. """ # Create the DataLoader object with the specified parameters dataloader = DataLoader( self, batch_size=self.batch_size, # Number of samples per batch drop_last=self.config.drop_last, # Drop the last incomplete batch if the dataset size is not divisible by the batch size collate_fn=self.collate_fn, # Function to merge a list of samples to form a mini-batch num_workers=self.num_workers, # Number of subprocesses to use for data loading prefetch_factor=( self.prefetch_factor if self.num_workers > 0 else None ), # Number of samples loaded in advance by each worker persistent_workers=( self.persistent_workers if self.num_workers > 0 else False ), # Keep worker processes alive after they finish their tasks worker_init_fn=( self._worker_init_fn if self.num_workers > 0 and self.seed is not None else None ), # Function to initialize the worker process ) # set self.data_partitions in case self.num_workers == 0 if self.num_workers == 0: self._worker_init_fn(0) return dataloader
[docs]class MultimodalRawDatasetProcessor(RawDatasetProcessor): """Dataset processor for multimodal data (e.g., image data).""" def __init__(self, config: MultimodalRawDatasetProcessorConfig): if isinstance(config, dict): config = MultimodalRawDatasetProcessorConfig(**config) super(MultimodalRawDatasetProcessor, self).__init__(config) self.img_data_dir = self.config.img_data_dir self.image_data_size = self.config.image_data_size self.transforms = get_preprocess_transform( { "transforms": self.config.transforms, } ) self.image_data_size = self.config.image_data_size self.transforms = get_preprocess_transform( { "transforms": self.config.transforms, } ) def preprocess_img(self, path_list): img_list = [] for img_paths in path_list: imgs_per_sample_list = [] ## iterate over all the image paths in 1 data sample for path in img_paths: path = path.decode("utf-8") if path != "None": image_path = os.path.join(self.img_data_dir, path) image = Image.open(image_path).convert("RGB") else: image = Image.new( mode="RGB", size=(self.image_data_size[2], self.image_data_size[1]), ) imgs_per_sample_list.append(self.transforms(image).unsqueeze(0)) imgs_per_sample = torch.cat( imgs_per_sample_list, dim=0 ) ## shape - max_num_img * C * H * W img_list.append(imgs_per_sample.unsqueeze(0)) img = torch.cat( img_list, dim=0 ) ## shape - batch_size * max_num_img * C * H * W return img
[docs] def get_next_item(self) -> Iterator[Dict[str, np.ndarray]]: """ Returns the next item in the iteration. This function iterates over the data stream from the reader, tokenizes the data, and yields dictionaries containing features as keys and NumPy arrays as values. Returns: Iterator[Dict[str, np.ndarray]]: An iterator yielding dictionaries with string keys and NumPy array values. """ for data in self.reader.stream_data(): data_array = self.dataset_processor.read_hook_fn(data) # Tokenize the data and get stats tokenized_data, stats = ( self.dataset_processor.token_generator.encode(data_array) ) # Continue to next iteration if "data" key is not present if "data" not in tokenized_data.keys(): continue # Apply image transformation tokenized_data['image_data'] = self.preprocess_img( tokenized_data['img_path'] ) for i in range(len(tokenized_data["data"])): data = { feature: np.array( tokenized_data["data"][i][feature_idx], np.int32 ) for feature_idx, feature in enumerate(self.features_list) } data.update( { "image_data": tokenized_data["image_data"][i], "image_data_loc": tokenized_data["img_data_loc"][i], } ) yield data