# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import filecmp
import math
import os
import random
import shutil
import torch
import torch.distributed as dist
from tqdm import tqdm
import cerebras_pytorch as cstorch
import cerebras_pytorch.distributed as dist
[docs]def is_gpu_distributed():
    """
    Returns True if DDP is enabled
    """
    return (
        torch.distributed.is_available() and torch.distributed.is_initialized()
    ) 
[docs]def task_id():
    if dist.is_streamer():
        return dist.get_streaming_rank()
    elif is_gpu_distributed():
        return dist.get_rank()
    else:
        return 0 
[docs]def num_tasks():
    if dist.is_streamer():
        return dist.num_streamers()
    elif is_gpu_distributed():
        return dist.get_world_size()
    else:
        return 1 
[docs]class ShardedSampler(torch.utils.data.Sampler):
    """
    Modified from:
    https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler
    Sampler that restricts data loading to a subset of the dataset.
    Dataset is assumed to be of constant size.
    Args:
        dataset (torch.utils.data.Dataset): Dataset used for sampling.
        mode (modes): Instance of `modes` to indicate train or eval mode.
        shuffle (bool, optional): If `True` (default), sampler will shuffle 
            the indices.
        seed (int, optional): Random seed used to shuffle the sampler if
            :attr:`shuffle=True`. This number should be identical across all
            processes in the distributed group. Default: `0`.
        drop_last (bool, optional): If `True`, then the sampler will drop the
            tail of the data to make it evenly divisible across the number of
            replicas. If `False`, the sampler will add extra indices to make
            the data evenly divisible across the replicas. Default: `False`.
    """
[docs]    def __init__(self, dataset, shuffle=True, seed=None, drop_last=False):
        self.num_tasks = num_tasks()
        self.task_id = task_id()
        self.dataset = dataset
        self.dataset_len = len(self.dataset)
        self.drop_last = drop_last
        if cstorch.use_cs() and not self.drop_last:
            raise ValueError(
                "On CS2 we do not support unequal batch sizes so `drop_last` "
                "must be set to `True`."
            )
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_tasks:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each task receives the same amount of data when
            # using this sampler.
            self.num_samples = len(self.dataset) // self.num_tasks
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_tasks)
        self.total_size = self.num_samples * self.num_tasks
        self.shuffle = shuffle
        self.seed = seed
        self.indices = list(range(self.dataset_len))
        if not self.drop_last:
            # add extra samples to make it evenly divisible across tasks
            padding_indices_size = self.total_size - self.dataset_len
            # choose padding indices at random to reduce the chance of
            # reusing samples.
            random.seed(self.seed)
            padding_indices = random.sample(self.indices, padding_indices_size)
            self.indices += padding_indices
        else:
            # remove tail of data to make it evenly divisible.
            self.indices = self.indices[: self.total_size]
        assert len(self.indices) == self.total_size, (
            f"Total `indices` after dropping/padding indices must be equal "
            f"to `total_size` of the dataset. Received total indices: "
            f"`{len(self.indices)}` and total size is: `{self.total_size}`."
        ) 
    def __iter__(self):
        if self.shuffle:
            random.seed(self.seed)
            random.shuffle(self.indices)
        # subsample
        indices = self.indices[self.task_id : self.total_size : self.num_tasks]
        assert len(indices) == self.num_samples, (
            f"Total `indices` for tasks must be equal to `num_samples` in a "
            f"task. Received total indices: `{len(indices)}` and samples in "
            f"task are: `{self.num_samples}`."
        )
        yield from indices
    def __len__(self):
        return self.num_samples 
