# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import cerebras.pytorch as cstorch
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
    adjust_brightness_transform,
    normalize_tensor_transform,
    rotation_90_transform,
    tile_image_transform,
)
from cerebras.modelzoo.data.vision.utils import (
    FastDataLoader,
    ShardedSampler,
    num_tasks,
    task_id,
)
[docs]@registry.register_datasetprocessor("UNetDataProcessor")
class UNetDataProcessor:
[docs]    def __init__(self, params):
        self.data_dir = params["data_dir"]
        self.num_classes = params["num_classes"]
        self.loss_type = params["loss"]
        self.normalize_data_method = params.get("normalize_data_method")
        self.shuffle_seed = params.get("shuffle_seed", None)
        if self.shuffle_seed is not None:
            torch.manual_seed(self.shuffle_seed)
        self.augment_data = params.get("augment_data", True)
        self.batch_size = get_streaming_batch_size(params["batch_size"])
        self.shuffle = params.get("shuffle", True)
        # Multi-processing params.
        self.num_workers = params.get("num_workers", 0)
        self.drop_last = params.get("drop_last", True)
        self.prefetch_factor = params.get("prefetch_factor", 10)
        self.persistent_workers = params.get("persistent_workers", True)
        self.mixed_precision = params.get("mixed_precision")
        if self.mixed_precision:
            self.mp_type = cstorch.amp.get_half_dtype()
        else:
            self.mp_type = torch.float32
        # Using Faster Dataloader for mapstyle dataset.
        self.use_fast_dataloader = params.get("use_fast_dataloader", False)
        # Each activation worker can access entire dataset when True
        self.duplicate_act_worker_data = params.get(
            "duplicate_act_worker_data", False
        ) 
    def create_dataloader(self, is_training=False):
        dataset = self.create_dataset(is_training)
        shuffle = self.shuffle and is_training
        generator_fn = torch.Generator(device="cpu")
        if self.shuffle_seed is not None:
            generator_fn.manual_seed(self.shuffle_seed)
        self.disable_sharding = False
        samples_per_task = len(dataset) // num_tasks()
        if self.batch_size > samples_per_task:
            print(
                f"Dataset size: {len(dataset)} too small for num_tasks: {num_tasks} and batch_size: {self.batch_size}, using duplicate data for activation workers..."
            )
            self.disable_sharding = True
        if shuffle:
            if self.duplicate_act_worker_data or self.disable_sharding:
                # Multiples activation workers, each sending same data in different
                # order since the dataset is extremely small
                if self.shuffle_seed is None:
                    seed = task_id()
                else:
                    seed = self.shuffle_seed + task_id()
                generator_fn.manual_seed(seed)
                data_sampler = torch.utils.data.RandomSampler(
                    dataset, generator=generator_fn
                )
            else:
                data_sampler = ShardedSampler(
                    dataset, shuffle, self.shuffle_seed, self.drop_last
                )
        else:
            data_sampler = torch.utils.data.SequentialSampler(dataset)
        if self.use_fast_dataloader:
            dataloader_fn = FastDataLoader
            print("-- Using FastDataloader -- ")
        else:
            dataloader_fn = torch.utils.data.DataLoader
            print("-- Using torch.utils.data.DataLoader -- ")
        if self.num_workers:
            dataloader = dataloader_fn(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                prefetch_factor=self.prefetch_factor,
                persistent_workers=self.persistent_workers,
                drop_last=self.drop_last,
                generator=generator_fn,
                sampler=data_sampler,
            )
        else:
            dataloader = dataloader_fn(
                dataset,
                batch_size=self.batch_size,
                drop_last=self.drop_last,
                generator=generator_fn,
                sampler=data_sampler,
            )
        return dataloader
    def transform_image_and_mask(self, image, mask):
        image = self.preprocess_image(image)
        mask = self.preprocess_mask(mask)
        if self.augment_data:
            do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
            # n_rots in range [0, 3)
            n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
            if self.tgt_image_height != self.tgt_image_width:
                # For a rectangle image
                n_rotations = n_rotations * 2
            augment_transform_image = self.get_augment_transforms(
                do_horizontal_flip=do_horizontal_flip,
                n_rotations=n_rotations,
                do_random_brightness=True,
            )
            augment_transform_mask = self.get_augment_transforms(
                do_horizontal_flip=do_horizontal_flip,
                n_rotations=n_rotations,
                do_random_brightness=False,
            )
            image = augment_transform_image(image)
            mask = augment_transform_mask(mask)
        # Handle dtypes and mask shapes based on `loss_type`
        # and `mixed_precsion`
        if self.loss_type == "bce":
            mask = mask.to(self.mp_type)
        elif self.loss_type == "multilabel_bce":
            mask = torch.squeeze(mask, 0)
            # Only long tensors are accepted by one_hot fcn.
            mask = mask.to(torch.long)
            # out shape: ((D), H, W, num_classes)
            mask = torch.nn.functional.one_hot(
                mask, num_classes=self.num_classes
            )
            # out shape: (num_classes, (D), H, W)
            mask_axes = [_ for _ in range(len(mask.shape))]
            mask = torch.permute(mask, mask_axes[-1:] + mask_axes[0:-1])
            mask = mask.to(self.mp_type)
        elif self.loss_type == "ssce":
            # out shape: ((D), H, W) with each value in [0, num_classes)
            mask = torch.squeeze(mask, 0)
            mask = mask.to(torch.int32)
        if self.mixed_precision:
            image = image.to(self.mp_type)
        return image, mask
    def preprocess_image(self, image):
        # converts to (C, (D), H, W) format.
        to_tensor_transform = transforms.PILToTensor()
        # Resize and convert to torch.Tensor
        resize_pil_transform = transforms.Resize(
            [self.tgt_image_height, self.tgt_image_width],
            interpolation=transforms.InterpolationMode.BICUBIC,
            antialias=True,
        )
        # Tiling when image shape qualifies
        tile_transform = self.get_tile_transform()
        # Normalize
        normalize_transform = transforms.Lambda(
            lambda x: normalize_tensor_transform(
                x, normalize_data_method=self.normalize_data_method
            )
        )
        transforms_list = [
            to_tensor_transform,
            resize_pil_transform,
            tile_transform,
            normalize_transform,
        ]
        image = transforms.Compose(transforms_list)(image)
        return image
    def preprocess_mask(self, mask):
        tile_transform = self.get_tile_transform()
        return tile_transform(mask)
    def get_augment_transforms(
        self, do_horizontal_flip, n_rotations, do_random_brightness
    ):
        augment_transforms_list = []
        if do_horizontal_flip:
            horizontal_flip_transform = transforms.Lambda(
                lambda x: transforms.functional.hflip(x)
            )
            augment_transforms_list.append(horizontal_flip_transform)
        if n_rotations > 0:
            rotation_transform = transforms.Lambda(
                lambda x: rotation_90_transform(x, num_rotations=n_rotations)
            )
            augment_transforms_list.append(rotation_transform)
        if do_random_brightness:
            brightness_transform = transforms.Lambda(
                lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
            )
            augment_transforms_list.append(brightness_transform)
        return transforms.Compose(augment_transforms_list)
    @property
    def tiling_image_shape(self):
        if not hasattr(self, "_tiling_image_shape"):
            raise AttributeError(
                "_tiling_image_shape not defined. "
                + "Please set it in __init__ method of DataProcessor of child class. "
                + "Format is (H, W, C)"
            )
        return self._tiling_image_shape
    def get_tile_transform(self):
        tiling_image_height, tiling_image_width = (
            self.tiling_image_shape[0],
            self.tiling_image_shape[1],
        )
        tile_transform = transforms.Lambda(
            lambda x: tile_image_transform(
                x, tiling_image_height, tiling_image_width
            )
        )
        return tile_transform 
[docs]def visualize_dataset(dataset, num_samples=3):
    figure = plt.figure(figsize=(10, 10))
    rows, cols = num_samples, 2
    for i in range(1, cols * rows + 1, 2):
        sample_idx = torch.randint(len(dataset), size=(1,)).item()
        img, label = dataset[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.axis("off")
        plt.imshow(img.permute(1, 2, 0) / torch.max(img))
        figure.add_subplot(rows, cols, i + 1)
        plt.axis("off")
        plt.imshow(label / torch.max(label))