Source code for cerebras.modelzoo.data.vision.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import filecmp
import math
import os
import random
import shutil

import torch
import torch.distributed as dist
from tqdm import tqdm

import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist


[docs]def is_gpu_distributed(): """ Returns True if DDP is enabled """ return ( torch.distributed.is_available() and torch.distributed.is_initialized() )
[docs]def task_id(): if dist.is_streamer(): return dist.get_streaming_rank() elif is_gpu_distributed(): return dist.get_rank() else: return 0
[docs]def num_tasks(): if dist.is_streamer(): return dist.num_streamers() elif is_gpu_distributed(): return dist.get_world_size() else: return 1
[docs]class ShardedSampler(torch.utils.data.Sampler): """ Modified from: https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler Sampler that restricts data loading to a subset of the dataset. Dataset is assumed to be of constant size. Args: dataset (torch.utils.data.Dataset): Dataset used for sampling. mode (modes): Instance of `modes` to indicate train or eval mode. shuffle (bool, optional): If `True` (default), sampler will shuffle the indices. seed (int, optional): Random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: `0`. drop_last (bool, optional): If `True`, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If `False`, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: `False`. """
[docs] def __init__(self, dataset, shuffle=True, seed=None, drop_last=False): self.num_tasks = num_tasks() self.task_id = task_id() self.dataset = dataset self.dataset_len = len(self.dataset) self.drop_last = drop_last if cstorch.use_cs() and not self.drop_last: raise ValueError( "On CS2 we do not support unequal batch sizes so `drop_last` " "must be set to `True`." ) # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_tasks: # Split to nearest available length that is evenly divisible. # This is to ensure each task receives the same amount of data when # using this sampler. self.num_samples = len(self.dataset) // self.num_tasks else: self.num_samples = math.ceil(len(self.dataset) / self.num_tasks) self.total_size = self.num_samples * self.num_tasks self.shuffle = shuffle self.seed = seed self.indices = list(range(self.dataset_len)) if not self.drop_last: # add extra samples to make it evenly divisible across tasks padding_indices_size = self.total_size - self.dataset_len # choose padding indices at random to reduce the chance of # reusing samples. random.seed(self.seed) padding_indices = random.sample(self.indices, padding_indices_size) self.indices += padding_indices else: # remove tail of data to make it evenly divisible. self.indices = self.indices[: self.total_size] assert len(self.indices) == self.total_size, ( f"Total `indices` after dropping/padding indices must be equal " f"to `total_size` of the dataset. Received total indices: " f"`{len(self.indices)}` and total size is: `{self.total_size}`." )
def __iter__(self): if self.shuffle: random.seed(self.seed) random.shuffle(self.indices) # subsample indices = self.indices[self.task_id : self.total_size : self.num_tasks] assert len(indices) == self.num_samples, ( f"Total `indices` for tasks must be equal to `num_samples` in a " f"task. Received total indices: `{len(indices)}` and samples in " f"task are: `{self.num_samples}`." ) yield from indices def __len__(self): return self.num_samples
##### Experimental to reduce first batch loading times for MAP style only ##### class _RepeatSampler(object): """Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler)
[docs]class FastDataLoader(torch.utils.data.DataLoader):
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__( self, 'batch_sampler', _RepeatSampler(self.batch_sampler) ) self.iterator = super().__iter__()
def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator)
def _get_worker_cache_dir(src_dir): """Gets the path to worker cache dir corresponding to the src_dir""" src_dir = os.path.abspath(src_dir) cache_dir = os.path.normpath("/".join([dist.WORKER_CACHE_ROOT, src_dir])) os.makedirs(cache_dir, exist_ok=True) return cache_dir def _same_dirs_shallow(src_dir, dest_dir): """Takes a directory comparison obj and does a shallow comparison between the dirs src_dir and dest_dir The shallow comparison does a recursive check of the following: 1. Check if the dirs exist, if they don't then return False 2. Check if the files have a diff, or if there are additional files for either of the two dirs, if different, return False. 3. Repeat 1 and 2 on subdirs """ def _same_dirs_shallow_helper(dcmp: filecmp.dircmp): if not os.path.exists(dcmp.left) or not os.path.exists(dcmp.right): return False if dcmp.left_only: # If diff consists of only broken # symlinks, then its a match parent = dcmp.left for left_file in dcmp.left_only: if os.path.isdir( os.path.join(parent, left_file) ) or os.path.isfile(os.path.join(parent, left_file)): return False if dcmp.diff_files or dcmp.right_only: return False for sub_dcmp in dcmp.subdirs.values(): if not _same_dirs_shallow_helper(sub_dcmp): return False return True return _same_dirs_shallow_helper(filecmp.dircmp(src_dir, dest_dir))
[docs]def create_worker_cache(src_dir: str, force_overwrite: bool = False): """Checks for the dir in the worker_cache (SSD) on the worker node corresponding to the src_dir. If the directory exists and is same as the src_dir, it returns the dir path on worker_cache. Otherwise writes the directory to the worker_cache and returns the dir path. Writing to the cache can take a while, depending on the size of the src_dir: Displays a progress bar (in the worker logs) which shows progress of the cache Forces cache overwrite irrespective of a cache hit, when force_overwrite is True. """ from filelock import FileLock if ( os.path.commonprefix([src_dir, dist.WORKER_CACHE_ROOT]) == dist.WORKER_CACHE_ROOT ): raise RuntimeError( f"Ensure that the src_dir path does not have " f"a worker_cache path prefix: {dist.WORKER_CACHE_ROOT}" ) if not dist.is_streamer(): raise RuntimeError( "Ensure that create_worker_cache is called only for a worker node." ) dest_dir = _get_worker_cache_dir(src_dir) # Provide read/write permissions for the lock for all users with FileLock(f"{dest_dir}.lock", mode=0o666): if _same_dirs_shallow(src_dir, dest_dir) and not force_overwrite: print(f"WORKER CACHE HIT: Skipping overwrite") else: ( is_limit_hit, dir_size, available_space_for_copy, ) = dist.hit_worker_cache_limit(src_dir, dest_dir) if is_limit_hit: raise RuntimeError( f"Failed when copying the directory to the worker cache: {src_dir}," f" directory size: {dir_size} exceeds the available space on worker cache: {available_space_for_copy}." f"Please contact your system administrator to clear the worker cache." ) if os.path.exists(dest_dir): shutil.rmtree(dest_dir) # copy dirs to destination # get the total number of files to copy total_files = sum( [len(files) for root, dirs, files in os.walk(src_dir)] ) # copy directory with progress bar def copy2_with_progress(src_path, dst_path, update): # skip if its a broken symlink if os.path.isfile(src_path): shutil.copy2(src_path, dst_path) update(1) with tqdm( total=total_files, desc="Overwriting cache", unit="files", dynamic_ncols=True, ) as pbar: shutil.copytree( src_dir, dest_dir, symlinks=False, ignore=None, ignore_dangling_symlinks=True, copy_function=lambda f, d: copy2_with_progress( f, d, pbar.update ), ) return dest_dir