Source code for cerebras.modelzoo.data.vision.segmentation.Hdf5BaseIterDataProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from abc import ABC, abstractmethod
from typing import Literal

import torch
from torchvision import transforms

import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.common.pytorch_utils import BufferedShuffleDataset
from cerebras.modelzoo.data.vision.segmentation.Hdf5BaseDataProcessor import (
    Hdf5BaseDataProcessorConfig,
)
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
    normalize_tensor_transform,
)
from cerebras.modelzoo.data.vision.utils import create_worker_cache


[docs]class Hdf5BaseIterDataProcessorConfig(Hdf5BaseDataProcessorConfig): data_processor: Literal["Hdf5BaseIterDataProcessor"]
[docs]class Hdf5BaseIterDataProcessor(ABC, torch.utils.data.IterableDataset): """ A HDF5 dataset processor for UNet HDF dataset. Performs on-the-fly augmentation of image and labek. Functionality includes: Reading data from HDF5 documents Augmenting data """ def __init__(self, config: Hdf5BaseIterDataProcessorConfig): super(Hdf5BaseIterDataProcessor, self).__init__() use_worker_cache = config.use_worker_cache self.data_dir = config.data_dir if use_worker_cache and dist.is_streamer(): if not cstorch.use_cs(): raise RuntimeError( "use_worker_cache not supported for non-CS runs" ) else: self.data_dir = create_worker_cache(self.data_dir) self.num_classes = config.num_classes self.normalize_data_method = config.normalize_data_method if self.normalize_data_method: # Normalize self.normalize_transform = transforms.Lambda( self._apply_normalization ) self.image_shape = config.image_shape # of format (H, W, C) ( self.tgt_image_height, self.tgt_image_width, self.channels, ) = self.image_shape self.loss_type = config.loss self.shuffle_seed = config.shuffle_seed if self.shuffle_seed is not None: torch.manual_seed(self.shuffle_seed) self.augment_data = config.augment_data self.batch_size = get_streaming_batch_size(config.batch_size) self.shuffle = config.shuffle self.shuffle_buffer = config.shuffle_buffer self.num_workers = config.num_workers self.drop_last = config.drop_last self.prefetch_factor = config.prefetch_factor self.persistent_workers = config.persistent_workers self.mp_type = cstorch.amp.get_floating_point_dtype() self.num_tasks = dist.num_streamers() if dist.is_streamer() else 1 self.task_id = dist.get_streaming_rank() if dist.is_streamer() else 0 # set later once processor gets a call to create a dataloader self.num_examples_in_this_task = 0 self.files_in_this_task = [] @abstractmethod def _shard_files(self, is_training=False): pass @abstractmethod def _load_buffer(self, data_partitions): pass @abstractmethod def __iter__(self): pass @abstractmethod def _shard_dataset(self, worker_id, num_workers): pass def __len__(self): """ Returns the len of dataset on the task process """ return self.num_examples_in_this_task def _worker_init_fn(self, worker_id): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers else: # Single-process worker_id = 0 num_workers = 1 if self.shuffle_seed is not None: # Use a unique seed for each worker. random.seed(self.shuffle_seed + worker_id) self.data_partitions = self._shard_dataset(worker_id, num_workers)
[docs] def create_dataloader(self): """ Classmethod to create the dataloader object. """ is_training = self.split == "train" self._shard_files(is_training) if self.shuffle: random.seed(self.shuffle_seed) random.shuffle(self.files_in_this_task) data_loader = torch.utils.data.DataLoader( ( BufferedShuffleDataset( dataset=self, buffer_size=self.shuffle_buffer ) if self.shuffle else self ), batch_size=self.batch_size, drop_last=self.drop_last, num_workers=self.num_workers, prefetch_factor=( self.prefetch_factor if self.num_workers > 0 else None ), persistent_workers=( self.persistent_workers if self.num_workers > 0 else False ), worker_init_fn=self._worker_init_fn, ) # set self.data_partitions in case self.num_workers == 0 if self.num_workers == 0: self._worker_init_fn(0) return data_loader
def _apply_normalization(self, x): return normalize_tensor_transform( x, normalize_data_method=self.normalize_data_method )