# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.datasets import VisionDataset
import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.UNetDataProcessor import (
    UNetDataProcessor,
)
from cerebras.modelzoo.data.vision.utils import create_worker_cache
[docs]class SeverstalBinaryClassDataset(VisionDataset):
[docs]    def __init__(
        self,
        root,
        train_test_split,
        class_id_to_consider,
        split="train",
        transforms=None,
        transform=None,
        target_transform=None,
        use_worker_cache=False,
    ):
        super(SeverstalBinaryClassDataset, self).__init__(
            root, transforms, transform, target_transform
        )
        self.train_test_split = train_test_split
        assert class_id_to_consider <= 4, "Maximum 4 available classes."
        self.class_id_to_consider = class_id_to_consider
        if use_worker_cache and dist.is_streamer():
            if not cstorch.use_cs():
                raise RuntimeError(
                    "use_worker_cache not supported for non-CS runs"
                )
            else:
                self.root = create_worker_cache(self.root)
        self.data_dir = self.root
        if split not in ["train", "val"]:
            raise ValueError(
                f"Invalid value={split} passed to `split` argument. "
                f"Valid are 'train' or 'val'"
            )
        self.split = split
        self.images_dir, self.csv_file_path = self._get_data_dirs()
        train_dataframe, val_dataframe = self._process_csv_file()
        if split == "train":
            self.data = train_dataframe
        elif split == "val":
            self.data = val_dataframe 
    def _get_data_dirs(self):
        images_dir = os.path.join(self.root, "train_images")
        csv_file = os.path.join(self.root, "train.csv")
        return images_dir, csv_file
    def _process_csv_file(self):
        """
        Function to read contents to csv file and make dataset splits
        """
        csv_data = pd.read_csv(self.csv_file_path)
        csv_data = csv_data[csv_data["ClassId"] == self.class_id_to_consider]
        self.total_rows = len(csv_data.index)
        try:
            temp_file = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
            self.class_id_dataset_path = temp_file.name
        finally:
            temp_file.close()
        csv_data.to_csv(self.class_id_dataset_path, index=True)
        # Get train-test splits.
        train_rows = int(np.floor(self.train_test_split * self.total_rows))
        train_data = csv_data[:train_rows]
        val_data = csv_data[train_rows:]
        return train_data, val_data
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item.
        """
        image_filename, class_id, encoded_pixels = self.data.iloc[index]
        image_file_path = os.path.join(self.images_dir, image_filename)
        image = Image.open(image_file_path).convert("L")  # PILImage
        # (W, H) = (1600, 256) is the standard image size for this dataset
        _W = 1600
        _H = 256
        target = torch.zeros(_W * _H, dtype=torch.int32)
        rle_list = encoded_pixels.split()  # Run Length Encoding
        if rle_list[0] != "-1":
            rle_numbers = [int(x) for x in rle_list]
            start_pixels = rle_numbers[::2]
            lengths = rle_numbers[1::2]
            # EncodedPixels are numbered from top to bottom,
            # then left to right: 1 is pixel (1,1), 2 is pixel (2,1), etc
            # Refer to: https://www.kaggle.com/c/severstal-steel-defect-detection/overview/evaluation
            for start, lgth in zip(start_pixels, lengths):
                start_loc = start - 1  # Since one-based encoding
                target[start_loc : start_loc + lgth] = 1
        target = torch.reshape(target, (_W, _H))
        target = torch.transpose(target, 0, 1)
        target = torch.unsqueeze(target, dim=0)  # outshape: (C, H, W)
        if self.transforms is not None:
            image, target = self.transforms(image, target)
        return image, target
    def __len__(self):
        return len(self.data.index) 
[docs]@registry.register_datasetprocessor("SeverstalBinaryClassDataProcessor")
class SeverstalBinaryClassDataProcessor(UNetDataProcessor):
[docs]    def __init__(self, params):
        super(SeverstalBinaryClassDataProcessor, self).__init__(params)
        self.use_worker_cache = params["use_worker_cache"]
        self.train_test_split = params["train_test_split"]
        self.class_id_to_consider = params["class_id"]
        self.image_shape = params["image_shape"]  # of format (H, W, C)
        self._tiling_image_shape = self.image_shape  # out format: (H, W, C)
        # Tiling param:
        # If `image_shape` < 1K x 2K, do not tile.
        # If `image_shape` > 1K x 2K in any dimension,
        #   first resize image to min(img_shape, max_image_shape)
        #   and then tile to target height and width specified in yaml
        # self.max_image_shape: (H, W)
        self.max_image_shape = params.get("max_image_shape", [256, 1600])
        self.image_shape = self._update_image_shape()
        (
            self.tgt_image_height,
            self.tgt_image_width,
            self.channels,
        ) = self.image_shape 
    def _update_image_shape(self):
        # image_shape is of format (H, W, C)
        image_shape = []
        for i in range(2):
            image_shape.append(
                min(self.image_shape[i], self.max_image_shape[i])
            )
        image_shape = (
            image_shape + self.image_shape[-1:]
        )  # Output shape format (H, W, C)
        return image_shape
    def create_dataset(self, is_training):
        split = "train" if is_training else "val"
        dataset = SeverstalBinaryClassDataset(
            root=self.data_dir,
            train_test_split=self.train_test_split,
            class_id_to_consider=self.class_id_to_consider,
            split=split,
            transforms=self.transform_image_and_mask,
            use_worker_cache=self.use_worker_cache,
        )
        return dataset
    def preprocess_mask(self, mask):
        # Resize
        resize_transform = transforms.Resize(
            [self.tgt_image_height, self.tgt_image_width],
            interpolation=transforms.InterpolationMode.NEAREST,
        )
        tile_transform = self.get_tile_transform()
        transforms_list = [resize_transform, tile_transform]
        mask = transforms.Compose(transforms_list)(mask)
        return mask