# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
import torchvision
from torch.utils.data import Subset
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.vision import StandardTransform
from cerebras.modelzoo.common.input_utils import get_streaming_batch_size
from cerebras.modelzoo.data.vision.classification.mixup import (
    RandomCutmix,
    RandomMixup,
)
from cerebras.modelzoo.data.vision.classification.sampler import (
    RepeatedAugSampler,
)
from cerebras.modelzoo.data.vision.classification.utils import (
    create_preprocessing_params_with_defaults,
)
from cerebras.modelzoo.data.vision.preprocessing import get_preprocess_transform
from cerebras.modelzoo.data.vision.transforms import LambdaWithParam
from cerebras.modelzoo.data.vision.utils import is_gpu_distributed, task_id
[docs]class Processor:
[docs]    def __init__(self, params):
        # data settings
        self.data_dir = params.get("data_dir", ".")
        self.image_size = params.get("image_size", 224)
        self.num_classes = params.get("num_classes")
        self.allowable_split = None
        # params for preprocessing dataset
        self.pp_params = create_preprocessing_params_with_defaults(params)
        # params for data loader
        self.batch_size = get_streaming_batch_size(
            params.get("batch_size", 128)
        )
        self.shuffle = params.get("shuffle", True)
        self.shuffle_seed = params.get("shuffle_seed", None)
        if self.shuffle_seed is not None:
            torch.manual_seed(self.shuffle_seed)
        self.drop_last = params.get("drop_last", True)
        # multi-processing params.
        self.num_workers = params.get("num_workers", 0)
        self.prefetch_factor = params.get("prefetch_factor", 10)
        self.persistent_workers = params.get("persistent_workers", True)
        self.distributed = is_gpu_distributed()
        # sampler
        self.sampler = params.get("sampler", "random")
        self.ra_sampler_num_repeat = params.get("ra_sampler_num_repeat", 3)
        self.mixup_alpha = params.get("mixup_alpha", 0.1)
        self.cutmix_alpha = params.get("cutmix_alpha", 0.1) 
    def create_dataloader(self, dataset, is_training=False):
        assert (
            isinstance(dataset, torchvision.datasets.VisionDataset)
            or isinstance(dataset, VisionSubset)
            or isinstance(dataset, torch.utils.data.Subset)
        ), f"Got {type(dataset)} but dataset must be type VisionDataset, "
        "VisionSubset, or torch.utils.data.Subset"
        shuffle = self.shuffle and is_training
        mixup_transforms = []
        if self.mixup_alpha > 0.0:
            mixup_transforms.append(
                RandomMixup(self.num_classes, p=1.0, alpha=self.mixup_alpha)
            )
        if self.cutmix_alpha > 0.0:
            mixup_transforms.append(
                RandomCutmix(self.num_classes, p=1.0, alpha=self.cutmix_alpha)
            )
        if mixup_transforms:
            mixup_fn = torchvision.transforms.RandomChoice(mixup_transforms)
            collate_fn = lambda batch: mixup_fn(*default_collate(batch))
        if self.distributed:
            # distributed samplers require a seed
            if self.shuffle_seed is None:
                self.shuffle_seed = 0
            if self.sampler == "repeated-aug":
                data_sampler = RepeatedAugSampler(
                    dataset,
                    shuffle=shuffle,
                    seed=self.shuffle_seed,
                    num_repeats=self.ra_sampler_num_repeat,
                    batch_size=self.batch_size,
                )
            else:
                data_sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset,
                    shuffle=shuffle,
                    seed=self.shuffle_seed,
                )
        else:
            if shuffle:
                data_sampler = torch.utils.data.RandomSampler(
                    dataset, generator=self._generator_fn()
                )
            else:
                data_sampler = torch.utils.data.SequentialSampler(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            sampler=data_sampler,
            num_workers=self.num_workers,
            pin_memory=self.distributed,
            drop_last=self.drop_last,
            prefetch_factor=self.prefetch_factor,
            persistent_workers=self.persistent_workers,
            worker_init_fn=self._worker_init_fn,
        )
        return dataloader
    def create_dataset(self, use_training_transforms=True, split="train"):
        raise NotImplementedError(
            "create_dataset must be implemented in a child class!!"
        )
    def _get_target_transform(self, x, *args, **kwargs):
        return np.int32(x)
    def process_transform(self, use_training_transforms=True):
        if self.pp_params["noaugment"]:
            transform_specs = [
                {"name": "resize", "size": self.image_size},
                {"name": "to_tensor"},
            ]
            logging.warning(
                "User specified `noaugment=True`. The input data will only be "
                "resized to `image_size` and converted to tensor."
            )
            self.pp_params["transforms"] = transform_specs
        transform = get_preprocess_transform(self.pp_params)
        target_transform = LambdaWithParam(self._get_target_transform)
        return transform, target_transform
    def check_split_valid(self, split):
        if split not in self.allowable_split:
            raise ValueError(
                f"Dataset split {split} is invalid. Only values in "
                f"{self.allowable_split} are allowed."
            )
    def split_dataset(self, dataset, split_percent, seed):
        num_sample = len(dataset)
        rng = np.random.default_rng(seed)
        sample_idx = self.create_shuffled_idx(num_sample, rng)
        split_idx = [0]
        if sum(split_percent) != 100:
            raise ValueError(
                f"Sum of split percentage must be 100%! Got {sum(split_percent)}"
            )
        for sp in split_percent[:-1]:
            offset = num_sample * sp // 100
            new_end = split_idx[-1] + offset
            split_idx.append(new_end)
        split_idx.append(num_sample)
        return [
            VisionSubset(dataset, sample_idx[start:end])
            for start, end in zip(split_idx[:-1], split_idx[1:])
        ]
    def create_shuffled_idx(self, num_sample, rng):
        shuffled_idx = np.arange(num_sample)
        rng.shuffle(shuffled_idx)
        return shuffled_idx
    def _worker_init_fn(self, worker_id):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        if self.shuffle_seed is not None:
            np.random.seed(self.shuffle_seed + worker_id)
    def _generator_fn(self):
        generator_fn = None
        if self.shuffle_seed is not None:
            seed = self.shuffle_seed + task_id()
            generator_fn = torch.Generator(device="cpu")
            generator_fn.manual_seed(seed)
        return generator_fn 
[docs]class VisionSubset(Subset):
[docs]    def __init__(self, dataset, indices):
        assert isinstance(
            dataset, torchvision.datasets.VisionDataset
        ), f"Dataset must be type VisionDataset, but got {type(dataset)} instead."
        super().__init__(dataset, indices) 
    def truncate_to_idx(self, new_length):
        self.indices = self.indices[:new_length]