Source code for common.pytorch.pytorch_base_runner

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Modulek containing the Base PyTorch Runner"""
# pylint: disable=no-self-use, attribute-defined-outside-init

import abc
import copy
import logging
import math
import os
import re
import warnings
from collections import defaultdict
from contextlib import ExitStack, contextmanager, nullcontext
from inspect import isclass
from typing import Callable, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch, modes
from modelzoo.common.pytorch.dump_context import DumpContext
from modelzoo.common.pytorch.loss_utils import LossSaver, extract_loss
from modelzoo.common.pytorch.metrics import (
    compute_all_metrics,
    reset_all_metrics,
)
from modelzoo.common.pytorch.PyTorchBaseModel import PyTorchBaseModel
from modelzoo.common.pytorch.summaries import (
    discard_cached_summaries,
    save_all_summaries,
    scalar_summary,
)
from modelzoo.common.pytorch.utils import (
    RunConfigParamsValidator,
    visit_structure,
)
from modelzoo.common.run_utils.utils import DeviceType, ExecutionStrategy


[docs]class PyTorchBaseRunner(metaclass=abc.ABCMeta): """The base class for running PyTorch models on any device."""
[docs] def __init__(self, model: PyTorchBaseModel, params: dict): """Construct a `PyTorchRunner` instance. Args: model: The PyTorch model to run. param: A dict of params that specify the behavior of the model. """ self._model = model self._params = copy.deepcopy(params) self._optimizer = None self._lr_scheduler = None self._scaler = None mode = self._runconfig["mode"] if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL): self._optimizer = model.get_optimizer() self._lr_scheduler = model.get_lr_scheduler() # The mode that is currently active self._active_mode = mode # Mandatory config options self._mixed_precision = self._params["model"]["mixed_precision"] # Optional config options self._grad_accum_steps = self._params["optimizer"].get( "grad_accum_steps", 1 ) self._log_summaries = self._params["optimizer"].get( "log_summaries", False ) self._show_debug_metrics = self._runconfig.get( "show_debug_metrics", False ) self._save_losses = self._runconfig.get("save_losses", True) self._check_loss_values = self._runconfig.get("check_loss_values", True) self._model_dir = self._runconfig.get("model_dir", "./") self._checkpoint_path = self._runconfig.get("checkpoint_path") if not self._checkpoint_path and self._runconfig.get( "autoload_last_checkpoint", True ): self._checkpoint_path = self._get_last_checkpoint() self._is_pretrained_checkpoint = self._runconfig.get( "is_pretrained_checkpoint", False, ) self._save_initial_checkpoint = self._runconfig.get( "save_initial_checkpoint", False ) self._save_stream_size = self._runconfig.get("save_stream_size", 0) # Summary writer object self._writers = {} # A cleanup stack that subclasses can hook into during a run to do any # necessary cleanups in case of failures self.__cleanup_stack: Union[ExitStack, None] = None # These are set up at the start of execution loop self._global_step = None self._initial_step = None self._total_steps = None if self._runconfig.get("dump_activations", False): if cm.use_cs(): logging.warning( "Actually dumping activations on CS2 requires setting " "additional debug arguments. This runconfig option just " "enables extra debug location info in the compile." ) self._dump_ctx = DumpContext( os.path.join(self._model_dir, "act_dumps"), self._model.model, ) else: self._dump_ctx = nullcontext()
def _log_summaries_params_norm(self, requires_grad=True): """ Args: requires_grad (bool): whether to only include params that requires gradient update """ # Computes global norm of all params, but calculating # norm of each set of params individually first and then # combining of all norms. if cm.use_cs(): device = self._model.model.device else: device = self._model.device param_norm = torch.tensor(0.0).to(device) for _, param in self._model.model.named_parameters(): if not requires_grad: # simply add if we want to include all params param_norm += torch.pow(torch.norm(param), 2.0) elif param.requires_grad: # only add the param if it requires gradient update param_norm += torch.pow(torch.norm(param), 2.0) param_norm = torch.sqrt(param_norm) scalar_summary("model_wise_params_norm", param_norm) def _log_summaries_grad_norm(self, is_clipped=False, is_scaled=False): """ Args: is_clipped (bool): whether to log clipped gradient is_scaled (bool): whether to log scaled gradient """ # This should be called after unscaling and before grad clipping if cm.use_cs(): device = self._model.model.device else: device = self._model.device param_grad_norm = torch.tensor(0.0).to(device) for _, param in self._model.model.named_parameters(): if param.grad is not None: param_grad_norm += torch.pow(torch.norm(param.grad), 2.0) param_grad_norm = torch.sqrt(param_grad_norm) summary_str = "model_wise_grad_norm" summary_str += "_clipped" if is_clipped else "_unclipped" summary_str += "_scaled" if is_scaled else "_unscaled" scalar_summary(summary_str, param_grad_norm) def _log_summaries_grad_norm_per_layer( self, is_clipped=False, is_scaled=False ): """ Args: is_clipped (bool): whether to log clipped gradient is_scaled (bool): whether to log scaled gradient """ # Computes global norm of all params, but calculating # norm of each set of params individually first and then # combining of all norms. param_norm = {} layer_pattern_str = r'.*(layers\.)(\d+)(\.).*' layer_pattern = re.compile(layer_pattern_str) for name, param in self._model.model.named_parameters(): if param.grad is None: continue # get a match if module name contains `layers.i.0` where i is layer num match = layer_pattern.match(name) if match: layer_id = match.group(2) if layer_id not in param_norm: param_norm[layer_id] = torch.tensor(0.0).to(param.device) param_norm[layer_id] += torch.pow(torch.norm(param.grad), 2.0) for layer_id in param_norm: param_norm[layer_id] = torch.sqrt(param_norm[layer_id]) summary_str = "per_layer_grad_norm" summary_str += "_clipped" if is_clipped else "_unclipped" summary_str += "_scaled" if is_scaled else "_unscaled" summary_str += f"/layer_{layer_id}" scalar_summary(summary_str, param_norm[layer_id]) def _log_summaries_learing_rate(self): if self._lr_scheduler: # self._lr_scheduler.get_last_lr() return a list of LRs for # different param groups of the optimizer. We create one # learning rate tracker for each param group with identifier `lr_{i}` # where i is the index of the param group last_lrs = self._lr_scheduler.get_last_lr() if not isinstance(last_lrs, list): last_lrs = [last_lrs] for i, last_lr in enumerate(last_lrs): scalar_summary(f"lr_params_group_{i}", last_lr) def _log_summaries_loss_scale(self): if self._scaler: scalar_summary("loss_scale", self._scaler.get_scale()) @property def _should_log_extra_summaries(self): return self._log_summaries @property def _writer(self) -> Optional[SummaryWriter]: return self._writers.get(self._active_mode) @property def _runconfig(self): return self._params["runconfig"] @property def _run_step(self) -> int: """Returns the current execution step. This is different from global_step in that it indicates the execution step of the current run. """ assert self._global_step is not None assert self._initial_step is not None return self._global_step - self._initial_step @property def _loss_dir(self) -> str: """Return the directory to use for saving intermediate losses.""" loss_dir = os.path.join(self._model_dir, "losses") os.makedirs(loss_dir, exist_ok=True) return loss_dir @property def _cleanup_stack(self) -> ExitStack: """Returns an ExitStack only available during execution.""" assert ( self.__cleanup_stack ), "Cleanup stack is only available during execution." return self.__cleanup_stack def _validate_config(self): """Check that the provided config is valid. Raises: AssertionError if any of the config options is invalid. """ # num_epochs and num_steps are mutually exclusive. max_steps is optional # unless neither num_epochs nor num_steps are provided, in which case # max_steps must be provided. if self._num_epochs is not None and self._num_steps is not None: raise ValueError( "Please specify only one of `num_epochs` or `num_steps`." ) elif self._num_epochs is not None: assert ( self._num_epochs > 0 ), "When provided, `num_epochs` must be greater than zero." elif self._num_steps is not None: assert ( self._num_steps > 0 ), "When provided, `num_steps` must be greater than zero." else: if self._max_steps is None: raise ValueError( "`max_steps` must be specified if neither `num_epochs` " "nor `num_steps` are specified." ) assert self._max_steps is None or self._max_steps > 0 assert ( self._train_steps_per_epoch is None or self._train_steps_per_epoch > 0 ) assert ( self._eval_steps_per_epoch is None or self._eval_steps_per_epoch > 0 ) assert self._grad_accum_steps >= 1 assert self._checkpoint_steps >= 0 assert self._log_steps >= 0 if cm.use_cs() and self._grad_accum_steps > 1: raise ValueError( "Gradient Accumulation not supported on CS workflow." ) def _should_stop(self, epoch_step: int, mode: str) -> Tuple[bool, bool]: """Return a tuple indicating whether to stop epoch/training. Args: epoch_step: The current step in the epoch loop. Returns: A tuple of booleans. The first item indicates whether to exit the epoch loop. The second item indicates whether to exit training. """ exit_epoch = exit_training = False if mode == modes.TRAIN: steps_per_epoch = self._train_steps_per_epoch elif mode == modes.EVAL: steps_per_epoch = self._eval_steps_per_epoch else: raise ValueError(f"Unhandled mode: {mode}.") if steps_per_epoch and epoch_step + 1 >= steps_per_epoch: exit_epoch = True if mode == modes.TRAIN and self._run_step >= self._total_steps: exit_epoch = True exit_training = True return exit_epoch, exit_training def _configure_run_steps(self, dataloader, mode: str): """Configure steps which specify training behavior. This method sets `self._total_steps`, `self._checkpoint_steps`, `self._fetch_steps`, and `self._num_epochs` based on the user-provided config. If the current global step exceeds max steps, this method raises an error. Args: data_loader: The data loader instance that is used for training. Raises: RuntimError if global step exceeds max steps. """ assert self._global_step is not None self._log_steps = self._runconfig["log_steps"] if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL): self._num_epochs = self._runconfig.get("num_epochs", None) self._num_steps = self._runconfig.get("num_steps", None) self._max_steps = self._runconfig.get("max_steps", None) self._checkpoint_steps = self._runconfig.get("checkpoint_steps", 0) elif mode == modes.EVAL: self._num_epochs = 1 self._num_steps = None self._max_steps = None self._checkpoint_steps = 0 else: raise ValueError(f"Unhandled mode: {mode}.") self._train_steps_per_epoch = self._runconfig.get( "steps_per_epoch", None ) self._eval_steps_per_epoch = self._runconfig.get("eval_steps", None) self._validate_config() if self._max_steps is not None and self._global_step >= self._max_steps: raise RuntimeError( f"Global step {self._global_step} already exceeds " f"max step {self._max_steps}." ) if mode == modes.TRAIN: train_dataloader = dataloader elif mode == modes.EVAL: eval_dataloader = dataloader elif mode == modes.TRAIN_AND_EVAL: train_dataloader, eval_dataloader = dataloader train_steps_per_epoch = 0 if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL): try: # Dataset length is known train_steps_per_epoch = len(train_dataloader) assert ( train_steps_per_epoch > 0 ), "Train Dataloader does not generate any batches." if self._train_steps_per_epoch is not None: assert ( self._train_steps_per_epoch <= train_steps_per_epoch ), ( f"The requested steps per epoch of {self._train_steps_per_epoch} " f"exceeds total steps in an epoch, which is " f"{train_steps_per_epoch}." ) train_steps_per_epoch = self._train_steps_per_epoch # With grad accumulation, the global step is incremented every Nth # batch, so our effective steps per epoch needs to be adjusted. assert self._grad_accum_steps <= train_steps_per_epoch, ( f"Gradient accumulation steps of {self._grad_accum_steps} is " f"greater than batches per epoch of {train_steps_per_epoch}." ) train_steps_per_epoch //= self._grad_accum_steps except TypeError: # Dataset length is not known assert self._num_epochs is None, ( "Specifying num_epochs for datasets with unknown length is " "not allowed. Please control training behavior through " "number of steps instead." ) train_steps_per_epoch = 1 if mode in (modes.EVAL, modes.TRAIN_AND_EVAL): try: # Dataset length is known eval_steps_per_epoch = len(eval_dataloader) assert ( eval_steps_per_epoch > 0 ), "Eval Dataloader does not generate any batches." if self._eval_steps_per_epoch is not None: # Do a sanity check that we can generate requested batches assert self._eval_steps_per_epoch <= eval_steps_per_epoch, ( f"The requested steps per epoch of " f"{self._eval_steps_per_epoch} exceeds total steps in " f"an epoch, which is {eval_steps_per_epoch}." ) else: # If not explicitly specified, run the whole eval epoch self._eval_steps_per_epoch = eval_steps_per_epoch except TypeError: # Dataset length is not known, so eval steps must be specified. # We assume the dataloader generates as many steps as the user # specified. Otherwise, we may get a stall. There's no way for # us to validate this since we can't query the dataset length. assert self._eval_steps_per_epoch is not None, ( "`eval_steps` must be specified for datasets with unknown " "length." ) if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL): steps_per_epoch = train_steps_per_epoch # Calculate total steps self._total_steps = math.inf if self._num_epochs is not None: self._total_steps = min( self._total_steps, self._num_epochs * steps_per_epoch ) if self._num_steps is not None: self._total_steps = min(self._total_steps, self._num_steps) if self._max_steps is not None: remaining_steps = self._max_steps - self._global_step assert remaining_steps > 0, ( # This was checked above f"Global step {self._global_step} already exceeds " f"max step {self._max_steps}." ) self._total_steps = min(self._total_steps, remaining_steps) # At least one of the above if blocks must have been true. # Adding an assert in case someone makes a mistake. assert not math.isinf( self._total_steps ), "One of num_epochs, num_steps, or max_steps must be provided" # Override num_epochs based on total steps and steps per epoch self._num_epochs = math.ceil(self._total_steps / steps_per_epoch) self._checkpoint_steps = min( self._checkpoint_steps, self._total_steps ) elif mode == modes.EVAL: self._total_steps = self._eval_steps_per_epoch self._fetch_steps = min(self._log_steps, self._total_steps) if self._fetch_steps == 0: # Always fetch the outputs of the last step self._fetch_steps = self._total_steps def _is_fetch_step_helper(self, step, step_offset: int = 0): """ Checks whether the step provided is a step where values are pre-scheduled to come off of the Cerebras system. Primarily for performance reasons. Args: step_offset: Used to offset the run step in eval where the global step is not incremented. """ step = step + step_offset return step == self._total_steps or ( self._fetch_steps > 0 and step % self._fetch_steps == 0 ) def _is_fetch_step(self, step_offset: int = 0): """ Checks whether we are on a step where values are pre-scheduled to come off of the Cerebras system. Primarily for performance reasons. Args: step_offset: Used to offset the run step in eval where the global step is not incremented. """ return self._is_fetch_step_helper(self._run_step, step_offset) def _is_checkpoint_step(self, step_offset: int = 0): """ Checks whether we are on a step where a checkpoint is pre-scheduled to come off of the Cerebras system. Primarily for performance reasons. Args: step_offset: Used to offset the run step in eval where the global step is not incremented. """ step = self._run_step + step_offset return self._checkpoint_steps > 0 and ( step == self._total_steps or step % self._checkpoint_steps == 0 )
[docs] def is_master_ordinal(self): """ Checks if distributed if enabled and if so whether it's the main process, most reading and writing should only happens on main process. """ return cm.is_master_ordinal()
@contextmanager def _configure_run( self, mode: str, dataloader: torch.utils.data.DataLoader ): """Configure the run for the mode using the provided dataloader The setup involves loading a checkpoint if specified as well as configuring the run steps for performance. Args: mode: the mode to configure the run for. dataloader: the dataloader used to configure the run. """ if not self._model.supports_mode(mode): raise ValueError( f"{mode} not supported for model. " f"Supported modes include: {self._model.supported_modes}." ) self._maybe_load_checkpoint(self._checkpoint_path, mode) if self.is_master_ordinal(): # Save initial checkpoint if self._save_initial_checkpoint: self._save_checkpoint(self._global_step) # Save dataloader streams for testing if self._save_stream_size: self._save_stream(dataloader, mode) # Create tensorboard summary writer for logging if mode in (modes.EVAL, modes.TRAIN_AND_EVAL): self._writers[modes.EVAL] = SummaryWriter( log_dir=os.path.join(self._model_dir, modes.EVAL) ) self._active_mode = modes.EVAL if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL): self._writers[modes.TRAIN] = SummaryWriter( log_dir=os.path.join(self._model_dir, modes.TRAIN) ) self._active_mode = modes.TRAIN self._loss_saver = LossSaver(self._writer) # Configure the number of steps to run based on # the size of the dataloader self._configure_run_steps(dataloader, mode) with ExitStack() as self.__cleanup_stack: try: yield finally: if self.is_master_ordinal(): for writer in self._writers.values(): writer.flush() writer.close()
[docs] def on_checkpoint_saved(self, checkpoint_path: str, step: int): """Function to execute after a checkpoint is saved."""
[docs] def on_train_start(self): """Function to execute before training starts"""
[docs] def on_train_end(self, early_exit: bool): """Function to execute after training ends"""
[docs] def on_eval_start(self): """Function to execute before eval starts"""
[docs] def on_eval_end(self, early_exit: bool): """Function to execute after eval ends"""
[docs] def on_train_epoch_start(self): """Function to execute before the training epoch begins"""
[docs] def on_train_epoch_end(self, early_exit: bool): """Function to execute after the training epoch ends"""
[docs] def on_eval_epoch_start(self): """Function to execute before the eval epoch begins"""
[docs] def on_eval_epoch_end(self, early_exit: bool): """Function to execute after the eval epoch ends"""
[docs] def on_train_batch_start(self, data): """Optionally pre-process data before train batch start""" return data
[docs] def on_train_batch_end(self, loss, epoch: int = None, step: int = None): """Actions to perform after the train batch iteration is complete""" # Add step closures self._maybe_write_log(loss) self._maybe_check_loss_value(loss) self._maybe_save_loss(loss) self._maybe_save_summaries() self._maybe_log_throughput() self._maybe_save_checkpoint()
[docs] def on_eval_batch_start(self, data): """Optionally pre-process data before eval batch start""" return data
[docs] def on_eval_batch_end(self, loss, epoch: int = None, step: int = None): """Actions to perform after the eval batch iteration is complete""" eval_step = step + 1 # Add step closures self._maybe_write_log(loss, step_offset=eval_step, base_step=0) self._maybe_check_loss_value(loss, step_offset=eval_step) self._maybe_save_summaries(step_offset=eval_step, base_step=0) self._accumulate_loss_value(loss)
[docs] def train_forward(self, data): """ Runs the train forward pass. Override this method to provide any additional functionality around the train forward pass call. """ return self._model(data)
[docs] def eval_forward(self, data): """ Runs the eval forward pass. Override this method to provide any additional functionality around the eval forward pass call. """ return self._model(data)
[docs] def backward(self, loss): """ Runs the backward pass. Override this method to provide any additional functionality around the backward call. """ if self._scaler: self._scaler.scale(loss).backward() else: loss.backward()
[docs] def optimizer_zero_grad(self): """Zeroes out the gradients in the optimizer""" self._optimizer.zero_grad()
[docs] def optimizer_step(self): """Performs the optimizer step""" if self._scaler: # Unscales the gradients of optimizer's assigned params in-place self._scaler.unscale_(self._optimizer) if self._should_log_extra_summaries: # gather unclipped gradients after unscale and before grad clipping self._log_summaries_grad_norm(is_clipped=False, is_scaled=False) self._log_summaries_grad_norm_per_layer( is_clipped=False, is_scaled=False ) # gradient clipping if ( hasattr(self._optimizer, "gradient_clipper") and self._optimizer.gradient_clipper is not None ): self._optimizer.gradient_clipper(self._model.model.parameters()) if self._scaler: self._scaler.step(self._optimizer) # Compute new loss scale. self._scaler.update() else: self._optimizer.step()
[docs] def lr_scheduler_step(self): """Performs the lr_scheduler step""" if self._lr_scheduler: self._lr_scheduler.step()
[docs] def train(self, train_dataloader: torch.utils.data.DataLoader): """Train the model with data generated by the given dataloader. Args: dataloader: A data loader for generating data to feed to the model. """ self._train_dataloader = train_dataloader with self._configure_run(modes.TRAIN, train_dataloader): self.on_train_start() exit_training = False for epoch in range(self._num_epochs): exit_training = self.train_epoch(epoch, train_dataloader) if exit_training: break self.on_train_end(exit_training)
[docs] def train_epoch( self, epoch: int, dataloader: torch.utils.data.DataLoader ) -> bool: """Runs an epoch of training Args: epoch: The current epoch dataloader: The dataloader to iterate through """ self._active_mode = modes.TRAIN exit_epoch = False exit_training = False accum_loss = None grad_accum_step = 0 # Set the appropriate writers self._loss_saver.writer = self._writer self.on_train_epoch_start() # Clear the loss to stop any noise from a previous epoch self._loss_saver.clear() self._model.train() # Enable training mode for epoch_step, data in enumerate(dataloader): data = self.on_train_batch_start(data) with self._dump_ctx: # Only zero out the gradients if on first step or immediately # following an optimizer step if grad_accum_step % self._grad_accum_steps == 0: self.optimizer_zero_grad() loss = self.train_forward(data) self.backward(loss) # accumulate the losses in a way that doesn't unnecessarily add # an addition op to the compute graph accum_loss = loss if not accum_loss else accum_loss + loss grad_accum_step += 1 if grad_accum_step % self._grad_accum_steps == 0: if self._should_log_extra_summaries: self._log_summaries_params_norm() with cbtorch.name_scope("optimizer"): self.optimizer_step() # Learning rate must be summarized before it is stepped for CS-X # and non-CS-X runs to match. This is because stepping it on # non-CS-X will eagerly update the value and we'll be printing # the learning rate used in the next step, not the current one. if self._should_log_extra_summaries: self._log_summaries_learing_rate() self._log_summaries_loss_scale() self.lr_scheduler_step() self._increment_global_step() self.on_train_batch_end(accum_loss, epoch, epoch_step) accum_loss = None # Check for early stopping in epoch and training loop. exit_epoch, exit_training = self._should_stop( epoch_step, modes.TRAIN ) if exit_epoch: break assert grad_accum_step >= self._grad_accum_steps, ( f"There were only {grad_accum_step} batches in epoch, which is " f"less than the grad accumulation steps {self._grad_accum_steps}. " f"This prevents model training as no optimizer step is taken." ) if grad_accum_step % self._grad_accum_steps != 0: warnings.warn( "There were leftover gradients in the accumulation step. " "They will effectively vanish, which could potentially lead " "to different convergence behaviour." ) self.on_train_epoch_end(exit_epoch) return exit_training
[docs] def evaluate(self, eval_dataloader: torch.utils.data.DataLoader): """Evaluate the model with data generated by the given dataloader. Args: dataloader: A data loader for generating data to feed to the model. """ self._eval_dataloader = eval_dataloader with self._configure_run(modes.EVAL, eval_dataloader): self.on_eval_start() self.eval_epoch(eval_dataloader) self.on_eval_end(early_exit=False)
@torch.no_grad() def eval_epoch(self, dataloader, epoch: int = None): """Runs an epoch of evaluation Args: dataloader: The dataloader to iterate through epoch: The current epoch """ self._active_mode = modes.EVAL exit_epoch = False reset_all_metrics() self.on_eval_epoch_start() # Clear the loss to stop any noise from a previous epoch self._loss_saver.clear() self._model.eval() step = 0 for step, data in enumerate(dataloader): data = self.on_eval_batch_start(data) with self._dump_ctx: outputs = self.eval_forward(data) loss = extract_loss(outputs) self.on_eval_batch_end(loss, epoch, step) exit_epoch, _ = self._should_stop(step, modes.EVAL) if exit_epoch: break self.on_eval_epoch_end(exit_epoch) self.compute_eval_metrics() self._maybe_log_throughput(step + 1)
[docs] def compute_eval_metrics(self): """Compute and log the eval metrics""" self.print_eval_metrics(compute_all_metrics())
[docs] def print_eval_metrics(self, eval_metrics): """Compute and log the eval metrics""" if eval_metrics: if self._writer: for metric_scope, metric_value in visit_structure( eval_metrics, select_fn=lambda struct: isinstance(struct, (int, float)), strict=True, ): key = "/".join(metric_scope) self._writer.add_scalar( key, metric_value, self._global_step ) logging.info(f"Avg eval_metrics = {eval_metrics}") # Normalize total loss avg_eval_loss = self._loss_saver.average_loss if self._writer: self._writer.add_scalar("loss", avg_eval_loss, self._global_step) logging.info(f"Avg Eval Loss: {avg_eval_loss}")
[docs] def train_and_eval( self, train_dataloader: torch.utils.data.DataLoader, eval_dataloader: torch.utils.data.DataLoader, ): """Train and evaluate the model with data generated by dataloaders. In each epoch, this method trains the model first, then runs evaluation every epoch. Args: train_dataloader: A data loader for generating training data to feed to the model. eval_dataloader: A data loader for generating evaluation data to feed to the model. """ self._train_dataloader = train_dataloader self._eval_dataloader = eval_dataloader with self._configure_run( modes.TRAIN_AND_EVAL, (train_dataloader, eval_dataloader), ): self.on_train_start() exit_training = False for epoch in range(self._num_epochs): exit_training = self.train_epoch(epoch, train_dataloader) self.eval_epoch(eval_dataloader) if exit_training: break self.on_train_end(exit_training) logging.info("Training and Evaluation Completed Successfully!")
@abc.abstractmethod def _increment_global_step(self): raise NotImplementedError() def _maybe_write_log(self, loss, step_offset=0, base_step=None): if self._is_fetch_step(step_offset): if base_step is None: base_step = self._global_step self._write_log(loss, base_step + step_offset) @abc.abstractmethod def _write_log(self, loss, global_step): raise NotImplementedError() def _maybe_save_loss(self, loss, epoch=None, step_offset=0): if self._save_losses and self._is_fetch_step(step_offset): self._save_loss(loss, self._global_step, epoch, step_offset) @cm.step_closure def _save_loss( self, loss: torch.Tensor, global_step: int, epoch: int = None, step_offset: int = 0, ): """Save the current step's loss Args: loss: The loss tensor. global_step: The global step epoch: The current epoch. Used in `train_and_eval` mode to distinguish the eval losses on a per epoch basis. step_offset: The amount to offset to global step to account for the fact that the global step is not incremented in eval mode """ if epoch is not None: self._loss_saver.add(loss, step_offset, epoch) else: self._loss_saver.add(loss, global_step + step_offset) def _maybe_save_summaries(self, step_offset=0, base_step=None): """Saves summaries calculated at the current step.""" if self._is_fetch_step(step_offset): if base_step is None: base_step = self._global_step save_all_summaries(self._writer, base_step + step_offset) else: discard_cached_summaries() def _maybe_save_checkpoint(self, step_offset=0): if self._is_checkpoint_step(step_offset): self._save_checkpoint(self._global_step) @abc.abstractmethod def _save_checkpoint(self, *args, **kwargs): raise NotImplementedError() def _maybe_load_checkpoint(self, checkpoint_path: Optional[str], mode: str): """Optionally load checkpoint into the model. Args: checkpoint_path: Path to a checkpoint file. Returns: The loaded state dict. If checkpoint path was None, returns None. """ if checkpoint_path: logging.info( f"Loading weights from checkpoint {self._checkpoint_path}" ) state_dict = cbtorch.load(checkpoint_path) self._model.set_state(state_dict) else: logging.info( f"No checkpoint was provided, using randomly initialized model " f"parameters." ) state_dict = None if state_dict and not self._is_pretrained_checkpoint: self._global_step = state_dict.get("global_step", 0) else: self._global_step = 0 self._initial_step = self._global_step return state_dict # Returns path to last checkpoint or None if no checkpoints exist def _get_last_checkpoint(self): # Used when running on interuptable instances in order to reload from # the last checkpoint. if not os.path.exists(self._model_dir): return None last_ckpt = (None, None) # (step of last ckpt, path of last ckpt) for name in os.listdir(self._model_dir): match = re.fullmatch(rf'checkpoint_(\d+)\.mdl', name) if match: step = int(match.group(1)) if last_ckpt[0] is None or step > last_ckpt[0]: last_ckpt = (step, os.path.join(self._model_dir, name)) if last_ckpt[1] is not None: logging.info( f"Found latest checkpoint at {last_ckpt[1]}." f"This checkpoint will be used for loading model state." ) return last_ckpt[1] def _maybe_check_loss_value(self, loss, step_offset=0): if self._check_loss_values and self._is_fetch_step(step_offset): self._check_loss_value(loss) @cm.step_closure def _check_loss_value(self, loss: torch.Tensor): """Checks to see if loss is Nan/inf. Args: loss: The loss tensor. Raises: ValueError if the loss is either NaN or inf. """ loss = cm.to_cpu(loss.detach()) msg_postfix = ( "This could potentially be due to selected hyperparameters such as " "the learning rate, batch size, etc. or it could due an internal " "error. Please try with different set of hyperparameters and " "contact Cerebras Support if the issue persists." ) if torch.isnan(loss).any().item(): raise ValueError(f"NaN loss detected. {msg_postfix}") if torch.isinf(loss).any().item(): raise ValueError(f"inf loss detected. {msg_postfix}") def _maybe_log_throughput(self, step_offset=0): """Conditionally add throughput details to tensorboard""" if self._is_fetch_step(step_offset): self._log_throughput(self._global_step + step_offset) def _log_throughput(self, step): """Add throughput details to tensorboard""" # Can be optionally implemented downstream @cm.step_closure def _accumulate_loss_value(self, loss: torch.Tensor): """ Accumulates our loss value to a total_loss Args: loss: The loss tensor. """ self._loss_saver.accumulate(loss) def _save_stream(self, data_loader, mode: str): if mode == modes.TRAIN_AND_EVAL: train_data_loader, eval_data_loader = data_loader self._save_stream(train_data_loader, modes.TRAIN) self._save_stream(eval_data_loader, modes.EVAL) return data = defaultdict(list) i = 0 while i < self._save_stream_size: for batch in data_loader: for scope, tensor in visit_structure( batch, lambda t: isinstance(t, torch.Tensor), strict=True ): data[".".join(map(str, scope))].append( cm.to_cpu(tensor.detach()).numpy() ) i += 1 if i >= self._save_stream_size: break ordinal = cm.get_ordinal() stream_dir = os.path.join(self._model_dir, mode) os.makedirs(stream_dir, exist_ok=True) np.savez(os.path.join(stream_dir, f"streams.{ordinal}.npz"), **data)
[docs] @staticmethod def create( model_fn: Callable[[dict, Optional[torch.device]], PyTorchBaseModel], params: dict, ) -> "PyTorchBaseRunner": """ Creates and returns an instance of PyTorchBaseRunner that has been configured based on the hardware specified by the provided params dictionary Args: model_fn: A callable that takes in a 'params' argument and optionally a torch.device which it uses to configure and return a PyTorchBaseModel params: A dictionary containing all the parameters required to initialize and configure both the model and the runner """ runconfig_params = params["runconfig"] RunConfigParamsValidator().validate(runconfig_params) if runconfig_params.get("max_checkpoints"): warnings.warn("`max_checkpoints` is only supported using cstorch") target_device = runconfig_params["target_device"] if target_device == DeviceType.CSX: from cerebras_appliance import DEFAULT_COMPILE_DIR from modelzoo.common.pytorch.pytorch_cs_appliance import ( PyTorchCSAppliance, ) cbtorch.initialize( service_workdir=runconfig_params["service_dir"], compile_dir=( runconfig_params.get("compile_dir") or DEFAULT_COMPILE_DIR ), compile_only=( runconfig_params["compile_only"] or runconfig_params["validate_only"] ), appliance=True, use_appliance_data=runconfig_params.get( "use_appliance_data", True ), use_cbfloat16=params.get("csconfig", {}).get( "use_cbfloat16", False ), log_initialization=runconfig_params.get( "log_initialization", True ), ) seed = runconfig_params.get("seed") if seed is not None: cm.set_rng_state(seed) execution_strategy = ( runconfig_params.get("execution_strategy") or ExecutionStrategy.weight_streaming ) if ( runconfig_params["num_csx"] > 1 and execution_strategy != ExecutionStrategy.weight_streaming ): raise ValueError( f"Using multiple CS-X systems is only supported in " f"appliance mode with " f"{ExecutionStrategy.weight_streaming} execution " f"strategy." ) cbtorch.env().weight_streaming_mode = ( execution_strategy == ExecutionStrategy.weight_streaming ) logging.debug( f"Running with {execution_strategy} execution strategy" ) if not cbtorch.env().weight_streaming_mode: use_bfloat16 = params["model"].get("use_bfloat16", False) if use_bfloat16: warnings.warn( f"bfloat16 is not supported on Pipeline execution " f"strategy. Setting use_bfloat16 in model config to " f"False." ) params["model"]["use_bfloat16"] = False model = _base_model_compat(model_fn, params) return PyTorchCSAppliance(model, params) elif target_device == DeviceType.CPU: from modelzoo.common.pytorch.pytorch_runner import PyTorchRunner device = torch.device("cpu") if params["model"]["mixed_precision"] and not params["model"].get( "use_bfloat16", False ): warnings.warn( "Mixed precision on CPU is only supported with bfloat16. " "Setting use_bfloat16 in model config to True." ) params["model"]["use_bfloat16"] = True model = _base_model_compat(model_fn, params, device) return PyTorchRunner(device, model, params) elif target_device == DeviceType.GPU: if not torch.cuda.is_available(): raise RuntimeError( f"{DeviceType.GPU} was specified as the target device, but " f"CUDA is not available. Please make sure you have a " f"PyTorch installation with CUDA enabled to run on " f"{DeviceType.GPU}." ) world_size = torch.cuda.device_count() enable_distributed = params["runconfig"].get( "enable_distributed", False ) if not enable_distributed: # Single GPU from modelzoo.common.pytorch.pytorch_runner import PyTorchRunner if world_size > 1: warnings.warn( "Distributed training was not enabled even though " "more than 1 GPU is available." ) device = torch.device("cuda") model = _base_model_compat(model_fn, params, device) return PyTorchRunner(device, model, params) else: # Distributed GPU from modelzoo.common.pytorch.pytorch_dist_runner import ( PyTorchDistRunner, ) if world_size == 1: warnings.warn( "Distributed training was enabled, but only " "1 GPU was detected." ) # Model with no device, used to create optimizer and scheduler # actual models will be created in on_process_start() model = _base_model_compat(model_fn, params) return PyTorchDistRunner(model, params) else: raise ValueError( f"Unsupported target device: {target_device}. " f"Supported devices are: {', '.join(DeviceType.devices)}" )
def _base_model_compat(model_fn, params, device=None) -> PyTorchBaseModel: # Initialize the model and runner if isclass(model_fn) and issubclass(model_fn, PyTorchBaseModel): # to keep compatibility for if a user inherits from PyTorchBaseModel model = model_fn(params) else: model = PyTorchBaseModel(params, model_fn, device) return model