# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.Hdf5BaseDataProcessor import (
    Hdf5BaseDataProcessor,
)
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
    adjust_brightness_transform,
    normalize_tensor_transform,
    rotation_90_transform,
)
[docs]@registry.register_datasetprocessor("Hdf5DataProcessor")
class Hdf5DataProcessor(Hdf5BaseDataProcessor):
    """
    A HDF5 dataset processor for UNet HDF dataset.
    Performs on-the-fly augmentation of image and labek.
    Functionality includes:
        Reading data from HDF5 documents
        Augmenting data
    :param dict params: dict containing training
        input parameters for creating dataset.
    Expects the following fields:
    - "data_dir" (str or list of str): Path to dataset HDF5 files
    - "num_classes (int): Maximum length of the sequence to generate
    - "image_shape" (int): Expected shape of output images and label, used in assert checks.
    - "loss" (str): Loss type, supported: {"bce", "multilabel_bce", "ssce"}
    - "normalize_data_method" (str): Can be one of {None, "zero_centered", "zero_one"}
    - "batch_size" (int): Batch size.
    - "shuffle" (bool): Flag to enable data shuffling.
    - "shuffle_buffer" (int): Size of shuffle buffer in samples.
    - "shuffle_seed" (int): Shuffle seed.
    - "num_workers" (int):  How many subprocesses to use for data loading.
    - "drop_last" (bool): If True and the dataset size is not divisible
       by the batch size, the last incomplete batch will be dropped.
    - "prefetch_factor" (int): Number of samples loaded in advance by each worker.
    - "persistent_workers" (bool): If True, the data loader will not shutdown
       the worker processes after a dataset has been consumed once.
    """
    def _shard_files(self, is_training=False):
        # Features in HDF5 record files
        self.features_list = ["image", "label"]
        assert self.batch_size > 0, "Batch size should be positive."
        p = Path(self.data_dir)
        assert p.is_dir()
        files = sorted(p.glob('*.h5'))
        if not files:
            raise RuntimeError('No hdf5 datasets found')
        all_files = [str(file.resolve()) for file in files]
        self.all_files = []
        self.files_in_this_task = []
        self.num_examples = 0
        self.num_examples_in_this_task = 0
        for idx, file_path in enumerate(all_files):
            with h5py.File(file_path, mode='r') as h5_file:
                num_examples_in_file = h5_file.attrs["n_examples"]
                file_details = (file_path, num_examples_in_file)
                self.all_files.append(file_details)
                self.num_examples += num_examples_in_file
                if idx % self.num_tasks == self.task_id:
                    self.files_in_this_task.append(file_details)
                    self.num_examples_in_this_task += num_examples_in_file
        # Prevent CoW which is effectively copy on read behavior for PT,
        # see: https://github.com/pytorch/pytorch/issues/13246
        self.all_files = pd.DataFrame(
            self.all_files, columns=["file_path", "num_examples_in_file"]
        )
        self.files_in_this_task = pd.DataFrame(
            self.files_in_this_task,
            columns=["file_path", "num_examples_in_file"],
        )
    def _apply_normalization(self, x):
        return normalize_tensor_transform(
            x, normalize_data_method=self.normalize_data_method
        )
    def _load_buffer(self, data_partitions):
        for file_path, start_idx, num_examples in data_partitions:
            with h5py.File(file_path, mode='r') as h5_file:
                for idx in range(start_idx, start_idx + num_examples):
                    yield h5_file[f"example_{idx}"]
    def _maybe_shard_dataset(self, num_workers):
        per_worker_partition = {}
        idx = 0
        files = (
            self.all_files if self.disable_sharding else self.files_in_this_task
        )
        for _, row in files.iterrows():
            # Try to evenly distribute number of examples between workers
            file_path = row["file_path"]
            num_examples_in_file = row["num_examples_in_file"]
            num_examples_all_workers = [
                (num_examples_in_file // num_workers)
            ] * num_workers
            for i in range(num_examples_in_file % num_workers):
                num_examples_all_workers[i] += 1
            assert sum(num_examples_all_workers) == num_examples_in_file
            for file_idx in range(num_examples_in_file):
                per_worker_partition[idx] = (file_path, f"example_{file_idx}")
                idx += 1
        return per_worker_partition
    def __len__(self):
        if self.disable_sharding:
            return self.num_examples
        else:
            return self.num_examples_in_this_task
    def __getitem__(self, index):
        """Get item at a particular index"""
        file_path, sample_name = self.data_partitions[index]
        example_dict = {}
        with h5py.File(file_path, mode='r') as h5_file:
            example = h5_file[sample_name]
            for _, feature in enumerate(self.features_list):
                example_dict[feature] = torch.from_numpy(
                    np.array(example[feature])
                )
            image, label = self.transform_image_and_mask(
                example_dict["image"], example_dict["label"]
            )
        return image, label
    def transform_image_and_mask(self, image, mask):
        if self.normalize_data_method:
            image = self.normalize_transform(image)
        if self.augment_data:
            do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
            # n_rots in range [0, 3)
            n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
            if self.tgt_image_height != self.tgt_image_width:
                # For a rectangle image
                n_rotations = n_rotations * 2
            augment_transform_image = self.get_augment_transforms(
                do_horizontal_flip=do_horizontal_flip,
                n_rotations=n_rotations,
                do_random_brightness=True,
            )
            augment_transform_mask = self.get_augment_transforms(
                do_horizontal_flip=do_horizontal_flip,
                n_rotations=n_rotations,
                do_random_brightness=False,
            )
            image = augment_transform_image(image)
            mask = augment_transform_mask(mask)
        # Handle dtypes and mask shapes based on `loss_type`
        # and `mixed_precsion`
        if self.loss_type == "bce":
            mask = mask.to(self.mp_type)
        elif self.loss_type == "multilabel_bce":
            mask = torch.squeeze(mask, 0)
            # Only long tensors are accepted by one_hot fcn.
            mask = mask.to(torch.long)
            # out shape: (H, W, num_classes)
            mask = torch.nn.functional.one_hot(
                mask, num_classes=self.num_classes
            )
            # out shape: (num_classes, H, W)
            mask = torch.permute(mask, [2, 0, 1])
            mask = mask.to(self.mp_type)
        elif self.loss_type == "ssce":
            # out shape: (H, W) with each value in [0, num_classes)
            mask = torch.squeeze(mask, 0)
            # TODO: Add MZ tags here when supported.
            # SW-82348 workaround: Pass `labels` in `int32``
            # PT crossentropy loss takes in `int64`,
            # view and typecast does not change the orginal `labels`.
            mask = mask.to(torch.int32)
        if self.mixed_precision:
            image = image.to(self.mp_type)
        return image, mask
    def get_augment_transforms(
        self, do_horizontal_flip, n_rotations, do_random_brightness
    ):
        augment_transforms_list = []
        if do_horizontal_flip:
            horizontal_flip_transform = transforms.Lambda(
                lambda x: transforms.functional.hflip(x)
            )
            augment_transforms_list.append(horizontal_flip_transform)
        if n_rotations > 0:
            rotation_transform = transforms.Lambda(
                lambda x: rotation_90_transform(x, num_rotations=n_rotations)
            )
            augment_transforms_list.append(rotation_transform)
        if do_random_brightness:
            brightness_transform = transforms.Lambda(
                lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
            )
            augment_transforms_list.append(brightness_transform)
        return transforms.Compose(augment_transforms_list)