# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision import datasets, transforms
import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.UNetDataProcessor import (
    UNetDataProcessor,
)
from cerebras.modelzoo.data.vision.utils import create_worker_cache
[docs]class Cityscapes(datasets.Cityscapes):
    """Wrapper around torchvision.datasets.Cityscapes with sorted files for reproducibility"""
[docs]    def __init__(self, use_worker_cache=False, **kwargs):
        super(Cityscapes, self).__init__(**kwargs)
        if use_worker_cache and dist.is_streamer():
            if not cstorch.use_cs():
                raise RuntimeError(
                    "use_worker_cache not supported for non-CS runs"
                )
            else:
                self.root = create_worker_cache(self.root)
        self.images, self.targets = zip(*sorted(zip(self.images, self.targets)))  
[docs]@registry.register_datasetprocessor("CityscapesDataProcessor")
class CityscapesDataProcessor(UNetDataProcessor):
[docs]    def __init__(self, params):
        super(CityscapesDataProcessor, self).__init__(params)
        self.use_worker_cache = params["use_worker_cache"]
        self.image_shape = params["image_shape"]  # of format (H, W, C)
        self._tiling_image_shape = self.image_shape  # out format: (H, W, C)
        # Tiling param:
        # If `image_shape` < 1K x 2K, do not tile.
        # If `image_shape` > 1K x 2K in any dimension,
        #   first resize image to min(img_shape, max_image_shape)
        #   and then tile to target height and width specified in yaml
        self.max_image_shape = params.get("max_image_shape", [1024, 2048])
        self.image_shape = self._update_image_shape()
        (
            self.tgt_image_height,
            self.tgt_image_width,
            self.channels,
        ) = self.image_shape 
    def _update_image_shape(self):
        # image_shape is of format (H, W, C)
        image_shape = []
        for i in range(2):
            image_shape.append(
                min(self.image_shape[i], self.max_image_shape[i])
            )
        image_shape = (
            image_shape + self.image_shape[-1:]
        )  # Output shape format (H, W, C)
        return image_shape
    def create_dataset(self, is_training):
        split = "train" if is_training else "val"
        dataset = Cityscapes(
            root=self.data_dir,
            split=split,
            mode="fine",
            target_type="semantic",
            transforms=self.transform_image_and_mask,
            use_worker_cache=self.use_worker_cache,
        )
        return dataset
    def preprocess_mask(self, mask):
        # Refer to :
        # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py#L56-L99
        # Mapping all classes with `ignoreInEval`=True(from above link)
        # to background class with id 0
        def lookup_table(mask):
            # fmt: off
            lut = torch.tensor([
                0, 0, 0, 0, 0, 0, 0, 1, 2, 0,
                0, 3, 4, 5, 0, 0, 0, 6, 0, 7,
                8, 9, 10, 11, 12, 13, 14, 15,
                16, 0, 0, 17, 18, 19,
                ],
                dtype=torch.uint8,
            )
            # fmt: on
            return lut[mask]
        # Resize
        resize_pil_transform = transforms.Resize(
            [self.tgt_image_height, self.tgt_image_width],
            interpolation=transforms.InterpolationMode.NEAREST,
        )
        # converts to (C, H, W) format.
        to_tensor_transform = transforms.PILToTensor()
        # Convert to long for lookup
        convert_to_long_transform = transforms.Lambda(
            lambda x: x.to(torch.long)
        )
        # Map target ids based on lookup table
        lookup_table_transform = transforms.Lambda(lambda x: lookup_table(x))
        # Convert to mp type
        convert_to_mp_type_transform = transforms.Lambda(
            lambda x: x.to(self.mp_type)
        )
        tile_transform = self.get_tile_transform()
        transforms_list = [
            resize_pil_transform,
            to_tensor_transform,
            convert_to_long_transform,
            lookup_table_transform,
            convert_to_mp_type_transform,
            tile_transform,
        ]
        mask = transforms.Compose(transforms_list)(mask)
        return mask