# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the CS Appliance mode runner"""
# pylint: disable=attribute-defined-outside-init
import logging
import os
import time
import warnings
from shutil import rmtree
from typing import Optional, Tuple
import torch
from cerebras_appliance.CSConfig import CSConfig
from cerebras_appliance.pb.workflow.appliance.common.common_config_pb2 import (
DebugArgs,
)
from cerebras_appliance.run_utils import (
get_debug_args,
update_debug_args_with_autogen_policy,
update_debug_args_with_mem_limits,
)
from modelzoo import CSOFT_PACKAGE, CSoftPackage
from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch
from modelzoo.common.pytorch.loss_utils import extract_loss
from modelzoo.common.pytorch.metrics import get_all_metrics
from modelzoo.common.pytorch.perf_utils import save_perf
from modelzoo.common.pytorch.pytorch_base_cs_runner import PyTorchBaseCSRunner
from modelzoo.common.pytorch.PyTorchBaseModel import PyTorchBaseModel
from modelzoo.common.pytorch.sparsity.appliance import (
build_sparsify_grouper,
validate_sparsity_params,
)
COMPILE_ONLY_MSG = "Compiling the model. This may take a few minutes."
[docs]class PyTorchCSAppliance(PyTorchBaseCSRunner):
"""Class for compiling PyTorch models for Cerebras hardware."""
[docs] def __init__(self, model: PyTorchBaseModel, params: dict):
super().__init__(model, params)
if self._save_stream_size:
raise ValueError(f"Saving input streams on CSX is not supported.")
self._save_losses = self._runconfig.get("save_losses", True)
self._validate_only = self._runconfig.get("validate_only", False)
self._compile_only = (
self._runconfig.get("compile_only", False) or self._validate_only
)
if self._compile_only or self._validate_only:
# nothing to save if compile only
self._save_initial_checkpoint = False
else:
self._save_initial_state()
self._initial_state_file = None
self._num_batches_processed = 0
debug_args = DebugArgs()
if self._runconfig.get("debug_args_path"):
debug_args = get_debug_args(self._runconfig["debug_args_path"])
update_debug_args_with_autogen_policy(
debug_args, self._runconfig.get("autogen_policy")
)
update_debug_args_with_mem_limits(debug_args, self._runconfig)
cs_config = CSConfig(
num_csx=self._runconfig.get("num_csx"),
max_wgt_servers=self._runconfig["num_wgt_servers"],
mgmt_address=self._runconfig.get("mgmt_address"),
mgmt_namespace=self._runconfig.get("mgmt_namespace"),
credentials_path=self._runconfig.get("credentials_path"),
debug_args=debug_args,
mount_dirs=self._runconfig.get("mount_dirs"),
python_paths=self._runconfig.get("python_paths"),
transfer_processes=self._runconfig.get("transfer_processes"),
num_workers_per_csx=self._runconfig["num_workers_per_csx"],
job_labels=self._runconfig.get("job_labels"),
max_act_per_csx=self._runconfig["num_act_servers"],
job_time_sec=self._runconfig["job_time_sec"],
disable_version_check=self._runconfig["disable_version_check"],
)
precision_opt_level = None
model_pol = self._params.get("model", {}).get("precision_opt_level")
if model_pol is not None:
warnings.warn(
"Passing `precision_opt_level` via `model` params is deprecated."
"Please use `params[\"runconfig\"][\"precision_opt_level\"]`"
)
precision_opt_level = self._runconfig.get(
"precision_opt_level", model_pol
)
if precision_opt_level != model_pol and model_pol is not None:
logging.warning(
f"Using `precision_opt_level:{precision_opt_level}` from"
f"`runconfig` instead of `{model_pol}` from `model`"
)
if precision_opt_level is None:
precision_opt_level = 1
cs_config.precision_opt_level = precision_opt_level
use_cs_grad_accum = self._runconfig.get("use_cs_grad_accum", False)
self.skip_train_recv_activations = self._runconfig.get(
"skip_train_recv_activations", False
)
self._appliance = cbtorch.core.appliance.ApplianceMode(
os.path.join(cbtorch.env().service_workdir, "cerebras_logs"),
cbtorch.env().compile_dir,
cs_config,
use_cs_grad_accum,
full_config=cbtorch.state().full_config,
)
# Cache the original xla loss tensor for retrieving its value later
self._loss_tensor = None
self.train_data_fn = None
self.eval_data_fn = None
self.send_weights_grouper = None
sparsity = self._params.get("sparsity", {})
if sparsity and sparsity.get("type") == "sideband":
validate_sparsity_params(sparsity)
@property
def _should_log_extra_summaries(self):
return self._log_summaries and self._num_batches_processed == 0
[docs] def get_loss_value(self) -> torch.Tensor:
"""Fetch all activations and return the loss value."""
assert self._loss_tensor is not None, "Loss tensor was not found!"
logging.debug("Receiving activations")
# This will fetch activations and store them in cbtorch.state()
self._appliance.receive_activations(self._num_batches_processed)
return cbtorch.state().get_activation_for_output(self._loss_tensor)
[docs] def maybe_get_loss_value(self, step) -> torch.Tensor:
"""Fetch loss value if its a fetch step otherwise return None."""
if self._is_fetch_step_helper(step):
loss = self.get_loss_value()
else:
loss = None
return loss
##################################################################
# Training Hooks #
##################################################################
[docs] def on_train_start(self):
if not self._compile_only:
self._start_time = time.time()
# Losses are fetched (but maybe not displayed) at every step
if self._model.grad_scaler:
self._scaler = self._model.grad_scaler
cm.set_run_config(self._total_steps, self._checkpoint_steps, 1)
# Now that training has started, no need to store any new tensors
os.environ["CEREBRAS_APPLIANCE_NO_STORAGE"] = "1"
[docs] def on_train_end(self, early_exit=False):
if not self._compile_only:
save_perf(self._perf_dir)
cm.run_step_closures()
super().on_train_end(early_exit)
# Delete appliance data if train was successful
self._delete_initial_state()
[docs] def on_train_epoch_end(self, early_exit: bool):
pass
[docs] def on_train_batch_start(self, data):
if self._num_batches_processed == 0:
self._appliance.tracker_execute.start("Tracing forward pass")
sparsity = self._params.get("sparsity", {})
if sparsity and sparsity.get("type") == "sideband":
# Build tensor grouper before tracing model so the initial
# weights can have their sparsity attributes annotated.
self.send_weights_grouper = build_sparsify_grouper(
sparsity, self._model
)
return super().on_train_batch_start(data)
return data
[docs] def on_train_batch_end(self, loss, epoch: int = None, step: int = None):
if self._num_batches_processed == 0:
self._appliance.tracker_execute.stop("Tracing forward pass")
logging.info(COMPILE_ONLY_MSG)
batch_size = self._train_dataloader.batch_size
with self._appliance.build_worker_image(
should_skip=self._compile_only or self._validate_only
):
self._appliance.compile(
cbtorch.state().outputs, batch_size, self._validate_only
)
logging.info("Compile for training completed successfully!")
if not self._compile_only:
assert self._initial_state_file is not None
self._appliance.execute(
self.train_data_fn,
self.get_input_fn_params(),
batch_size,
self._total_steps,
self._checkpoint_steps,
self._active_mode,
self._initial_state_file,
cleanup_stack=self._cleanup_stack,
send_weights_grouper=self.send_weights_grouper,
)
logging.debug("Execute setup complete")
if self.skip_train_recv_activations:
loss = self.maybe_get_loss_value(
self._num_batches_processed
)
else:
loss = self.get_loss_value()
if not self._compile_only:
super().on_train_batch_end(loss, epoch, step)
self._num_batches_processed += 1
[docs] def train_forward(self, data):
if self._num_batches_processed == 0:
# Cache the loss lazy tensor used in compile
self._loss_tensor = super().train_forward(data)
# FIXME: Add a no-op to fix output mapping issue with ScopeBoundary
self._loss_tensor = self._loss_tensor * 1
return self._loss_tensor
if self.skip_train_recv_activations:
return self.maybe_get_loss_value(self._num_batches_processed + 1)
else:
return self.get_loss_value()
[docs] def backward(self, loss):
if self._num_batches_processed == 0:
return super().backward(loss)
return None
[docs] def optimizer_zero_grad(self):
if self._num_batches_processed == 0:
return super().optimizer_zero_grad()
return None
[docs] def optimizer_step(self):
if self._num_batches_processed == 0:
return super().optimizer_step()
return None
[docs] def lr_scheduler_step(self):
if self._num_batches_processed == 0:
return super().lr_scheduler_step()
return None
##################################################################
# Evaluation Hooks #
##################################################################
[docs] def on_eval_start(self):
# Losses are fetched (but maybe not displayed) at every step
cm.set_run_config(self._total_steps, self._checkpoint_steps, 1)
[docs] def on_eval_end(self, early_exit=False):
if not self._compile_only:
save_perf(self._perf_dir)
cm.run_step_closures()
super().on_eval_end(early_exit)
# Delete appliance data if eval was successful
self._delete_initial_state()
[docs] def eval_forward(self, data):
if self._num_batches_processed == 0:
outputs = super().eval_forward(data)
# Need to track eval model outputs to compile
loss = extract_loss(outputs)
# Cache the loss lazy tensor used in compile
self._loss_tensor = loss
cbtorch.state().track_object({"loss": loss})
cbtorch.state().track_object(outputs)
return outputs
else:
return self.get_loss_value()
[docs] def on_eval_epoch_end(self, early_exit: bool):
if not self._compile_only:
cm.run_step_closures()
[docs] def on_eval_batch_start(self, data):
if self._num_batches_processed == 0:
self._appliance.tracker_execute.start("Tracing forward pass")
sparsity = self._params.get("sparsity", {})
if sparsity and sparsity.get("type") == "sideband":
# Build tensor grouper before tracing model so the initial
# weights can have their sparsity attributes annotated.
# We don't actually save/apply this during eval.
build_sparsify_grouper(sparsity, self._model)
return super().on_eval_batch_start(data)
return data
[docs] def on_eval_batch_end(self, loss, epoch: int = None, step: int = None):
if self._num_batches_processed == 0:
self._appliance.tracker_execute.stop("Tracing forward pass")
logging.info(COMPILE_ONLY_MSG)
batch_size = self._eval_dataloader.batch_size
with self._appliance.build_worker_image(
should_skip=self._compile_only or self._validate_only
):
self._appliance.compile(
cbtorch.state().outputs, batch_size, self._validate_only
)
logging.info("Compile for evaluation completed successfully!")
if not self._compile_only:
assert self._initial_state_file is not None
self._appliance.execute(
self.eval_data_fn,
self.get_input_fn_params(),
batch_size,
self._total_steps,
0, # checkpoint_steps
self._active_mode,
initial_checkpoint_file=self._initial_state_file,
cleanup_stack=self._cleanup_stack,
)
loss = self.get_loss_value()
if not self._compile_only:
super().on_eval_batch_end(loss, epoch, step)
self._num_batches_processed += 1
[docs] def compute_eval_metrics(self):
if not self._compile_only:
super().compute_eval_metrics()
##################################################################
# Override Abstract Methods #
##################################################################
[docs] def train(self, train_dataloader: torch.utils.data.DataLoader) -> None:
dataloader = cbtorch.dataloader(
train_dataloader, use_parallel_loader=False
)
super().train(dataloader)
[docs] def evaluate(self, eval_dataloader: cbtorch.data.DataLoader):
dataloader = cbtorch.dataloader(
eval_dataloader, use_parallel_loader=False
)
super().evaluate(dataloader)
def _should_stop(self, epoch_step: int, mode: str) -> Tuple[bool, bool]:
if self._compile_only:
return True, True
return super()._should_stop(epoch_step, mode)
[docs] def is_master_ordinal(self):
return True
def _configure_run_steps(self, dataloader, mode: str):
if self._compile_only:
self._num_epochs = 1
self._total_steps = 1
self._checkpoint_steps = 0
self._fetch_steps = 0
else:
super()._configure_run_steps(dataloader, mode)
def _maybe_load_checkpoint(self, checkpoint_path: Optional[str], mode: str):
"""Optionally load checkpoint into the model.
Args:
checkpoint_path: Path to a checkpoint file.
"""
if not checkpoint_path:
logging.info(
f"No checkpoint was provided, using randomly initialized model "
f"parameters."
)
self._global_step = 0
self._initial_step = 0
return
if CSOFT_PACKAGE in (CSoftPackage.SRC, CSoftPackage.WHEEL):
from cerebras_pytorch.saver.pt_h5_saver import PyTorchH5Saver
else:
raise ImportError("Cerebras PyTorch package not installed")
logging.info(f"Loading weights from checkpoint {self._checkpoint_path}")
with self._appliance.tracker_execute.entry("Load Checkpoint"):
saver = PyTorchH5Saver()
if PyTorchH5Saver.is_valid_checkpoint(checkpoint_path):
assert self._checkpoint_path == checkpoint_path
tensor_names = saver.tensor_names(checkpoint_path)
if "global_step" in tensor_names:
self._global_step = saver.load_tensor(
checkpoint_path, "global_step"
)
else:
self._global_step = 0
else:
# If we get a normal PyTorch checkpoint we need to convert it into H5 format
with self._appliance.tracker_execute.entry(
"Convert PyTorch Checkpoint"
):
state_dict = torch.load(
checkpoint_path, map_location=torch.device('cpu'),
)
self._global_step = state_dict.get("global_step", 0)
self._checkpoint_path = os.path.join(
self._model_dir, f"loaded_checkpoint.mdl"
)
saver.save(self._checkpoint_path, state_dict)
if self._is_pretrained_checkpoint:
self._global_step = 0
self._initial_step = int(self._global_step)
@cm.step_closure
def _save_initial_state(self):
self._initial_state_file = os.path.join(
self._model_dir, f"initial_state_{time.time()}.hdf5"
)
# construct initial state dict
state_dict = self._model.get_state()
state_dict["global_step"] = state_dict.get(
"global_step", self._initial_step
)
initial_metric_state = {}
for metric in get_all_metrics().values():
on_device_state_dict = metric.on_device_state_dict()
if on_device_state_dict:
metric_name = metric.name.replace("/", "_")
initial_metric_state[metric_name] = on_device_state_dict
if initial_metric_state:
state_dict[cm.METRIC_NAME_PREFIX] = initial_metric_state
if CSOFT_PACKAGE in (CSoftPackage.SRC, CSoftPackage.WHEEL):
from cerebras_pytorch.saver.pt_h5_saver import PyTorchH5Saver
else:
raise ImportError("Cerebras PyTorch package not installed")
saver = PyTorchH5Saver(
loaded_checkpoint=self._checkpoint_path,
is_pretrained_checkpoint=self._is_pretrained_checkpoint,
)
saver.save(self._initial_state_file, state_dict)
def _delete_initial_state(self):
"""Delete the initial state file and its associated data directory."""
if os.path.exists(self._initial_state_file):
os.remove(self._initial_state_file)
if os.path.exists(f"{self._initial_state_file}.data"):
rmtree(f"{self._initial_state_file}.data")
@cm.step_closure
def _save_checkpoint(self, step): # pylint: disable=arguments-differ
logging.info(f"Saving checkpoint at global step {step}")
file_name = os.path.join(self._model_dir, f"checkpoint_{step}.mdl")
if os.path.exists(file_name):
# If checkpoint path already exists, need to come up with a unique
# name. Appending the current time, should be sufficient
file_name = os.path.join(
self._model_dir, f"checkpoint_{step}_{time.time()}.mdl"
)
if CSOFT_PACKAGE in (CSoftPackage.SRC, CSoftPackage.WHEEL):
from cerebras_pytorch.saver.pt_h5_saver import (
CerebrasStateDict,
PyTorchH5Saver,
)
else:
raise ImportError("Cerebras PyTorch package not installed")
saver = PyTorchH5Saver()
state_dict = self._model.get_state()
state_dict["global_step"] = state_dict.get("global_step", step)
flattened, spec = saver.flatten_state_dict(state_dict)
# save the spec before saving tensors so we know what was
# intended to be saved, even if something fails
saver.save_spec(file_name, spec)
if step == self._initial_step:
assert self._initial_state_file is not None
with self._appliance.tracker_execute.entry(
"Saving Initial Checkpoint"
):
src_tensor_names = PyTorchH5Saver.tensor_names(
self._initial_state_file
)
for key in flattened:
if key in src_tensor_names:
val = saver.load_tensor(self._initial_state_file, key)
saver.save_tensor(file_name, key, val)
else:
saver.save_tensor(file_name, key, flattened[key])
else:
with self._appliance.tracker_execute.entry("Saving Checkpoint"):
self._appliance.save_weights(
flattened.items(),
file_name,
step - self._initial_step - 1,
self._model.duplicate_params_map,
)
# Save dataloader checkpoint via WRK
self._appliance.save_dataloader_checkpoint(
state_dict["global_step"]
)
def post_transfer_callback(state_dict):
if "optimizer" in state_dict:
state_dict[
"optimizer"
] = self._optimizer.convert_state_dict_for_checkpoint(
state_dict["optimizer"]
)
return state_dict
post_transfer_callback(CerebrasStateDict.create(spec, file_name))
logging.info(f"Saved checkpoint at global step: {step}")
logging.debug(f"Checkpoint file: {file_name}")
self.on_checkpoint_saved(file_name, step)
def _increment_global_step(self):
self._global_step += 1