# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import torch
from PIL import Image
from torchvision import transforms
from torchvision.datasets import VisionDataset
import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.classification.dataset_factory import (
VisionSubset,
)
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
adjust_brightness_transform,
normalize_tensor_transform,
rotation_90_transform,
)
from cerebras.modelzoo.data.vision.transforms import LambdaWithParam
from cerebras.modelzoo.data.vision.utils import (
FastDataLoader,
ShardedSampler,
create_worker_cache,
num_tasks,
task_id,
)
[docs]class InriaAerialDataset(VisionDataset):
[docs] def __init__(
self,
root,
split="train",
transforms=None, # pylint: disable=redefined-outer-name
transform=None,
target_transform=None,
use_worker_cache=False,
):
super(InriaAerialDataset, self).__init__(
root, transforms, transform, target_transform
)
if split not in ["train", "val", "test"]:
raise ValueError(
f"Invalid value={split} passed to `split` argument. "
f"Valid are 'train' or 'val' or 'test' "
)
self.split = split
if split == "test" and target_transform is not None:
raise ValueError(
f"split {split} has no mask images and hence target_transform should be None. "
f"Got {target_transform}."
)
if use_worker_cache and dist.is_streamer():
if not cstorch.use_cs():
raise RuntimeError(
"use_worker_cache not supported for non-CS runs"
)
else:
self.root = create_worker_cache(self.root)
self.data_dir = os.path.join(self.root, self.split)
self.image_dir = os.path.join(self.data_dir, "images")
self.mask_dir = os.path.join(self.data_dir, "gt")
self.file_list = sorted(os.listdir(self.image_dir))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is
a list with more than one item.
"""
image_file_path = os.path.join(self.image_dir, self.file_list[index])
image = Image.open(image_file_path) # 3-channel PILImage
mask_file_path = os.path.join(self.mask_dir, self.file_list[index])
target = Image.open(mask_file_path) # PILImage
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.file_list)
[docs]@registry.register_datasetprocessor("InriaAerialDataProcessor")
class InriaAerialDataProcessor:
[docs] def __init__(self, params):
self.use_worker_cache = params["use_worker_cache"]
self.data_dir = params["data_dir"]
self.num_classes = params["num_classes"]
self.image_shape = params["image_shape"] # of format (H, W, C)
self.duplicate_act_worker_data = params.get(
"duplicate_act_worker_data", False
)
self.loss_type = params["loss"]
self.normalize_data_method = params.get("normalize_data_method")
self.shuffle_seed = params.get("shuffle_seed", None)
if self.shuffle_seed is not None:
torch.manual_seed(self.shuffle_seed)
self.augment_data = params.get("augment_data", True)
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.shuffle = params.get("shuffle", True)
# Multi-processing params.
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
self.prefetch_factor = params.get("prefetch_factor", 10)
self.persistent_workers = params.get("persistent_workers", True)
self.mixed_precision = params.get("mixed_precision")
if self.mixed_precision:
self.mp_type = cstorch.amp.get_half_dtype()
else:
self.mp_type = torch.float32
# Debug params:
self.overfit = params.get("overfit", False)
# default is that each activation worker sends `num_workers`
# batches so total batch_size * num_act_workers * num_pytorch_workers samples
self.overfit_num_batches = params.get(
"overfit_num_batches", num_tasks() * self.num_workers
)
self.random_indices = params.get("overfit_indices", None)
if self.overfit:
logging.info(f"---- Overfitting {self.overfit_num_batches}! ----")
# Using Faster Dataloader for mapstyle dataset.
self.use_fast_dataloader = params.get("use_fast_dataloader", False)
def create_dataset(self, is_training):
split = "train" if is_training else "val"
dataset = InriaAerialDataset(
root=self.data_dir,
split=split,
transforms=self.transform_image_and_mask,
use_worker_cache=self.use_worker_cache,
)
if self.overfit:
random.seed(self.shuffle_seed)
if self.random_indices is None:
indices = random.sample(
range(0, len(dataset)),
self.overfit_num_batches * self.batch_size,
)
else:
indices = self.random_indices
dataset = VisionSubset(dataset, indices)
logging.info(f"---- Overfitting {indices}! ----")
return dataset
def create_dataloader(self, is_training=False):
dataset = self.create_dataset(is_training)
shuffle = self.shuffle and is_training
generator_fn = torch.Generator(device="cpu")
if self.shuffle_seed is not None:
generator_fn.manual_seed(self.shuffle_seed)
if shuffle:
if self.duplicate_act_worker_data:
# Multiples activation workers, each sending same data in different
# order since the dataset is extremely small
if self.shuffle_seed is None:
seed = task_id()
else:
seed = self.shuffle_seed + task_id()
generator_fn.manual_seed(seed)
data_sampler = torch.utils.data.RandomSampler(
dataset, generator=generator_fn
)
else:
data_sampler = ShardedSampler(
dataset, shuffle, self.shuffle_seed, self.drop_last
)
else:
data_sampler = torch.utils.data.SequentialSampler(dataset)
num_samples_per_task = len(data_sampler)
assert num_samples_per_task >= self.batch_size, (
f"Number of samples available per task(={num_samples_per_task}) is less than "
f"batch_size(={self.batch_size})"
)
if self.use_fast_dataloader:
dataloader_fn = FastDataLoader
else:
dataloader_fn = torch.utils.data.DataLoader
if self.num_workers:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
else:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
return dataloader
def _apply_normalization(
self, image, normalize_data_method, *args, **kwargs
):
return normalize_tensor_transform(image, normalize_data_method)
def preprocess_image(self, image):
if self.image_shape[-1] == 1:
image = image.convert(
"L"
) # convert PILImage to grayscale (H, W, 1)
# converts to (C, H, W) format.
to_tensor_transform = transforms.PILToTensor()
# Normalize
normalize_transform = LambdaWithParam(
self._apply_normalization, self.normalize_data_method
)
transforms_list = [
to_tensor_transform,
normalize_transform,
]
image = transforms.Compose(transforms_list)(image)
return image
def preprocess_mask(self, mask):
to_tensor_transform = transforms.PILToTensor()
normalize_transform = LambdaWithParam(
self._apply_normalization, "zero_one"
)
transforms_list = [
to_tensor_transform,
normalize_transform,
]
mask = transforms.Compose(transforms_list)(
mask
) # output of shape (1, 5000, 5000)
return mask
def transform_image_and_mask(self, image, mask):
image = self.preprocess_image(image)
mask = self.preprocess_mask(mask)
if self.augment_data:
do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
# n_rots in range [0, 3)
n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
if self.image_shape[0] != self.image_shape[1]: # H != W
# For a rectangle image
n_rotations = n_rotations * 2
augment_transform_image = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=True,
)
augment_transform_mask = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=False,
)
image = augment_transform_image(image)
mask = augment_transform_mask(mask)
# Handle dtypes and mask shapes based on `loss_type`
# and `mixed_precsion`
if self.loss_type == "bce":
mask = mask.to(self.mp_type)
if self.mixed_precision:
image = image.to(self.mp_type)
return image, mask
def get_augment_transforms(
self, do_horizontal_flip, n_rotations, do_random_brightness
):
augment_transforms_list = []
if do_horizontal_flip:
horizontal_flip_transform = transforms.Lambda(
transforms.functional.hflip
)
augment_transforms_list.append(horizontal_flip_transform)
if n_rotations > 0:
rotation_transform = transforms.Lambda(
lambda x: rotation_90_transform(x, num_rotations=n_rotations)
)
augment_transforms_list.append(rotation_transform)
if do_random_brightness:
brightness_transform = transforms.Lambda(
lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
)
augment_transforms_list.append(brightness_transform)
return transforms.Compose(augment_transforms_list)