# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
import torch
from torchvision.datasets import DatasetFolder
import cerebras.pytorch.distributed as dist
from cerebras.modelzoo.data.vision.diffusion.DiffusionBaseProcessor import (
    DiffusionBaseProcessor,
)
from cerebras.modelzoo.data.vision.utils import create_worker_cache
FILE_EXTENSIONS = (".npz", ".npy")
[docs]class CategoricalDataset(torch.utils.data.Dataset):
[docs]    def __init__(self, datasets, probs=None, seed=None):
        self.datasets = datasets
        self.num_datasets = len(datasets)
        self.probs = probs
        if self.probs is None:
            self.probs = [1 / self.num_datasets] * self.num_datasets
        if not isinstance(self.probs, torch.Tensor):
            self.probs = torch.tensor(self.probs)
        assert (
            len(self.probs) == self.num_datasets
        ), f"Probability values(={len(self.probs)}) != number of datasets(={self.num_datasets})"
        assert (
            torch.sum(self.probs) == 1.0
        ), f"Probability values don't add up to 1.0"
        self.len_datasets = [len(ds) for ds in self.datasets]
        self.max_len = max(self.len_datasets)
        if seed is None:
            # large random number chosen as `high` upper bound
            seed = torch.randint(0, 2147483647, (1,), dtype=torch.int64).item()
        self.seed = seed
        self.generator = None 
    def __len__(self):
        return self.max_len
    def __getitem__(self, idx):
        if self.generator is None:
            self.generator = torch.Generator()
            self.generator.manual_seed(self.seed)
        # Pick a dataset
        ds_id = torch.multinomial(self.probs, 1, generator=self.generator)
        ds = self.datasets[ds_id]
        # get sample from dataset selected
        sample_id = idx % self.len_datasets[ds_id]
        return ds[sample_id] 
[docs]class ImageNetLatentDataset(DatasetFolder):
[docs]    def __init__(
        self,
        root: str,
        split: str,
        latent_size: List,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        strict_check: Optional[bool] = False,
    ):
        split_folder = os.path.join(root, split)
        self.strict_check = strict_check
        self.latent_size = latent_size
        super().__init__(
            split_folder,
            self.loader,
            FILE_EXTENSIONS,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=None,
        ) 
    def loader(self, path: str):
        data = np.load(path)
        return data
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        latent = torch.from_numpy(sample["vae_output"])
        label = torch.from_numpy(sample["label"])
        assert (
            list(latent.shape) == self.latent_size
        ), f"Mismatch between shapes {latent.shape} vs expected shape:{self.latent_size}"
        if self.strict_check:
            assert (
                sample["dest_path"] == path
            ), f"Mismatch between image and latent files, please check data creation process."
            assert (
                label == target
            ), f"Mismatch between labels written to npz file and inferred according to folder structure"
        if self.transform is not None:
            latent = self.transform(latent)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return latent, target 
[docs]class DiffusionLatentImageNet1KProcessor(DiffusionBaseProcessor):
[docs]    def __init__(self, params):
        super().__init__(params)
        if not isinstance(self.data_dir, list):
            self.data_dir = [self.data_dir]
        self.num_classes = 1000 
    def create_dataset(self, use_training_transforms=True, split="train"):
        if self.use_worker_cache and dist.is_streamer():
            data_dir = []
            for _dir in self.data_dir:
                data_dir.append(create_worker_cache(_dir))
            self.data_dir = data_dir
        self.check_split_valid(split)
        transform, target_transform = self.process_transform()
        dataset_list = []
        for _dir in self.data_dir:
            if not os.path.isdir(os.path.join(_dir, split)):
                raise RuntimeError(f"No directory {split} under root dir")
            dataset_list.append(
                ImageNetLatentDataset(
                    root=_dir,
                    latent_size=[
                        2 * self.latent_channels,
                        self.latent_height,
                        self.latent_width,
                    ],
                    split=split,
                    transform=transform,
                    target_transform=target_transform,
                )
            )
        dataset = CategoricalDataset(dataset_list, seed=self.shuffle_seed)
        return dataset 
if __name__ == "__main__":
    import os
    import yaml
    from cerebras.modelzoo.models.vision.dit.utils import set_defaults
    fpath = os.path.abspath(
        os.path.join(
            os.path.dirname(__file__),
            "../configs/params_dit_small_patchsize_2x2.yaml",
        )
    )
    with open(fpath, "r") as fid:
        params = yaml.safe_load(fid)
    params = set_defaults(params)
    data_obj = DiffusionLatentImageNet1KProcessor(params['train_input'])
    dataset = data_obj.create_dataset(
        use_training_transforms=True, split=params["train_input"]["split"]
    )
    print("Dataset features: \n")
    print(f"Dataset 0th sample: {dataset[0]}, {dataset[0][0].shape}")
    dataloader = data_obj.create_dataloader(dataset, is_training=True)
    print("Dataloader features as below: \n")
    for ii, data in enumerate(dataloader):
        if ii == 1:
            break
        for k, v in data.items():
            print(f"{k} -- {v.shape}, {v.dtype}")
        print("----")