# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modulek containing the Base PyTorch Runner"""
# pylint: disable=no-self-use, attribute-defined-outside-init
import abc
import copy
import logging
import math
import os
import re
import warnings
from collections import defaultdict
from contextlib import ExitStack, contextmanager, nullcontext
from inspect import isclass
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch, modes
from modelzoo.common.pytorch.dump_context import DumpContext
from modelzoo.common.pytorch.loss_utils import LossSaver, extract_loss
from modelzoo.common.pytorch.metrics import (
compute_all_metrics,
reset_all_metrics,
)
from modelzoo.common.pytorch.PyTorchBaseModel import PyTorchBaseModel
from modelzoo.common.pytorch.summaries import (
discard_cached_summaries,
save_all_summaries,
scalar_summary,
)
from modelzoo.common.pytorch.utils import (
RunConfigParamsValidator,
visit_structure,
)
from modelzoo.common.run_utils.utils import DeviceType, ExecutionStrategy
[docs]class PyTorchBaseRunner(metaclass=abc.ABCMeta):
"""The base class for running PyTorch models on any device."""
[docs] def __init__(self, model: PyTorchBaseModel, params: dict):
"""Construct a `PyTorchRunner` instance.
Args:
model: The PyTorch model to run.
param: A dict of params that specify the behavior of the model.
"""
self._model = model
self._params = copy.deepcopy(params)
self._optimizer = None
self._lr_scheduler = None
self._scaler = None
mode = self._runconfig["mode"]
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
self._optimizer = model.get_optimizer()
self._lr_scheduler = model.get_lr_scheduler()
# The mode that is currently active
self._active_mode = mode
# Mandatory config options
self._mixed_precision = self._params["model"]["mixed_precision"]
# Optional config options
self._grad_accum_steps = self._params["optimizer"].get(
"grad_accum_steps", 1
)
self._log_summaries = self._params["optimizer"].get(
"log_summaries", False
)
self._show_debug_metrics = self._runconfig.get(
"show_debug_metrics", False
)
self._save_losses = self._runconfig.get("save_losses", True)
self._check_loss_values = self._runconfig.get("check_loss_values", True)
self._model_dir = self._runconfig.get("model_dir", "./")
self._checkpoint_path = self._runconfig.get("checkpoint_path")
if not self._checkpoint_path and self._runconfig.get(
"autoload_last_checkpoint", True
):
self._checkpoint_path = self._get_last_checkpoint()
self._is_pretrained_checkpoint = self._runconfig.get(
"is_pretrained_checkpoint", False,
)
self._save_initial_checkpoint = self._runconfig.get(
"save_initial_checkpoint", False
)
self._save_stream_size = self._runconfig.get("save_stream_size", 0)
# Summary writer object
self._writers = {}
# A cleanup stack that subclasses can hook into during a run to do any
# necessary cleanups in case of failures
self.__cleanup_stack: Union[ExitStack, None] = None
# These are set up at the start of execution loop
self._global_step = None
self._initial_step = None
self._total_steps = None
if self._runconfig.get("dump_activations", False):
if cm.use_cs():
logging.warning(
"Actually dumping activations on CS2 requires setting "
"additional debug arguments. This runconfig option just "
"enables extra debug location info in the compile."
)
self._dump_ctx = DumpContext(
os.path.join(self._model_dir, "act_dumps"), self._model.model,
)
else:
self._dump_ctx = nullcontext()
def _log_summaries_params_norm(self, requires_grad=True):
"""
Args:
requires_grad (bool): whether to only include params that requires gradient update
"""
# Computes global norm of all params, but calculating
# norm of each set of params individually first and then
# combining of all norms.
if cm.use_cs():
device = self._model.model.device
else:
device = self._model.device
param_norm = torch.tensor(0.0).to(device)
for _, param in self._model.model.named_parameters():
if not requires_grad:
# simply add if we want to include all params
param_norm += torch.pow(torch.norm(param), 2.0)
elif param.requires_grad:
# only add the param if it requires gradient update
param_norm += torch.pow(torch.norm(param), 2.0)
param_norm = torch.sqrt(param_norm)
scalar_summary("model_wise_params_norm", param_norm)
def _log_summaries_grad_norm(self, is_clipped=False, is_scaled=False):
"""
Args:
is_clipped (bool): whether to log clipped gradient
is_scaled (bool): whether to log scaled gradient
"""
# This should be called after unscaling and before grad clipping
if cm.use_cs():
device = self._model.model.device
else:
device = self._model.device
param_grad_norm = torch.tensor(0.0).to(device)
for _, param in self._model.model.named_parameters():
if param.grad is not None:
param_grad_norm += torch.pow(torch.norm(param.grad), 2.0)
param_grad_norm = torch.sqrt(param_grad_norm)
summary_str = "model_wise_grad_norm"
summary_str += "_clipped" if is_clipped else "_unclipped"
summary_str += "_scaled" if is_scaled else "_unscaled"
scalar_summary(summary_str, param_grad_norm)
def _log_summaries_grad_norm_per_layer(
self, is_clipped=False, is_scaled=False
):
"""
Args:
is_clipped (bool): whether to log clipped gradient
is_scaled (bool): whether to log scaled gradient
"""
# Computes global norm of all params, but calculating
# norm of each set of params individually first and then
# combining of all norms.
param_norm = {}
layer_pattern_str = r'.*(layers\.)(\d+)(\.).*'
layer_pattern = re.compile(layer_pattern_str)
for name, param in self._model.model.named_parameters():
if param.grad is None:
continue
# get a match if module name contains `layers.i.0` where i is layer num
match = layer_pattern.match(name)
if match:
layer_id = match.group(2)
if layer_id not in param_norm:
param_norm[layer_id] = torch.tensor(0.0).to(param.device)
param_norm[layer_id] += torch.pow(torch.norm(param.grad), 2.0)
for layer_id in param_norm:
param_norm[layer_id] = torch.sqrt(param_norm[layer_id])
summary_str = "per_layer_grad_norm"
summary_str += "_clipped" if is_clipped else "_unclipped"
summary_str += "_scaled" if is_scaled else "_unscaled"
summary_str += f"/layer_{layer_id}"
scalar_summary(summary_str, param_norm[layer_id])
def _log_summaries_learing_rate(self):
if self._lr_scheduler:
# self._lr_scheduler.get_last_lr() return a list of LRs for
# different param groups of the optimizer. We create one
# learning rate tracker for each param group with identifier `lr_{i}`
# where i is the index of the param group
last_lrs = self._lr_scheduler.get_last_lr()
if not isinstance(last_lrs, list):
last_lrs = [last_lrs]
for i, last_lr in enumerate(last_lrs):
scalar_summary(f"lr_params_group_{i}", last_lr)
def _log_summaries_loss_scale(self):
if self._scaler:
scalar_summary("loss_scale", self._scaler.get_scale())
@property
def _should_log_extra_summaries(self):
return self._log_summaries
@property
def _writer(self) -> Optional[SummaryWriter]:
return self._writers.get(self._active_mode)
@property
def _runconfig(self):
return self._params["runconfig"]
@property
def _run_step(self) -> int:
"""Returns the current execution step.
This is different from global_step in that it indicates the execution
step of the current run.
"""
assert self._global_step is not None
assert self._initial_step is not None
return self._global_step - self._initial_step
@property
def _loss_dir(self) -> str:
"""Return the directory to use for saving intermediate losses."""
loss_dir = os.path.join(self._model_dir, "losses")
os.makedirs(loss_dir, exist_ok=True)
return loss_dir
@property
def _cleanup_stack(self) -> ExitStack:
"""Returns an ExitStack only available during execution."""
assert (
self.__cleanup_stack
), "Cleanup stack is only available during execution."
return self.__cleanup_stack
def _validate_config(self):
"""Check that the provided config is valid.
Raises:
AssertionError if any of the config options is invalid.
"""
# num_epochs and num_steps are mutually exclusive. max_steps is optional
# unless neither num_epochs nor num_steps are provided, in which case
# max_steps must be provided.
if self._num_epochs is not None and self._num_steps is not None:
raise ValueError(
"Please specify only one of `num_epochs` or `num_steps`."
)
elif self._num_epochs is not None:
assert (
self._num_epochs > 0
), "When provided, `num_epochs` must be greater than zero."
elif self._num_steps is not None:
assert (
self._num_steps > 0
), "When provided, `num_steps` must be greater than zero."
else:
if self._max_steps is None:
raise ValueError(
"`max_steps` must be specified if neither `num_epochs` "
"nor `num_steps` are specified."
)
assert self._max_steps is None or self._max_steps > 0
assert (
self._train_steps_per_epoch is None
or self._train_steps_per_epoch > 0
)
assert (
self._eval_steps_per_epoch is None or self._eval_steps_per_epoch > 0
)
assert self._grad_accum_steps >= 1
assert self._checkpoint_steps >= 0
assert self._log_steps >= 0
if cm.use_cs() and self._grad_accum_steps > 1:
raise ValueError(
"Gradient Accumulation not supported on CS workflow."
)
def _should_stop(self, epoch_step: int, mode: str) -> Tuple[bool, bool]:
"""Return a tuple indicating whether to stop epoch/training.
Args:
epoch_step: The current step in the epoch loop.
Returns:
A tuple of booleans. The first item indicates whether to exit the
epoch loop. The second item indicates whether to exit training.
"""
exit_epoch = exit_training = False
if mode == modes.TRAIN:
steps_per_epoch = self._train_steps_per_epoch
elif mode == modes.EVAL:
steps_per_epoch = self._eval_steps_per_epoch
else:
raise ValueError(f"Unhandled mode: {mode}.")
if steps_per_epoch and epoch_step + 1 >= steps_per_epoch:
exit_epoch = True
if mode == modes.TRAIN and self._run_step >= self._total_steps:
exit_epoch = True
exit_training = True
return exit_epoch, exit_training
def _configure_run_steps(self, dataloader, mode: str):
"""Configure steps which specify training behavior.
This method sets `self._total_steps`, `self._checkpoint_steps`,
`self._fetch_steps`, and `self._num_epochs` based on the user-provided
config. If the current global step exceeds max steps, this method raises
an error.
Args:
data_loader: The data loader instance that is used for training.
Raises:
RuntimError if global step exceeds max steps.
"""
assert self._global_step is not None
self._log_steps = self._runconfig["log_steps"]
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
self._num_epochs = self._runconfig.get("num_epochs", None)
self._num_steps = self._runconfig.get("num_steps", None)
self._max_steps = self._runconfig.get("max_steps", None)
self._checkpoint_steps = self._runconfig.get("checkpoint_steps", 0)
elif mode == modes.EVAL:
self._num_epochs = 1
self._num_steps = None
self._max_steps = None
self._checkpoint_steps = 0
else:
raise ValueError(f"Unhandled mode: {mode}.")
self._train_steps_per_epoch = self._runconfig.get(
"steps_per_epoch", None
)
self._eval_steps_per_epoch = self._runconfig.get("eval_steps", None)
self._validate_config()
if self._max_steps is not None and self._global_step >= self._max_steps:
raise RuntimeError(
f"Global step {self._global_step} already exceeds "
f"max step {self._max_steps}."
)
if mode == modes.TRAIN:
train_dataloader = dataloader
elif mode == modes.EVAL:
eval_dataloader = dataloader
elif mode == modes.TRAIN_AND_EVAL:
train_dataloader, eval_dataloader = dataloader
train_steps_per_epoch = 0
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
try:
# Dataset length is known
train_steps_per_epoch = len(train_dataloader)
assert (
train_steps_per_epoch > 0
), "Train Dataloader does not generate any batches."
if self._train_steps_per_epoch is not None:
assert (
self._train_steps_per_epoch <= train_steps_per_epoch
), (
f"The requested steps per epoch of {self._train_steps_per_epoch} "
f"exceeds total steps in an epoch, which is "
f"{train_steps_per_epoch}."
)
train_steps_per_epoch = self._train_steps_per_epoch
# With grad accumulation, the global step is incremented every Nth
# batch, so our effective steps per epoch needs to be adjusted.
assert self._grad_accum_steps <= train_steps_per_epoch, (
f"Gradient accumulation steps of {self._grad_accum_steps} is "
f"greater than batches per epoch of {train_steps_per_epoch}."
)
train_steps_per_epoch //= self._grad_accum_steps
except TypeError:
# Dataset length is not known
assert self._num_epochs is None, (
"Specifying num_epochs for datasets with unknown length is "
"not allowed. Please control training behavior through "
"number of steps instead."
)
train_steps_per_epoch = 1
if mode in (modes.EVAL, modes.TRAIN_AND_EVAL):
try:
# Dataset length is known
eval_steps_per_epoch = len(eval_dataloader)
assert (
eval_steps_per_epoch > 0
), "Eval Dataloader does not generate any batches."
if self._eval_steps_per_epoch is not None:
# Do a sanity check that we can generate requested batches
assert self._eval_steps_per_epoch <= eval_steps_per_epoch, (
f"The requested steps per epoch of "
f"{self._eval_steps_per_epoch} exceeds total steps in "
f"an epoch, which is {eval_steps_per_epoch}."
)
else:
# If not explicitly specified, run the whole eval epoch
self._eval_steps_per_epoch = eval_steps_per_epoch
except TypeError:
# Dataset length is not known, so eval steps must be specified.
# We assume the dataloader generates as many steps as the user
# specified. Otherwise, we may get a stall. There's no way for
# us to validate this since we can't query the dataset length.
assert self._eval_steps_per_epoch is not None, (
"`eval_steps` must be specified for datasets with unknown "
"length."
)
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
steps_per_epoch = train_steps_per_epoch
# Calculate total steps
self._total_steps = math.inf
if self._num_epochs is not None:
self._total_steps = min(
self._total_steps, self._num_epochs * steps_per_epoch
)
if self._num_steps is not None:
self._total_steps = min(self._total_steps, self._num_steps)
if self._max_steps is not None:
remaining_steps = self._max_steps - self._global_step
assert remaining_steps > 0, ( # This was checked above
f"Global step {self._global_step} already exceeds "
f"max step {self._max_steps}."
)
self._total_steps = min(self._total_steps, remaining_steps)
# At least one of the above if blocks must have been true.
# Adding an assert in case someone makes a mistake.
assert not math.isinf(
self._total_steps
), "One of num_epochs, num_steps, or max_steps must be provided"
# Override num_epochs based on total steps and steps per epoch
self._num_epochs = math.ceil(self._total_steps / steps_per_epoch)
self._checkpoint_steps = min(
self._checkpoint_steps, self._total_steps
)
elif mode == modes.EVAL:
self._total_steps = self._eval_steps_per_epoch
self._fetch_steps = min(self._log_steps, self._total_steps)
if self._fetch_steps == 0: # Always fetch the outputs of the last step
self._fetch_steps = self._total_steps
def _is_fetch_step_helper(self, step, step_offset: int = 0):
"""
Checks whether the step provided is a step where values are pre-scheduled to
come off of the Cerebras system.
Primarily for performance reasons.
Args:
step_offset: Used to offset the run step in eval where the global
step is not incremented.
"""
step = step + step_offset
return step == self._total_steps or (
self._fetch_steps > 0 and step % self._fetch_steps == 0
)
def _is_fetch_step(self, step_offset: int = 0):
"""
Checks whether we are on a step where values are pre-scheduled to
come off of the Cerebras system.
Primarily for performance reasons.
Args:
step_offset: Used to offset the run step in eval where the global
step is not incremented.
"""
return self._is_fetch_step_helper(self._run_step, step_offset)
def _is_checkpoint_step(self, step_offset: int = 0):
"""
Checks whether we are on a step where a checkpoint is pre-scheduled to
come off of the Cerebras system.
Primarily for performance reasons.
Args:
step_offset: Used to offset the run step in eval where the global
step is not incremented.
"""
step = self._run_step + step_offset
return self._checkpoint_steps > 0 and (
step == self._total_steps or step % self._checkpoint_steps == 0
)
[docs] def is_master_ordinal(self):
"""
Checks if distributed if enabled and if so whether
it's the main process, most reading and writing should
only happens on main process.
"""
return cm.is_master_ordinal()
@contextmanager
def _configure_run(
self, mode: str, dataloader: torch.utils.data.DataLoader
):
"""Configure the run for the mode using the provided dataloader
The setup involves loading a checkpoint if specified
as well as configuring the run steps for performance.
Args:
mode: the mode to configure the run for.
dataloader: the dataloader used to configure the run.
"""
if not self._model.supports_mode(mode):
raise ValueError(
f"{mode} not supported for model. "
f"Supported modes include: {self._model.supported_modes}."
)
self._maybe_load_checkpoint(self._checkpoint_path, mode)
if self.is_master_ordinal():
# Save initial checkpoint
if self._save_initial_checkpoint:
self._save_checkpoint(self._global_step)
# Save dataloader streams for testing
if self._save_stream_size:
self._save_stream(dataloader, mode)
# Create tensorboard summary writer for logging
if mode in (modes.EVAL, modes.TRAIN_AND_EVAL):
self._writers[modes.EVAL] = SummaryWriter(
log_dir=os.path.join(self._model_dir, modes.EVAL)
)
self._active_mode = modes.EVAL
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
self._writers[modes.TRAIN] = SummaryWriter(
log_dir=os.path.join(self._model_dir, modes.TRAIN)
)
self._active_mode = modes.TRAIN
self._loss_saver = LossSaver(self._writer)
# Configure the number of steps to run based on
# the size of the dataloader
self._configure_run_steps(dataloader, mode)
with ExitStack() as self.__cleanup_stack:
try:
yield
finally:
if self.is_master_ordinal():
for writer in self._writers.values():
writer.flush()
writer.close()
[docs] def on_checkpoint_saved(self, checkpoint_path: str, step: int):
"""Function to execute after a checkpoint is saved."""
[docs] def on_train_start(self):
"""Function to execute before training starts"""
[docs] def on_train_end(self, early_exit: bool):
"""Function to execute after training ends"""
[docs] def on_eval_start(self):
"""Function to execute before eval starts"""
[docs] def on_eval_end(self, early_exit: bool):
"""Function to execute after eval ends"""
[docs] def on_train_epoch_start(self):
"""Function to execute before the training epoch begins"""
[docs] def on_train_epoch_end(self, early_exit: bool):
"""Function to execute after the training epoch ends"""
[docs] def on_eval_epoch_start(self):
"""Function to execute before the eval epoch begins"""
[docs] def on_eval_epoch_end(self, early_exit: bool):
"""Function to execute after the eval epoch ends"""
[docs] def on_train_batch_start(self, data):
"""Optionally pre-process data before train batch start"""
return data
[docs] def on_train_batch_end(self, loss, epoch: int = None, step: int = None):
"""Actions to perform after the train batch iteration is complete"""
# Add step closures
self._maybe_write_log(loss)
self._maybe_check_loss_value(loss)
self._maybe_save_loss(loss)
self._maybe_save_summaries()
self._maybe_log_throughput()
self._maybe_save_checkpoint()
[docs] def on_eval_batch_start(self, data):
"""Optionally pre-process data before eval batch start"""
return data
[docs] def on_eval_batch_end(self, loss, epoch: int = None, step: int = None):
"""Actions to perform after the eval batch iteration is complete"""
eval_step = step + 1
# Add step closures
self._maybe_write_log(loss, step_offset=eval_step, base_step=0)
self._maybe_check_loss_value(loss, step_offset=eval_step)
self._maybe_save_summaries(step_offset=eval_step, base_step=0)
self._accumulate_loss_value(loss)
[docs] def train_forward(self, data):
"""
Runs the train forward pass.
Override this method to provide any additional functionality around
the train forward pass call.
"""
return self._model(data)
[docs] def eval_forward(self, data):
"""
Runs the eval forward pass.
Override this method to provide any additional functionality around
the eval forward pass call.
"""
return self._model(data)
[docs] def backward(self, loss):
"""
Runs the backward pass.
Override this method to provide any additional functionality around
the backward call.
"""
if self._scaler:
self._scaler.scale(loss).backward()
else:
loss.backward()
[docs] def optimizer_zero_grad(self):
"""Zeroes out the gradients in the optimizer"""
self._optimizer.zero_grad()
[docs] def optimizer_step(self):
"""Performs the optimizer step"""
if self._scaler:
# Unscales the gradients of optimizer's assigned params in-place
self._scaler.unscale_(self._optimizer)
if self._should_log_extra_summaries:
# gather unclipped gradients after unscale and before grad clipping
self._log_summaries_grad_norm(is_clipped=False, is_scaled=False)
self._log_summaries_grad_norm_per_layer(
is_clipped=False, is_scaled=False
)
# gradient clipping
if (
hasattr(self._optimizer, "gradient_clipper")
and self._optimizer.gradient_clipper is not None
):
self._optimizer.gradient_clipper(self._model.model.parameters())
if self._scaler:
self._scaler.step(self._optimizer)
# Compute new loss scale.
self._scaler.update()
else:
self._optimizer.step()
[docs] def lr_scheduler_step(self):
"""Performs the lr_scheduler step"""
if self._lr_scheduler:
self._lr_scheduler.step()
[docs] def train(self, train_dataloader: torch.utils.data.DataLoader):
"""Train the model with data generated by the given dataloader.
Args:
dataloader: A data loader for generating data to feed to the model.
"""
self._train_dataloader = train_dataloader
with self._configure_run(modes.TRAIN, train_dataloader):
self.on_train_start()
exit_training = False
for epoch in range(self._num_epochs):
exit_training = self.train_epoch(epoch, train_dataloader)
if exit_training:
break
self.on_train_end(exit_training)
[docs] def train_epoch(
self, epoch: int, dataloader: torch.utils.data.DataLoader
) -> bool:
"""Runs an epoch of training
Args:
epoch: The current epoch
dataloader: The dataloader to iterate through
"""
self._active_mode = modes.TRAIN
exit_epoch = False
exit_training = False
accum_loss = None
grad_accum_step = 0
# Set the appropriate writers
self._loss_saver.writer = self._writer
self.on_train_epoch_start()
# Clear the loss to stop any noise from a previous epoch
self._loss_saver.clear()
self._model.train() # Enable training mode
for epoch_step, data in enumerate(dataloader):
data = self.on_train_batch_start(data)
with self._dump_ctx:
# Only zero out the gradients if on first step or immediately
# following an optimizer step
if grad_accum_step % self._grad_accum_steps == 0:
self.optimizer_zero_grad()
loss = self.train_forward(data)
self.backward(loss)
# accumulate the losses in a way that doesn't unnecessarily add
# an addition op to the compute graph
accum_loss = loss if not accum_loss else accum_loss + loss
grad_accum_step += 1
if grad_accum_step % self._grad_accum_steps == 0:
if self._should_log_extra_summaries:
self._log_summaries_params_norm()
with cbtorch.name_scope("optimizer"):
self.optimizer_step()
# Learning rate must be summarized before it is stepped for CS-X
# and non-CS-X runs to match. This is because stepping it on
# non-CS-X will eagerly update the value and we'll be printing
# the learning rate used in the next step, not the current one.
if self._should_log_extra_summaries:
self._log_summaries_learing_rate()
self._log_summaries_loss_scale()
self.lr_scheduler_step()
self._increment_global_step()
self.on_train_batch_end(accum_loss, epoch, epoch_step)
accum_loss = None
# Check for early stopping in epoch and training loop.
exit_epoch, exit_training = self._should_stop(
epoch_step, modes.TRAIN
)
if exit_epoch:
break
assert grad_accum_step >= self._grad_accum_steps, (
f"There were only {grad_accum_step} batches in epoch, which is "
f"less than the grad accumulation steps {self._grad_accum_steps}. "
f"This prevents model training as no optimizer step is taken."
)
if grad_accum_step % self._grad_accum_steps != 0:
warnings.warn(
"There were leftover gradients in the accumulation step. "
"They will effectively vanish, which could potentially lead "
"to different convergence behaviour."
)
self.on_train_epoch_end(exit_epoch)
return exit_training
[docs] def evaluate(self, eval_dataloader: torch.utils.data.DataLoader):
"""Evaluate the model with data generated by the given dataloader.
Args:
dataloader: A data loader for generating data to feed to the model.
"""
self._eval_dataloader = eval_dataloader
with self._configure_run(modes.EVAL, eval_dataloader):
self.on_eval_start()
self.eval_epoch(eval_dataloader)
self.on_eval_end(early_exit=False)
@torch.no_grad()
def eval_epoch(self, dataloader, epoch: int = None):
"""Runs an epoch of evaluation
Args:
dataloader: The dataloader to iterate through
epoch: The current epoch
"""
self._active_mode = modes.EVAL
exit_epoch = False
reset_all_metrics()
self.on_eval_epoch_start()
# Clear the loss to stop any noise from a previous epoch
self._loss_saver.clear()
self._model.eval()
step = 0
for step, data in enumerate(dataloader):
data = self.on_eval_batch_start(data)
with self._dump_ctx:
outputs = self.eval_forward(data)
loss = extract_loss(outputs)
self.on_eval_batch_end(loss, epoch, step)
exit_epoch, _ = self._should_stop(step, modes.EVAL)
if exit_epoch:
break
self.on_eval_epoch_end(exit_epoch)
self.compute_eval_metrics()
self._maybe_log_throughput(step + 1)
[docs] def compute_eval_metrics(self):
"""Compute and log the eval metrics"""
self.print_eval_metrics(compute_all_metrics())
[docs] def print_eval_metrics(self, eval_metrics):
"""Compute and log the eval metrics"""
if eval_metrics:
if self._writer:
for metric_scope, metric_value in visit_structure(
eval_metrics,
select_fn=lambda struct: isinstance(struct, (int, float)),
strict=True,
):
key = "/".join(metric_scope)
self._writer.add_scalar(
key, metric_value, self._global_step
)
logging.info(f"Avg eval_metrics = {eval_metrics}")
# Normalize total loss
avg_eval_loss = self._loss_saver.average_loss
if self._writer:
self._writer.add_scalar("loss", avg_eval_loss, self._global_step)
logging.info(f"Avg Eval Loss: {avg_eval_loss}")
[docs] def train_and_eval(
self,
train_dataloader: torch.utils.data.DataLoader,
eval_dataloader: torch.utils.data.DataLoader,
):
"""Train and evaluate the model with data generated by dataloaders.
In each epoch, this method trains the model first, then runs evaluation
every epoch.
Args:
train_dataloader: A data loader for generating training data to
feed to the model.
eval_dataloader: A data loader for generating evaluation data to
feed to the model.
"""
self._train_dataloader = train_dataloader
self._eval_dataloader = eval_dataloader
with self._configure_run(
modes.TRAIN_AND_EVAL, (train_dataloader, eval_dataloader),
):
self.on_train_start()
exit_training = False
for epoch in range(self._num_epochs):
exit_training = self.train_epoch(epoch, train_dataloader)
self.eval_epoch(eval_dataloader)
if exit_training:
break
self.on_train_end(exit_training)
logging.info("Training and Evaluation Completed Successfully!")
@abc.abstractmethod
def _increment_global_step(self):
raise NotImplementedError()
def _maybe_write_log(self, loss, step_offset=0, base_step=None):
if self._is_fetch_step(step_offset):
if base_step is None:
base_step = self._global_step
self._write_log(loss, base_step + step_offset)
@abc.abstractmethod
def _write_log(self, loss, global_step):
raise NotImplementedError()
def _maybe_save_loss(self, loss, epoch=None, step_offset=0):
if self._save_losses and self._is_fetch_step(step_offset):
self._save_loss(loss, self._global_step, epoch, step_offset)
@cm.step_closure
def _save_loss(
self,
loss: torch.Tensor,
global_step: int,
epoch: int = None,
step_offset: int = 0,
):
"""Save the current step's loss
Args:
loss: The loss tensor.
global_step: The global step
epoch: The current epoch. Used in `train_and_eval` mode to distinguish
the eval losses on a per epoch basis.
step_offset: The amount to offset to global step to account for the
fact that the global step is not incremented in eval mode
"""
if epoch is not None:
self._loss_saver.add(loss, step_offset, epoch)
else:
self._loss_saver.add(loss, global_step + step_offset)
def _maybe_save_summaries(self, step_offset=0, base_step=None):
"""Saves summaries calculated at the current step."""
if self._is_fetch_step(step_offset):
if base_step is None:
base_step = self._global_step
save_all_summaries(self._writer, base_step + step_offset)
else:
discard_cached_summaries()
def _maybe_save_checkpoint(self, step_offset=0):
if self._is_checkpoint_step(step_offset):
self._save_checkpoint(self._global_step)
@abc.abstractmethod
def _save_checkpoint(self, *args, **kwargs):
raise NotImplementedError()
def _maybe_load_checkpoint(self, checkpoint_path: Optional[str], mode: str):
"""Optionally load checkpoint into the model.
Args:
checkpoint_path: Path to a checkpoint file.
Returns:
The loaded state dict. If checkpoint path was None, returns None.
"""
if checkpoint_path:
logging.info(
f"Loading weights from checkpoint {self._checkpoint_path}"
)
state_dict = cbtorch.load(checkpoint_path)
self._model.set_state(state_dict)
else:
logging.info(
f"No checkpoint was provided, using randomly initialized model "
f"parameters."
)
state_dict = None
if state_dict and not self._is_pretrained_checkpoint:
self._global_step = state_dict.get("global_step", 0)
else:
self._global_step = 0
self._initial_step = self._global_step
return state_dict
# Returns path to last checkpoint or None if no checkpoints exist
def _get_last_checkpoint(self):
# Used when running on interuptable instances in order to reload from
# the last checkpoint.
if not os.path.exists(self._model_dir):
return None
last_ckpt = (None, None) # (step of last ckpt, path of last ckpt)
for name in os.listdir(self._model_dir):
match = re.fullmatch(rf'checkpoint_(\d+)\.mdl', name)
if match:
step = int(match.group(1))
if last_ckpt[0] is None or step > last_ckpt[0]:
last_ckpt = (step, os.path.join(self._model_dir, name))
if last_ckpt[1] is not None:
logging.info(
f"Found latest checkpoint at {last_ckpt[1]}."
f"This checkpoint will be used for loading model state."
)
return last_ckpt[1]
def _maybe_check_loss_value(self, loss, step_offset=0):
if self._check_loss_values and self._is_fetch_step(step_offset):
self._check_loss_value(loss)
@cm.step_closure
def _check_loss_value(self, loss: torch.Tensor):
"""Checks to see if loss is Nan/inf.
Args:
loss: The loss tensor.
Raises:
ValueError if the loss is either NaN or inf.
"""
loss = cm.to_cpu(loss.detach())
msg_postfix = (
"This could potentially be due to selected hyperparameters such as "
"the learning rate, batch size, etc. or it could due an internal "
"error. Please try with different set of hyperparameters and "
"contact Cerebras Support if the issue persists."
)
if torch.isnan(loss).any().item():
raise ValueError(f"NaN loss detected. {msg_postfix}")
if torch.isinf(loss).any().item():
raise ValueError(f"inf loss detected. {msg_postfix}")
def _maybe_log_throughput(self, step_offset=0):
"""Conditionally add throughput details to tensorboard"""
if self._is_fetch_step(step_offset):
self._log_throughput(self._global_step + step_offset)
def _log_throughput(self, step):
"""Add throughput details to tensorboard"""
# Can be optionally implemented downstream
@cm.step_closure
def _accumulate_loss_value(self, loss: torch.Tensor):
"""
Accumulates our loss value to a total_loss
Args:
loss: The loss tensor.
"""
self._loss_saver.accumulate(loss)
def _save_stream(self, data_loader, mode: str):
if mode == modes.TRAIN_AND_EVAL:
train_data_loader, eval_data_loader = data_loader
self._save_stream(train_data_loader, modes.TRAIN)
self._save_stream(eval_data_loader, modes.EVAL)
return
data = defaultdict(list)
i = 0
while i < self._save_stream_size:
for batch in data_loader:
for scope, tensor in visit_structure(
batch, lambda t: isinstance(t, torch.Tensor), strict=True
):
data[".".join(map(str, scope))].append(
cm.to_cpu(tensor.detach()).numpy()
)
i += 1
if i >= self._save_stream_size:
break
ordinal = cm.get_ordinal()
stream_dir = os.path.join(self._model_dir, mode)
os.makedirs(stream_dir, exist_ok=True)
np.savez(os.path.join(stream_dir, f"streams.{ordinal}.npz"), **data)
[docs] @staticmethod
def create(
model_fn: Callable[[dict, Optional[torch.device]], PyTorchBaseModel],
params: dict,
) -> "PyTorchBaseRunner":
"""
Creates and returns an instance of PyTorchBaseRunner
that has been configured based on the hardware specified by the
provided params dictionary
Args:
model_fn: A callable that takes in a 'params' argument
and optionally a torch.device which it uses to configure
and return a PyTorchBaseModel
params: A dictionary containing all the parameters required
to initialize and configure both the model and the runner
"""
runconfig_params = params["runconfig"]
RunConfigParamsValidator().validate(runconfig_params)
if runconfig_params.get("max_checkpoints"):
warnings.warn("`max_checkpoints` is only supported using cstorch")
target_device = runconfig_params["target_device"]
if target_device == DeviceType.CSX:
from cerebras_appliance import DEFAULT_COMPILE_DIR
from modelzoo.common.pytorch.pytorch_cs_appliance import (
PyTorchCSAppliance,
)
cbtorch.initialize(
service_workdir=runconfig_params["service_dir"],
compile_dir=(
runconfig_params.get("compile_dir") or DEFAULT_COMPILE_DIR
),
compile_only=(
runconfig_params["compile_only"]
or runconfig_params["validate_only"]
),
appliance=True,
use_appliance_data=runconfig_params.get(
"use_appliance_data", True
),
use_cbfloat16=params.get("csconfig", {}).get(
"use_cbfloat16", False
),
log_initialization=runconfig_params.get(
"log_initialization", True
),
)
seed = runconfig_params.get("seed")
if seed is not None:
cm.set_rng_state(seed)
execution_strategy = (
runconfig_params.get("execution_strategy")
or ExecutionStrategy.weight_streaming
)
if (
runconfig_params["num_csx"] > 1
and execution_strategy != ExecutionStrategy.weight_streaming
):
raise ValueError(
f"Using multiple CS-X systems is only supported in "
f"appliance mode with "
f"{ExecutionStrategy.weight_streaming} execution "
f"strategy."
)
cbtorch.env().weight_streaming_mode = (
execution_strategy == ExecutionStrategy.weight_streaming
)
logging.debug(
f"Running with {execution_strategy} execution strategy"
)
if not cbtorch.env().weight_streaming_mode:
use_bfloat16 = params["model"].get("use_bfloat16", False)
if use_bfloat16:
warnings.warn(
f"bfloat16 is not supported on Pipeline execution "
f"strategy. Setting use_bfloat16 in model config to "
f"False."
)
params["model"]["use_bfloat16"] = False
model = _base_model_compat(model_fn, params)
return PyTorchCSAppliance(model, params)
elif target_device == DeviceType.CPU:
from modelzoo.common.pytorch.pytorch_runner import PyTorchRunner
device = torch.device("cpu")
if params["model"]["mixed_precision"] and not params["model"].get(
"use_bfloat16", False
):
warnings.warn(
"Mixed precision on CPU is only supported with bfloat16. "
"Setting use_bfloat16 in model config to True."
)
params["model"]["use_bfloat16"] = True
model = _base_model_compat(model_fn, params, device)
return PyTorchRunner(device, model, params)
elif target_device == DeviceType.GPU:
if not torch.cuda.is_available():
raise RuntimeError(
f"{DeviceType.GPU} was specified as the target device, but "
f"CUDA is not available. Please make sure you have a "
f"PyTorch installation with CUDA enabled to run on "
f"{DeviceType.GPU}."
)
world_size = torch.cuda.device_count()
enable_distributed = params["runconfig"].get(
"enable_distributed", False
)
if not enable_distributed: # Single GPU
from modelzoo.common.pytorch.pytorch_runner import PyTorchRunner
if world_size > 1:
warnings.warn(
"Distributed training was not enabled even though "
"more than 1 GPU is available."
)
device = torch.device("cuda")
model = _base_model_compat(model_fn, params, device)
return PyTorchRunner(device, model, params)
else: # Distributed GPU
from modelzoo.common.pytorch.pytorch_dist_runner import (
PyTorchDistRunner,
)
if world_size == 1:
warnings.warn(
"Distributed training was enabled, but only "
"1 GPU was detected."
)
# Model with no device, used to create optimizer and scheduler
# actual models will be created in on_process_start()
model = _base_model_compat(model_fn, params)
return PyTorchDistRunner(model, params)
else:
raise ValueError(
f"Unsupported target device: {target_device}. "
f"Supported devices are: {', '.join(DeviceType.devices)}"
)
def _base_model_compat(model_fn, params, device=None) -> PyTorchBaseModel:
# Initialize the model and runner
if isclass(model_fn) and issubclass(model_fn, PyTorchBaseModel):
# to keep compatibility for if a user inherits from PyTorchBaseModel
model = model_fn(params)
else:
model = PyTorchBaseModel(params, model_fn, device)
return model