Source code for cerebras.modelzoo.data.vision.segmentation.UNetDataProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, List, Literal, Optional, Union

import matplotlib.pyplot as plt
import torch
from pydantic import Field, field_validator
from torchvision import transforms

import cerebras.pytorch as cstorch
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.config import DataConfig
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
    adjust_brightness_transform,
    normalize_tensor_transform,
    rotation_90_transform,
    tile_image_transform,
)
from cerebras.modelzoo.data.vision.utils import (
    FastDataLoader,
    ShardedSampler,
    num_tasks,
    task_id,
)


[docs]class UNetDataProcessorConfig(DataConfig): data_processor: Literal["UNetDataProcessor"] data_dir: Union[str, List[str]] = ... num_classes: Optional[int] = None loss: Optional[Literal["bce", "multilabel_bce", "ssce", "ssce_dice"]] = None normalize_data_method: Optional[ Literal["zero_centered", "zero_one", "standard_score"] ] = None shuffle: bool = True shuffle_seed: Optional[int] = None augment_data: bool = True batch_size: int = ... num_workers: int = 0 image_shape: Optional[List[int]] = None drop_last: bool = True prefetch_factor: Optional[int] = 10 persistent_workers: bool = True use_fast_dataloader: bool = True duplicate_act_worker_data: bool = False convert_to_onehot: Optional[bool] = None fp16_type: Optional[Any] = Field(default=None, deprecated=True) mixed_precision: Optional[Any] = Field(default=None, deprecated=True) @field_validator("convert_to_onehot") @classmethod def set_convert_to_onehot(cls, convert_to_onehot, info): if info.context: model_config = info.context.get("model", {}).get("config") if convert_to_onehot is None: convert_to_onehot = model_config.loss == "multilabel_bce" return convert_to_onehot def post_init(self, context): model_config = context.get("model", {}).get("config") if model_config is not None: if hasattr(model_config, "image_shape"): self.image_shape = model_config.image_shape if hasattr(model_config, "num_classes"): self.num_classes = model_config.num_classes if "loss" in self.model_fields_set: logging.warning( "Loss cannot be set in data configuration. " "Defaulting to value in model configuration." ) if hasattr(model_config, "loss"): self.loss = model_config.loss if any( x is None for x in [ self.image_shape, self.num_classes, self.loss, ] ): raise ValueError( "image_shape, num_classes and loss must " "be configured from the model config." )
[docs]class UNetDataProcessor: def __init__(self, config: UNetDataProcessorConfig): self.data_dir = config.data_dir self.num_classes = config.num_classes self.loss_type = config.loss self.normalize_data_method = config.normalize_data_method self.shuffle_seed = config.shuffle_seed if self.shuffle_seed is not None: torch.manual_seed(self.shuffle_seed) self.augment_data = config.augment_data self.batch_size = get_streaming_batch_size(config.batch_size) self.shuffle = config.shuffle # Multi-processing params. self.num_workers = config.num_workers self.drop_last = config.drop_last self.prefetch_factor = config.prefetch_factor self.persistent_workers = config.persistent_workers self.mp_type = cstorch.amp.get_floating_point_dtype() # Using Faster Dataloader for mapstyle dataset. self.use_fast_dataloader = config.use_fast_dataloader # Each activation worker can access entire dataset when True self.duplicate_act_worker_data = config.duplicate_act_worker_data def create_dataset(): raise NotImplementedError( "create_dataset must be implemented in a child class!!" ) def create_dataloader(self): dataset = self.create_dataset() generator_fn = torch.Generator(device="cpu") if self.shuffle_seed is not None: generator_fn.manual_seed(self.shuffle_seed) self.disable_sharding = False samples_per_task = len(dataset) // num_tasks() if self.batch_size > samples_per_task: print( f"Dataset size: {len(dataset)} too small for num_tasks:" f" {num_tasks} and batch_size: {self.batch_size}," " using duplicate data for activation workers..." ) self.disable_sharding = True if self.shuffle: if self.duplicate_act_worker_data or self.disable_sharding: # Multiples activation workers, each sending same data in different # order since the dataset is extremely small if self.shuffle_seed is None: seed = task_id() else: seed = self.shuffle_seed + task_id() generator_fn.manual_seed(seed) data_sampler = torch.utils.data.RandomSampler( dataset, generator=generator_fn ) else: data_sampler = ShardedSampler( dataset, self.shuffle, self.shuffle_seed, self.drop_last ) else: data_sampler = torch.utils.data.SequentialSampler(dataset) if self.use_fast_dataloader: dataloader_fn = FastDataLoader print("-- Using FastDataloader -- ") else: dataloader_fn = torch.utils.data.DataLoader print("-- Using torch.utils.data.DataLoader -- ") if self.num_workers: dataloader = dataloader_fn( dataset, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, drop_last=self.drop_last, generator=generator_fn, sampler=data_sampler, ) else: dataloader = dataloader_fn( dataset, batch_size=self.batch_size, drop_last=self.drop_last, generator=generator_fn, sampler=data_sampler, ) return dataloader def transform_image_and_mask(self, image, mask): image = self.preprocess_image(image) mask = self.preprocess_mask(mask) if self.augment_data: do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5 # n_rots in range [0, 3) n_rotations = torch.randint(low=0, high=3, size=(1,)).item() if self.tgt_image_height != self.tgt_image_width: # For a rectangle image n_rotations = n_rotations * 2 augment_transform_image = self.get_augment_transforms( do_horizontal_flip=do_horizontal_flip, n_rotations=n_rotations, do_random_brightness=True, ) augment_transform_mask = self.get_augment_transforms( do_horizontal_flip=do_horizontal_flip, n_rotations=n_rotations, do_random_brightness=False, ) image = augment_transform_image(image) mask = augment_transform_mask(mask) # Handle dtypes and mask shapes based on `loss_type` # and `mixed_precsion` if self.loss_type == "bce": mask = mask.to(self.mp_type) elif self.loss_type == "multilabel_bce": mask = torch.squeeze(mask, 0) # Only long tensors are accepted by one_hot fcn. mask = mask.to(torch.long) # out shape: ((D), H, W, num_classes) mask = torch.nn.functional.one_hot( mask, num_classes=self.num_classes ) # out shape: (num_classes, (D), H, W) mask_axes = [_ for _ in range(len(mask.shape))] mask = torch.permute(mask, mask_axes[-1:] + mask_axes[0:-1]) mask = mask.to(self.mp_type) elif self.loss_type == "ssce": # out shape: ((D), H, W) with each value in [0, num_classes) mask = torch.squeeze(mask, 0) mask = mask.to(torch.int32) if cstorch.amp.mixed_precision(): image = image.to(self.mp_type) return image, mask def preprocess_image(self, image): # converts to (C, (D), H, W) format. to_tensor_transform = transforms.PILToTensor() # Resize and convert to torch.Tensor resize_pil_transform = transforms.Resize( [self.tgt_image_height, self.tgt_image_width], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True, ) # Tiling when image shape qualifies tile_transform = self.get_tile_transform() # Normalize normalize_transform = transforms.Lambda( lambda x: normalize_tensor_transform( x, normalize_data_method=self.normalize_data_method ) ) transforms_list = [ to_tensor_transform, resize_pil_transform, tile_transform, normalize_transform, ] image = transforms.Compose(transforms_list)(image) return image def preprocess_mask(self, mask): tile_transform = self.get_tile_transform() return tile_transform(mask) def get_augment_transforms( self, do_horizontal_flip, n_rotations, do_random_brightness ): augment_transforms_list = [] if do_horizontal_flip: horizontal_flip_transform = transforms.Lambda( lambda x: transforms.functional.hflip(x) ) augment_transforms_list.append(horizontal_flip_transform) if n_rotations > 0: rotation_transform = transforms.Lambda( lambda x: rotation_90_transform(x, num_rotations=n_rotations) ) augment_transforms_list.append(rotation_transform) if do_random_brightness: brightness_transform = transforms.Lambda( lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2) ) augment_transforms_list.append(brightness_transform) return transforms.Compose(augment_transforms_list) @property def tiling_image_shape(self): if not hasattr(self, "_tiling_image_shape"): raise AttributeError( "_tiling_image_shape not defined. " + "Please set it in __init__ method of DataProcessor of child class. " + "Format is (H, W, C)" ) return self._tiling_image_shape def get_tile_transform(self): tiling_image_height, tiling_image_width = ( self.tiling_image_shape[0], self.tiling_image_shape[1], ) tile_transform = transforms.Lambda( lambda x: tile_image_transform( x, tiling_image_height, tiling_image_width ) ) return tile_transform
[docs]def visualize_dataset(dataset, num_samples=3): figure = plt.figure(figsize=(10, 10)) rows, cols = num_samples, 2 for i in range(1, cols * rows + 1, 2): sample_idx = torch.randint(len(dataset), size=(1,)).item() img, label = dataset[sample_idx] figure.add_subplot(rows, cols, i) plt.axis("off") plt.imshow(img.permute(1, 2, 0) / torch.max(img)) figure.add_subplot(rows, cols, i + 1) plt.axis("off") plt.imshow(label / torch.max(label))