Source code for common.pytorch.pytorch_dist_runner

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch.metrics import compute_all_metrics, get_all_metrics
from modelzoo.common.pytorch.pytorch_runner import PyTorchRunner
from modelzoo.common.pytorch.utils import visit_structure


[docs]class PyTorchDistRunner(PyTorchRunner): """Class for running PyTorch models on multiple GPUs."""
[docs] def __init__(self, model, params): self._epoch = 0 # The main process we use to aggregate results and do most IOs self._main_process_id = params["runconfig"].get("main_process_id", 0) self._dist_backend = params["runconfig"].get("dist_backend", "nccl") self._init_method = params["runconfig"].get("init_method", "env://") self._should_sync_batchnorm = params["runconfig"].get( "sync_batchnorm", False ) if not dist.is_available(): raise RuntimeError(f"torch.distributed package is not available") dist_addr = params["runconfig"].get("dist_addr", "localhost:8888") master_addr, master_port = dist_addr.split(":") os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = master_port # pass in an instance of the model for housing keep super().__init__(device=None, model=model, params=params)
[docs] def is_master_ordinal(self): """ Checks if distributed if enabled and if so whether it's the main process, most reading and writing should only happens on main process. """ if torch.distributed.is_initialized(): return torch.distributed.get_rank() == self._main_process_id else: return cm.is_master_ordinal()
################################################################## # Training Hooks # ##################################################################
[docs] def on_train_batch_start(self, data): return self._to_device(data, non_blocking=True)
[docs] def on_train_epoch_start(self): if hasattr(self._train_sampler, "set_epoch"): self._train_sampler.set_epoch(self._epoch)
[docs] def on_train_epoch_end(self, early_exit: bool): # change _epoch for shuffling in dataloader sampler self._epoch += 1
[docs] def on_train_batch_end(self, loss, epoch: int = None, step: int = None): """Actions to perform after the train batch iteration is complete""" self._maybe_check_loss_value(loss) if not torch.is_tensor(loss): loss = torch.tensor(loss).to(self._device) # check _is_fetch_step ahead of time to minimize loss syncing if self._is_fetch_step(0): # not using AVG since it's only available with NCCL dist.reduce(loss, 0, op=dist.ReduceOp.SUM) loss /= dist.get_world_size() dist.barrier() if self.is_master_ordinal(): super().on_train_batch_end(loss, epoch=epoch, step=step)
[docs] def on_train_end(self, early_exit: bool): if self.is_master_ordinal(): logging.info("Training Completed Successfully!")
################################################################## # Evaluation Hooks # ##################################################################
[docs] def on_eval_batch_start(self, data): return self._to_device(data, non_blocking=True)
[docs] def on_eval_batch_end(self, loss, epoch: int = None, step: int = None): """Actions to perform after the eval batch iteration is complete""" self._maybe_check_loss_value(loss, step_offset=step + 1) if not torch.is_tensor(loss): loss = torch.tensor(loss).to(self._device) # not using AVG since it's only available with NCCL dist.reduce(loss, 0, op=dist.ReduceOp.SUM) loss /= dist.get_world_size() dist.barrier() if self.is_master_ordinal(): super().on_eval_batch_end(loss, epoch=epoch, step=step)
[docs] def on_eval_end(self, early_exit: bool): if self.is_master_ordinal(): logging.info("Evaluation Completed Successfully!")
[docs] def compute_eval_metrics(self): """Compute and log the eval metrics""" eval_metrics = compute_all_metrics() # aggregate eval metrics across processes aggregated_eval_metrics = dict() for key, metric in eval_metrics.items(): if not torch.is_tensor(metric): metric = torch.tensor(metric).to(self._device) dist.reduce(metric, 0, op=dist.ReduceOp.SUM) metric /= dist.get_world_size() dist.barrier() if self.is_master_ordinal(): aggregated_eval_metrics[key] = metric.detach().cpu().item() if self.is_master_ordinal(): if aggregated_eval_metrics: if self._writer: for metric_scope, metric_value in visit_structure( aggregated_eval_metrics, select_fn=lambda struct: isinstance( struct, (int, float) ), strict=True, ): key = "/".join(metric_scope) self._writer.add_scalar( key, metric_value, self._global_step ) logging.info(f"Avg eval_metrics = {eval_metrics}") # Normalize total loss avg_eval_loss = self._loss_saver.average_loss if self._writer: self._writer.add_scalar( "loss", avg_eval_loss, self._global_step ) logging.info(f"Avg Eval. Loss = {avg_eval_loss}") dist.barrier()
################################################################## # Override train/eval functions # ##################################################################
[docs] def on_process_start(self, all_metrics): get_all_metrics().update(all_metrics) logging.getLogger().setLevel(logging.INFO) rank = dist.get_rank() self._device = torch.device(rank) self._model.model.to(self._device) self._optimizer = self._model.get_optimizer() self._optimizer.to(self._device) self._lr_scheduler = self._model.get_lr_scheduler() if self._should_sync_batchnorm: self._model.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self._model.model ) self._model.model = DistributedDataParallel( self._model.model, device_ids=[rank], output_device=rank )
def _train_dist(self, rank, world_size, train_data_fn, all_metrics): dist.init_process_group( backend=self._dist_backend, init_method=self._init_method, world_size=world_size, rank=rank, ) self.on_process_start(all_metrics) train_dataloader = train_data_fn(self._params) self._train_sampler = train_dataloader.sampler super().train(train_dataloader) dist.destroy_process_group()
[docs] def train(self, train_data_fn): all_metrics = get_all_metrics() world_size = torch.cuda.device_count() mp.spawn( self._train_dist, nprocs=world_size, args=(world_size, train_data_fn, all_metrics), )
def _evaluate_dist(self, rank, world_size, eval_data_fn, all_metrics): dist.init_process_group( backend=self._dist_backend, init_method=self._init_method, world_size=world_size, rank=rank, ) self.on_process_start(all_metrics) eval_dataloader = eval_data_fn(self._params) super().evaluate(eval_dataloader) dist.destroy_process_group()
[docs] def evaluate(self, eval_data_fn): """Evaluate the model with data generated by the given dataloader. Args: dataloader: A data loader for generating data to feed to the model. """ all_metrics = get_all_metrics() world_size = torch.cuda.device_count() mp.spawn( self._evaluate_dist, nprocs=world_size, args=(world_size, eval_data_fn, all_metrics), )
def _train_and_eval_dist( self, rank, world_size, train_data_fn, eval_data_fn, all_metrics ): dist.init_process_group( backend=self._dist_backend, init_method=self._init_method, world_size=world_size, rank=rank, ) self.on_process_start(all_metrics) train_dataloader = train_data_fn(self._params) self._train_sampler = train_dataloader.sampler eval_dataloader = eval_data_fn(self._params) super().train_and_eval(train_dataloader, eval_dataloader) dist.destroy_process_group()
[docs] def train_and_eval(self, train_data_fn, eval_data_fn): """Train the model with data generated by the given dataloader. Args: dataloader: A data loader for generating data to feed to the model. """ all_metrics = get_all_metrics() world_size = torch.cuda.device_count() mp.spawn( self._train_and_eval_dist, nprocs=world_size, args=(world_size, train_data_fn, eval_data_fn, all_metrics), )