Source code for cerebras.modelzoo.data.vision.segmentation.InriaAerialDataProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import random
from typing import Any, Literal, Optional, Sequence

import torch
from PIL import Image
from pydantic import Field
from torchvision import transforms
from torchvision.datasets import VisionDataset

import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.data.vision.classification.dataset_factory import (
    VisionSubset,
)
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
    adjust_brightness_transform,
    normalize_tensor_transform,
    rotation_90_transform,
)
from cerebras.modelzoo.data.vision.segmentation.UNetDataProcessor import (
    UNetDataProcessor,
    UNetDataProcessorConfig,
)
from cerebras.modelzoo.data.vision.transforms import LambdaWithParam
from cerebras.modelzoo.data.vision.utils import (
    FastDataLoader,
    ShardedSampler,
    create_worker_cache,
    num_tasks,
    task_id,
)


[docs]class InriaAerialDataset(VisionDataset): def __init__( self, root, split="train", transforms=None, # pylint: disable=redefined-outer-name transform=None, target_transform=None, use_worker_cache=False, ): super(InriaAerialDataset, self).__init__( root, transforms, transform, target_transform ) if split not in ["train", "val", "test"]: raise ValueError( f"Invalid value={split} passed to `split` argument. " f"Valid are 'train' or 'val' or 'test' " ) self.split = split if split == "test" and target_transform is not None: raise ValueError( f"split {split} has no mask images and hence target_transform should be None. " f"Got {target_transform}." ) if use_worker_cache and dist.is_streamer(): if not cstorch.use_cs(): raise RuntimeError( "use_worker_cache not supported for non-CS runs" ) else: self.root = create_worker_cache(self.root) self.data_dir = os.path.join(self.root, self.split) self.image_dir = os.path.join(self.data_dir, "images") self.mask_dir = os.path.join(self.data_dir, "gt") self.file_list = sorted(os.listdir(self.image_dir)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is a tuple of all target types if target_type is a list with more than one item. """ image_file_path = os.path.join(self.image_dir, self.file_list[index]) image = Image.open(image_file_path) # 3-channel PILImage mask_file_path = os.path.join(self.mask_dir, self.file_list[index]) target = Image.open(mask_file_path) # PILImage if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self): return len(self.file_list)
[docs]class InriaAerialDataProcessorConfig(UNetDataProcessorConfig): data_processor: Literal["InriaAerialDataProcessor"] use_worker_cache: bool = False overfit: bool = False overfit_num_batches: Optional[int] = None overfit_indices: Optional[Sequence] = None split: Literal["train", "val", "test"] = "train" "Dataset split." use_fast_dataloader: bool = False disable_sharding: bool = False train_test_split: Optional[Any] = Field(default=None, deprecated=True) class_id: Optional[Any] = Field(default=None, deprecated=True) def post_init(self, context): super().post_init(context) if self.overfit_num_batches is None: self.overfit_num_batches = num_tasks() * self.num_workers
[docs]class InriaAerialDataProcessor(UNetDataProcessor): def __init__(self, config: InriaAerialDataProcessorConfig): super().__init__(config) self.split = config.split self.use_worker_cache = config.use_worker_cache self.shuffle = self.shuffle and self.split == "train" self.image_shape = config.image_shape # of format (H, W, C) # Debug params: self.overfit = config.overfit # default is that each activation worker sends `num_workers` # batches so total batch_size * num_act_workers * num_pytorch_workers samples self.overfit_num_batches = config.overfit_num_batches self.random_indices = config.overfit_indices if self.overfit: logging.info(f"---- Overfitting {self.overfit_num_batches}! ----") # Using Faster Dataloader for mapstyle dataset. self.use_fast_dataloader = config.use_fast_dataloader self.disable_sharding = config.disable_sharding def create_dataset(self): dataset = InriaAerialDataset( root=self.data_dir, split=self.split, transform=self.preprocess_image, use_worker_cache=self.use_worker_cache, ) if self.overfit: random.seed(self.shuffle_seed) if self.random_indices is None: indices = random.sample( range(0, len(dataset)), self.overfit_num_batches * self.batch_size, ) else: indices = self.random_indices dataset = VisionSubset(dataset, indices) logging.info(f"---- Overfitting {indices}! ----") return dataset def create_dataloader(self): dataset = self.create_dataset() generator_fn = torch.Generator(device="cpu") if self.shuffle_seed is not None: generator_fn.manual_seed(self.shuffle_seed) if self.shuffle: if self.duplicate_act_worker_data: # Multiples activation workers, each sending same data in different # order since the dataset is extremely small if self.shuffle_seed is None: seed = task_id() else: seed = self.shuffle_seed + task_id() generator_fn.manual_seed(seed) data_sampler = torch.utils.data.RandomSampler( dataset, generator=generator_fn ) else: data_sampler = ShardedSampler( dataset, self.shuffle, self.shuffle_seed, self.drop_last ) else: data_sampler = torch.utils.data.SequentialSampler(dataset) num_samples_per_task = len(data_sampler) assert num_samples_per_task >= self.batch_size, ( f"Number of samples available per task(={num_samples_per_task}) is less than " f"batch_size(={self.batch_size})" ) if self.use_fast_dataloader: dataloader_fn = FastDataLoader else: dataloader_fn = torch.utils.data.DataLoader if self.num_workers: dataloader = dataloader_fn( dataset, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, drop_last=self.drop_last, generator=generator_fn, sampler=data_sampler, ) else: dataloader = dataloader_fn( dataset, batch_size=self.batch_size, drop_last=self.drop_last, generator=generator_fn, sampler=data_sampler, ) return dataloader def _apply_normalization( self, image, normalize_data_method, *args, **kwargs ): return normalize_tensor_transform(image, normalize_data_method) def preprocess_image(self, image): if self.image_shape[-1] == 1: image = image.convert( "L" ) # convert PILImage to grayscale (H, W, 1) # converts to (C, H, W) format. to_tensor_transform = transforms.PILToTensor() # Normalize normalize_transform = LambdaWithParam( self._apply_normalization, self.normalize_data_method ) transforms_list = [ to_tensor_transform, normalize_transform, ] image = transforms.Compose(transforms_list)(image) return image def preprocess_mask(self, mask): to_tensor_transform = transforms.PILToTensor() normalize_transform = LambdaWithParam( self._apply_normalization, "zero_one" ) transforms_list = [ to_tensor_transform, normalize_transform, ] mask = transforms.Compose(transforms_list)( mask ) # output of shape (1, 5000, 5000) return mask def transform_image_and_mask(self, image, mask): image = self.preprocess_image(image) mask = self.preprocess_mask(mask) if self.augment_data: do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5 # n_rots in range [0, 3) n_rotations = torch.randint(low=0, high=3, size=(1,)).item() if self.image_shape[0] != self.image_shape[1]: # H != W # For a rectangle image n_rotations = n_rotations * 2 augment_transform_image = self.get_augment_transforms( do_horizontal_flip=do_horizontal_flip, n_rotations=n_rotations, do_random_brightness=True, ) augment_transform_mask = self.get_augment_transforms( do_horizontal_flip=do_horizontal_flip, n_rotations=n_rotations, do_random_brightness=False, ) image = augment_transform_image(image) mask = augment_transform_mask(mask) # Handle dtypes and mask shapes based on `loss_type` # and `mixed_precsion` if self.loss_type == "bce": mask = mask.to(self.mp_type) if cstorch.amp.mixed_precision(): image = image.to(self.mp_type) return image, mask def get_augment_transforms( self, do_horizontal_flip, n_rotations, do_random_brightness ): augment_transforms_list = [] if do_horizontal_flip: horizontal_flip_transform = transforms.Lambda( transforms.functional.hflip ) augment_transforms_list.append(horizontal_flip_transform) if n_rotations > 0: rotation_transform = transforms.Lambda( lambda x: rotation_90_transform(x, num_rotations=n_rotations) ) augment_transforms_list.append(rotation_transform) if do_random_brightness: brightness_transform = transforms.Lambda( lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2) ) augment_transforms_list.append(brightness_transform) return transforms.Compose(augment_transforms_list)