# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import cerebras.pytorch as cstorch
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
adjust_brightness_transform,
normalize_tensor_transform,
rotation_90_transform,
tile_image_transform,
)
from cerebras.modelzoo.data.vision.utils import (
FastDataLoader,
ShardedSampler,
num_tasks,
task_id,
)
[docs]@registry.register_datasetprocessor("UNetDataProcessor")
class UNetDataProcessor:
[docs] def __init__(self, params):
self.data_dir = params["data_dir"]
self.num_classes = params["num_classes"]
self.loss_type = params["loss"]
self.normalize_data_method = params.get("normalize_data_method")
self.shuffle_seed = params.get("shuffle_seed", None)
if self.shuffle_seed is not None:
torch.manual_seed(self.shuffle_seed)
self.augment_data = params.get("augment_data", True)
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.shuffle = params.get("shuffle", True)
# Multi-processing params.
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
self.prefetch_factor = params.get("prefetch_factor", 10)
self.persistent_workers = params.get("persistent_workers", True)
self.mixed_precision = params.get("mixed_precision")
if self.mixed_precision:
self.mp_type = cstorch.amp.get_half_dtype()
else:
self.mp_type = torch.float32
# Using Faster Dataloader for mapstyle dataset.
self.use_fast_dataloader = params.get("use_fast_dataloader", False)
# Each activation worker can access entire dataset when True
self.duplicate_act_worker_data = params.get(
"duplicate_act_worker_data", False
)
def create_dataloader(self, is_training=False):
dataset = self.create_dataset(is_training)
shuffle = self.shuffle and is_training
generator_fn = torch.Generator(device="cpu")
if self.shuffle_seed is not None:
generator_fn.manual_seed(self.shuffle_seed)
self.disable_sharding = False
samples_per_task = len(dataset) // num_tasks()
if self.batch_size > samples_per_task:
print(
f"Dataset size: {len(dataset)} too small for num_tasks: {num_tasks} and batch_size: {self.batch_size}, using duplicate data for activation workers..."
)
self.disable_sharding = True
if shuffle:
if self.duplicate_act_worker_data or self.disable_sharding:
# Multiples activation workers, each sending same data in different
# order since the dataset is extremely small
if self.shuffle_seed is None:
seed = task_id()
else:
seed = self.shuffle_seed + task_id()
generator_fn.manual_seed(seed)
data_sampler = torch.utils.data.RandomSampler(
dataset, generator=generator_fn
)
else:
data_sampler = ShardedSampler(
dataset, shuffle, self.shuffle_seed, self.drop_last
)
else:
data_sampler = torch.utils.data.SequentialSampler(dataset)
if self.use_fast_dataloader:
dataloader_fn = FastDataLoader
print("-- Using FastDataloader -- ")
else:
dataloader_fn = torch.utils.data.DataLoader
print("-- Using torch.utils.data.DataLoader -- ")
if self.num_workers:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
else:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
return dataloader
def transform_image_and_mask(self, image, mask):
image = self.preprocess_image(image)
mask = self.preprocess_mask(mask)
if self.augment_data:
do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
# n_rots in range [0, 3)
n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
if self.tgt_image_height != self.tgt_image_width:
# For a rectangle image
n_rotations = n_rotations * 2
augment_transform_image = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=True,
)
augment_transform_mask = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=False,
)
image = augment_transform_image(image)
mask = augment_transform_mask(mask)
# Handle dtypes and mask shapes based on `loss_type`
# and `mixed_precsion`
if self.loss_type == "bce":
mask = mask.to(self.mp_type)
elif self.loss_type == "multilabel_bce":
mask = torch.squeeze(mask, 0)
# Only long tensors are accepted by one_hot fcn.
mask = mask.to(torch.long)
# out shape: ((D), H, W, num_classes)
mask = torch.nn.functional.one_hot(
mask, num_classes=self.num_classes
)
# out shape: (num_classes, (D), H, W)
mask_axes = [_ for _ in range(len(mask.shape))]
mask = torch.permute(mask, mask_axes[-1:] + mask_axes[0:-1])
mask = mask.to(self.mp_type)
elif self.loss_type == "ssce":
# out shape: ((D), H, W) with each value in [0, num_classes)
mask = torch.squeeze(mask, 0)
mask = mask.to(torch.int32)
if self.mixed_precision:
image = image.to(self.mp_type)
return image, mask
def preprocess_image(self, image):
# converts to (C, (D), H, W) format.
to_tensor_transform = transforms.PILToTensor()
# Resize and convert to torch.Tensor
resize_pil_transform = transforms.Resize(
[self.tgt_image_height, self.tgt_image_width],
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True,
)
# Tiling when image shape qualifies
tile_transform = self.get_tile_transform()
# Normalize
normalize_transform = transforms.Lambda(
lambda x: normalize_tensor_transform(
x, normalize_data_method=self.normalize_data_method
)
)
transforms_list = [
to_tensor_transform,
resize_pil_transform,
tile_transform,
normalize_transform,
]
image = transforms.Compose(transforms_list)(image)
return image
def preprocess_mask(self, mask):
tile_transform = self.get_tile_transform()
return tile_transform(mask)
def get_augment_transforms(
self, do_horizontal_flip, n_rotations, do_random_brightness
):
augment_transforms_list = []
if do_horizontal_flip:
horizontal_flip_transform = transforms.Lambda(
lambda x: transforms.functional.hflip(x)
)
augment_transforms_list.append(horizontal_flip_transform)
if n_rotations > 0:
rotation_transform = transforms.Lambda(
lambda x: rotation_90_transform(x, num_rotations=n_rotations)
)
augment_transforms_list.append(rotation_transform)
if do_random_brightness:
brightness_transform = transforms.Lambda(
lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
)
augment_transforms_list.append(brightness_transform)
return transforms.Compose(augment_transforms_list)
@property
def tiling_image_shape(self):
if not hasattr(self, "_tiling_image_shape"):
raise AttributeError(
"_tiling_image_shape not defined. "
+ "Please set it in __init__ method of DataProcessor of child class. "
+ "Format is (H, W, C)"
)
return self._tiling_image_shape
def get_tile_transform(self):
tiling_image_height, tiling_image_width = (
self.tiling_image_shape[0],
self.tiling_image_shape[1],
)
tile_transform = transforms.Lambda(
lambda x: tile_image_transform(
x, tiling_image_height, tiling_image_width
)
)
return tile_transform
[docs]def visualize_dataset(dataset, num_samples=3):
figure = plt.figure(figsize=(10, 10))
rows, cols = num_samples, 2
for i in range(1, cols * rows + 1, 2):
sample_idx = torch.randint(len(dataset), size=(1,)).item()
img, label = dataset[sample_idx]
figure.add_subplot(rows, cols, i)
plt.axis("off")
plt.imshow(img.permute(1, 2, 0) / torch.max(img))
figure.add_subplot(rows, cols, i + 1)
plt.axis("off")
plt.imshow(label / torch.max(label))