##### Experimental to reduce first batch loading times for MAP style only #####
class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """
    def __init__(self, sampler):
        self.sampler = sampler
    def __iter__(self):
        while True:
            yield from iter(self.sampler)
[docs]class FastDataLoader(torch.utils.data.DataLoader):
[docs]    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(
            self, 'batch_sampler', _RepeatSampler(self.batch_sampler)
        )
        self.iterator = super().__iter__() 
    def __len__(self):
        return len(self.batch_sampler.sampler)
    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator) 
def _get_worker_cache_dir(src_dir):
    """Gets the path to worker cache dir corresponding to the src_dir"""
    src_dir = os.path.abspath(src_dir)
    cache_dir = os.path.normpath("/".join([dist.WORKER_CACHE_ROOT, src_dir]))
    os.makedirs(cache_dir, exist_ok=True)
    return cache_dir
def _same_dirs_shallow(src_dir, dest_dir):
    """Takes a directory comparison obj and does a shallow comparison
    between the dirs src_dir and dest_dir
    The shallow comparison does a recursive check of the following:
        1. Check if the dirs exist, if they don't then return False
        2. Check if the files have a diff, or if there
        are additional files for either of the two dirs, if different,
        return False.
        3. Repeat 1 and 2 on subdirs
    """
    def _same_dirs_shallow_helper(dcmp: filecmp.dircmp):
        if not os.path.exists(dcmp.left) or not os.path.exists(dcmp.right):
            return False
        if dcmp.left_only:
            # If diff consists of only broken
            # symlinks, then its a match
            parent = dcmp.left
            for left_file in dcmp.left_only:
                if os.path.isdir(
                    os.path.join(parent, left_file)
                ) or os.path.isfile(os.path.join(parent, left_file)):
                    return False
        if dcmp.diff_files or dcmp.right_only:
            return False
        for sub_dcmp in dcmp.subdirs.values():
            if not _same_dirs_shallow_helper(sub_dcmp):
                return False
        return True
    return _same_dirs_shallow_helper(filecmp.dircmp(src_dir, dest_dir))
[docs]def create_worker_cache(src_dir: str, force_overwrite: bool = False):
    """Checks for the dir in the worker_cache (SSD) on the worker node corresponding to the src_dir.
    If the directory exists and is same as the src_dir, it returns the dir path on worker_cache.
    Otherwise writes the directory to the worker_cache and returns the dir path.
    Writing to the cache can take a while, depending on the size of the src_dir:
    Displays a progress bar (in the worker logs) which shows progress of the cache
    Forces cache overwrite irrespective of a cache hit, when force_overwrite is True.
    """
    from filelock import FileLock
    if (
        os.path.commonprefix([src_dir, dist.WORKER_CACHE_ROOT])
        == dist.WORKER_CACHE_ROOT
    ):
        raise RuntimeError(
            f"Ensure that the src_dir path does not have "
            f"a worker_cache path prefix: {dist.WORKER_CACHE_ROOT}"
        )
    if not dist.is_streamer():
        raise RuntimeError(
            "Ensure that create_worker_cache is called only for a worker node."
        )
    dest_dir = _get_worker_cache_dir(src_dir)
    # Provide read/write permissions for the lock for all users
    with FileLock(f"{dest_dir}.lock", mode=0o666):
        if _same_dirs_shallow(src_dir, dest_dir) and not force_overwrite:
            print(f"WORKER CACHE HIT: Skipping overwrite")
        else:
            (
                is_limit_hit,
                dir_size,
                available_space_for_copy,
            ) = dist.hit_worker_cache_limit(src_dir, dest_dir)
            if is_limit_hit:
                raise RuntimeError(
                    f"Failed when copying the directory to the worker cache: {src_dir},"
                    f" directory size: {dir_size} exceeds the available space on worker cache: {available_space_for_copy}."
                    f"Please contact your system administrator to clear the worker cache."
                )
            if os.path.exists(dest_dir):
                shutil.rmtree(dest_dir)
            # copy dirs to destination
            # get the total number of files to copy
            total_files = sum(
                [len(files) for root, dirs, files in os.walk(src_dir)]
            )
            # copy directory with progress bar
            def copy2_with_progress(src_path, dst_path, update):
                # skip if its a broken symlink
                if os.path.isfile(src_path):
                    shutil.copy2(src_path, dst_path)
                    update(1)
            with tqdm(
                total=total_files,
                desc="Overwriting cache",
                unit="files",
                dynamic_ncols=True,
            ) as pbar:
                shutil.copytree(
                    src_dir,
                    dest_dir,
                    symlinks=False,
                    ignore=None,
                    ignore_dangling_symlinks=True,
                    copy_function=lambda f, d: copy2_with_progress(
                        f, d, pbar.update
                    ),
                )
    return dest_dir