# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic run scripts build using the cstorch API"""
import logging
import os
import re
import time
import warnings
from datetime import datetime
from pathlib import Path
import torch
from modelzoo.common.pytorch.half_dtype import half_dtype_instance
from modelzoo.common.pytorch.utils import (
named_parameters_requiring_grad,
partition_params_groups_with_adjusted_lr,
partition_params_groups_with_weight_decay,
setup_logging,
)
from modelzoo.common.run_utils.utils import DeviceType
[docs]def run_cstorch_flow(params, model_fn, train_data_fn, eval_data_fn):
"""
Set up the cstorch run and call the appropriate helper based on the mode
Args:
params: the params dictionary extracted from the params.yaml used
model_fn: A callable that takes in the params dictionary and returns
a torch.nn.Module
train_data_fn: A callable that takes in the param dictionary and
returns a torch.utils.data.DataLoader
eval_data_fn: A callable that takes in the param dictionary and
returns a torch.utils.data.DataLoader
"""
import cerebras_pytorch.experimental as cstorch
from cerebras_appliance.run_utils import get_debug_args
runconfig = params["runconfig"]
if "seed" in runconfig:
# Ensure we set seed before any model initialization
torch.manual_seed(runconfig["seed"])
debug_args = None
if runconfig.get("debug_args_path"):
debug_args = get_debug_args(runconfig["debug_args_path"])
# Configure the Cerebras Wafer Scale cluster
cs_config = cstorch.utils.CSConfig(
num_csx=runconfig.get("num_csx"),
max_wgt_servers=runconfig.get("num_wgt_servers"),
mgmt_address=runconfig.get("mgmt_address"),
credentials_path=runconfig.get("credentials_path"),
debug_args=debug_args,
mount_dirs=runconfig.get("mount_dirs"),
python_paths=runconfig.get("python_paths"),
transfer_processes=runconfig.get("transfer_processes"),
num_workers_per_csx=runconfig.get("num_workers_per_csx"),
job_labels=runconfig.get("job_labels"),
max_act_per_csx=runconfig.get("num_act_servers"),
job_time_sec=runconfig.get("job_time_sec"),
disable_version_check=runconfig["disable_version_check"],
)
# Set up logging level
setup_logging(
runconfig.get("logging"),
runconfig.get("streamer_logging"),
logging_dir=runconfig.get("model_dir"),
)
if runconfig["mode"] == "train":
run_cstorch_train(params, model_fn, train_data_fn, cs_config)
elif runconfig["mode"] == "eval":
run_cstorch_eval(params, model_fn, eval_data_fn, cs_config)
[docs]def run_cstorch_train(params, model_fn, input_fn, cs_config):
"""
Runs the training workflow built using the cstorch API
Args:
params: the params dictionary extracted from the params.yaml used
model_fn: A callable that takes in the params dictionary and returns
a torch.nn.Module
input_data_fn: A callable that takes in the param dictionary and
returns a torch.utils.data.DataLoader
"""
import cerebras_pytorch.experimental as cstorch
runconfig = params["runconfig"]
model_dir = runconfig["model_dir"]
compile_dir = runconfig.get("compile_dir")
log_steps = runconfig.get("log_steps")
checkpoint_steps = runconfig.get("checkpoint_steps")
compile_only = runconfig.get("compile_only", False)
validate_only = runconfig.get("validate_only", False)
drop_data = runconfig.get("drop_data", False)
log_summaries = params["optimizer"].get("log_summaries", False)
precision_opt_level = None
model_pol = params["model"].get("precision_opt_level")
if model_pol is not None:
warnings.warn(
"Passing `precision_opt_level` via `model` params is deprecated. "
"Please use `params[\"runconfig\"][\"precision_opt_level\"]`"
)
precision_opt_level = runconfig.get("precision_opt_level", model_pol)
if precision_opt_level != model_pol and model_pol is not None:
logging.warning(
f"Using `precision_opt_level:{precision_opt_level}` from `runconfig` "
f"instead of `{model_pol}` from `model`"
)
if precision_opt_level is None:
precision_opt_level = 1
cs_config.precision_opt_level = precision_opt_level
use_bfloat16 = params["model"].get("use_bfloat16", False)
half_dtype_instance.use_bfloat16 = use_bfloat16
if use_bfloat16:
cstorch.amp.use_bfloat16(True)
optimizer_params = params["optimizer"]
grad_scaler = None
loss_scale = params["optimizer"].get("loss_scaling_factor", 1.0)
if loss_scale == "dynamic" and use_bfloat16:
optimizer_params["loss_scaling_factor"] = 1.0
logging.info(
f"No need to use DLS for loss when `use_bfloat16` is set to"
" `True`. Setting `loss_scaling_factor ` to `1.0`."
)
use_cstorch_optimizer_step = runconfig.get(
"use_cstorch_optimizer_step", False
)
# Default to only keeping the 5 latest checkpoints.
max_checkpoints = runconfig.get("max_checkpoints", 5)
target_device = runconfig["target_device"]
if target_device == DeviceType.CSX:
backend = cstorch.backend(
"CSX",
artifact_dir=os.path.join(model_dir, "cerebras_logs"),
compile_dir=compile_dir,
compile_only=compile_only,
validate_only=validate_only,
drop_data=drop_data,
max_checkpoints=max_checkpoints,
)
elif target_device == DeviceType.CPU:
backend = cstorch.backend("CPU", max_checkpoints=max_checkpoints,)
with backend.device:
model = model_fn(params)
compiled_model = cstorch.compile(model, backend)
compiled_model.train()
# learning rate scaling params
lr_adjustment_scalars = []
lr_adjustment_layers = []
if optimizer_params.get("adjust_learning_rate"):
for layer_type, adjustment_scalar in optimizer_params.get(
"adjust_learning_rate"
).items():
lr_adjustment_layers.append(layer_type)
lr_adjustment_scalars.append(adjustment_scalar)
assert len(lr_adjustment_scalars) == len(
lr_adjustment_layers
), "number of keys for layer types should match the number of scalars"
param_optimizer = list(named_parameters_requiring_grad(model))
# default: assemble all params in 1 group
param_optimizer_grouped = [{"params": list(param_optimizer)}]
# split param_groups in 2 groups: with and without weight decay
param_optimizer_grouped = partition_params_groups_with_weight_decay(
model,
param_optimizer_grouped,
optimizer_params.get("weight_decay_rate", 0.0),
)
# create additional param groups for each layer type with lr adjustment scalar
param_optimizer_grouped = partition_params_groups_with_adjusted_lr(
model,
param_optimizer_grouped,
lr_adjustment_layers,
lr_adjustment_scalars,
)
# remove param name from the (name, param) tuple as the name was only used for referencing
# while grouping params
for group_idx in range(len(param_optimizer_grouped)):
param_list = []
for _, param in param_optimizer_grouped[group_idx]["params"]:
param_list.append(param)
param_optimizer_grouped[group_idx].pop("params")
param_optimizer_grouped[group_idx]["params"] = param_list
optimizer = cstorch.optim.configure_optimizer(
optimizer_type=optimizer_params.pop("optimizer_type"),
params=param_optimizer_grouped,
**optimizer_params,
)
lr_scheduler = cstorch.optim.configure_lr_scheduler(
optimizer, optimizer_params.get("learning_rate"),
)
if loss_scale is not None:
if backend.is_csx:
grad_scaler = cstorch.amp.GradScaler(
loss_scale=optimizer_params.get("loss_scaling_factor"),
init_scale=optimizer_params.get("initial_loss_scale"),
steps_per_increase=optimizer_params.get("steps_per_increase"),
min_loss_scale=optimizer_params.get("min_loss_scale"),
max_loss_scale=optimizer_params.get("max_loss_scale"),
max_gradient_norm=optimizer_params.get("max_gradient_norm"),
)
elif backend.is_gpu:
grad_scaler = torch.cuda.amp.GradScaler()
else:
logging.warning(
f"Gradient scaling is not supported on "
f"{backend.backend_type.name}. "
f"Disabling gradient scaling for this run"
)
@cstorch.checkpoint_closure
def save_checkpoint(step):
logging.info(f"Saving checkpoint at step {step}")
checkpoint_file = os.path.join(model_dir, f"checkpoint_{step}.mdl")
if os.path.exists(checkpoint_file):
# If checkpoint path already exists, need to come up with a unique
# name. Appending the current time, should be sufficient
checkpoint_file = os.path.join(
model_dir,
f"checkpoint_{step}_{datetime.now():%Y%m%d_%H%M%S}.mdl",
)
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if lr_scheduler:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()
if grad_scaler:
state_dict["grad_scaler"] = grad_scaler.state_dict()
state_dict["global_step"] = step
cstorch.save(state_dict, checkpoint_file)
logging.info(f"Saved checkpoint {checkpoint_file}")
def load_checkpoint(checkpoint_path):
logging.info(f"Loading weights from checkpoint {checkpoint_path}")
state_dict = cstorch.load(checkpoint_path)
model.load_state_dict(state_dict["model"])
if not runconfig.get("is_pretrained_checkpoint", False):
optimizer.load_state_dict(state_dict["optimizer"])
if lr_scheduler and "lr_scheduler" in state_dict:
lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
if grad_scaler and "grad_scaler" in state_dict:
grad_scaler.load_state_dict(state_dict["grad_scaler"])
global_step = state_dict.get("global_step", 0)
return global_step
global_step = 0
if compile_only or validate_only:
# Don't bother loading a checkpoint if only compiling/validating
pass
elif runconfig.get("checkpoint_path") is not None:
global_step = load_checkpoint(runconfig["checkpoint_path"])
elif runconfig.get("autoload_last_checkpoint", True):
# get the path to the checkpoint with the highest global step
last_checkpoint = get_latest_checkpoint(model_dir)
if last_checkpoint:
logging.info(f"Found latest checkpoint at {last_checkpoint}")
global_step = load_checkpoint(last_checkpoint)
else:
logging.info(
f"Expected checkpoints named `checkpoint_(\\d+).mdl` "
f"in {model_dir} but found None. "
f"Using randomly initialized model parameters."
)
else:
logging.info(
f"No checkpoint was provided, using randomly initialized model "
f"parameters."
)
if runconfig.get("save_initial_checkpoint", False) and not compile_only:
save_checkpoint(global_step)
writer = cstorch.utils.tensorboard.SummaryWriter(
log_dir=os.path.join(model_dir, "train")
)
@cstorch.compile_step
def training_step(*args, **kwargs):
loss = compiled_model(*args, **kwargs)
if log_summaries:
compute_params_norm(compiled_model)
if not grad_scaler:
optimizer.zero_grad()
loss.backward()
optimizer.step()
elif use_cstorch_optimizer_step:
cstorch.amp.optimizer_step(
loss,
optimizer,
grad_scaler,
max_gradient_norm=optimizer_params.get("max_gradient_norm"),
max_gradient_value=optimizer_params.get("max_gradient_value"),
)
else:
optimizer_step_with_summaries(
loss,
optimizer,
grad_scaler,
max_gradient_norm=optimizer_params.get("max_gradient_norm"),
max_gradient_value=optimizer_params.get("max_gradient_value"),
log_summaries=log_summaries,
model=compiled_model,
)
if lr_scheduler:
lr_scheduler.step()
# Extract the loss scale value from the grad scaler
loss_scale = None
# return final values
return loss, loss_scale
# Counts the total number of steps actually executed
total_steps = 0
@cstorch.step_closure
def post_training_step(loss, loss_scale):
nonlocal total_steps
is_log_step = executor.on_final_iteration or (
log_steps and global_step % log_steps == 0
)
# extract the loss scalar
if is_log_step:
rate = executor.profiler.rate()
global_rate = executor.profiler.global_rate()
# Print some logs to provide an update to the client
logging.info(
f"| Train Device={backend.device}, "
f"Step={global_step}, "
f"Loss={loss.item():.5f}, "
f"Rate={rate:.2f} samples/sec, "
f"GlobalRate={global_rate:.2f} samples/sec"
)
# record rates in tensorboard for future reference
writer.add_scalar("local_samples_per_sec", rate, global_step)
writer.add_scalar("avg_samples_per_sec", global_rate, global_step)
writer.add_scalar(
"avg_steps_per_sec",
global_rate / dataloader.batch_size,
global_step,
)
# Save the loss value to be able to plot the loss curve
writer.add_scalar("loss", loss.item(), global_step)
msg_postfix = (
"This could potentially be due to selected hyperparameters such as "
"the learning rate, batch size, etc. or it could due an internal "
"error. Please try with different set of hyperparameters and "
"contact Cerebras Support if the issue persists."
)
if torch.isnan(loss).any().item():
raise ValueError(f"NaN loss detected. {msg_postfix}")
if torch.isinf(loss).any().item():
raise ValueError(f"inf loss detected. {msg_postfix}")
if lr_scheduler:
for group, lr in enumerate(lr_scheduler.get_last_lr()):
writer.add_scalar(f"lr.{group}", lr, global_step)
total_steps += 1
dataloader = cstorch.utils.data.DataLoader(input_fn, params)
if compile_only or validate_only:
num_steps = None
else:
num_steps = cstorch.utils.data.compute_num_steps(
dataloader,
initial_step=global_step,
num_steps=runconfig.get("num_steps"),
max_steps=runconfig.get("max_steps"),
num_epochs=runconfig.get("num_epochs"),
steps_per_epoch=runconfig.get("steps_per_epoch"),
)
executor = cstorch.utils.data.DataExecutor(
dataloader,
num_steps=num_steps,
checkpoint_steps=checkpoint_steps,
cs_config=cs_config,
writer=writer,
)
start_time = time.time()
# The main training loop
try:
for batch in executor:
loss, loss_scale = training_step(batch)
global_step += 1
# Wait for outputs to become available to fetch from the CS system(s)
post_training_step(loss, loss_scale)
# only saves checkpoint if current step is a checkpoint step
save_checkpoint(global_step)
if not (compile_only or validate_only):
logging.info("Training completed successfully!")
finally:
if not (compile_only or validate_only):
# compute the total samples processed based on the number of steps
# and the number of Cerebras systems in the cluster
total_samples = total_steps * dataloader.batch_size
end_time = time.time()
logging.info(
f"Processed {total_samples} sample(s) "
f"in {end_time - start_time} seconds."
)
[docs]def run_cstorch_eval(params, model_fn, input_fn, cs_config):
"""
Runs the evaluatiion workflow built using the cstorch API
Args:
params: the params dictionary extracted from the params.yaml used
model_fn: A callable that takes in the params dictionary and returns
a torch.nn.Module
input_data_fn: A callable that takes in the param dictionary and
returns a torch.utils.data.DataLoader
"""
import cerebras_pytorch.experimental as cstorch
import cerebras_pytorch.experimental.metrics as metrics
runconfig = params["runconfig"]
model_dir = runconfig["model_dir"]
compile_dir = runconfig.get("compile_dir")
log_steps = runconfig.get("log_steps")
compile_only = runconfig.get("compile_only", False)
validate_only = runconfig.get("validate_only", False)
drop_data = runconfig.get("drop_data", False)
precision_opt_level = None
model_pol = params["model"].get("precision_opt_level")
if model_pol is not None:
warnings.warn(
"Passing `precision_opt_level` via `model` params is deprecated. "
"Please use `params[\"runconfig\"][\"precision_opt_level\"]`"
)
precision_opt_level = runconfig.get("precision_opt_level", model_pol)
if precision_opt_level != model_pol and model_pol is not None:
logging.warning(
f"Using `precision_opt_level:{precision_opt_level}` from `runconfig` "
f"instead of `{model_pol}` from `model`"
)
if precision_opt_level is None:
precision_opt_level = 1
cs_config.precision_opt_level = precision_opt_level
use_bfloat16 = params["model"].get("use_bfloat16", False)
half_dtype_instance.use_bfloat16 = use_bfloat16
if use_bfloat16:
cstorch.amp.use_bfloat16(True)
target_device = runconfig["target_device"]
if target_device == DeviceType.CSX:
backend = cstorch.backend(
"CSX",
artifact_dir=os.path.join(model_dir, "cerebras_logs"),
compile_dir=compile_dir,
compile_only=compile_only,
validate_only=validate_only,
drop_data=drop_data,
)
elif target_device == DeviceType.CPU:
backend = cstorch.backend("CPU")
with backend.device:
model = model_fn(params)
compiled_model = cstorch.compile(model, backend)
def load_checkpoint(checkpoint_path):
logging.info(f"Loading weights from checkpoint {checkpoint_path}")
state_dict = cstorch.load(checkpoint_path)
model.load_state_dict(state_dict["model"])
global_step = state_dict.get("global_step", 0)
return global_step
global_step = 0
if compile_only or validate_only:
# Don't bother loading a checkpoint if only compiling/validating
pass
elif runconfig.get("checkpoint_path") is not None:
global_step = load_checkpoint(runconfig["checkpoint_path"])
elif runconfig.get("autoload_last_checkpoint", True):
# get the path to the checkpoint with the highest global step
last_checkpoint = get_latest_checkpoint(model_dir)
if last_checkpoint:
logging.info(f"Found latest checkpoint at {last_checkpoint}")
global_step = load_checkpoint(last_checkpoint)
else:
logging.info(
f"Expected checkpoints named `checkpoint_(\\d+).mdl` "
f"in {model_dir} but found None. "
f"Using randomly initialized model parameters."
)
else:
logging.info(
f"No checkpoint was provided, using randomly initialized model "
f"parameters."
)
writer = cstorch.utils.tensorboard.SummaryWriter(
log_dir=os.path.join(model_dir, "eval")
)
@cstorch.compile_step
def eval_step(*args, **kwargs):
loss = compiled_model(*args, **kwargs)
return loss
total_loss = 0
total_steps = 0
@cstorch.step_closure
def post_eval_step(loss, step):
nonlocal total_loss
nonlocal total_steps
is_log_step = executor.on_final_iteration or (
log_steps and step % log_steps == 0
)
rate = executor.profiler.rate()
global_rate = executor.profiler.global_rate()
is_log_step = executor.on_final_iteration or (
log_steps and global_step % log_steps == 0
)
if is_log_step:
# Print some logs to provide an update to the client
logging.info(
f"| Eval Device={backend.device}, "
f"Step={step}, "
f"Loss={loss.item():.5f}, "
f"Rate={rate:.2f} samples/sec, "
f"GlobalRate={global_rate:.2f} samples/sec"
)
if executor.on_final_iteration:
# log the throughput of the eval run to tensorboard on the last step
writer.add_scalar("local_samples_per_sec", rate, global_step)
writer.add_scalar("avg_samples_per_sec", global_rate, global_step)
writer.add_scalar(
"avg_steps_per_sec",
global_rate / dataloader.batch_size,
global_step,
)
if torch.isnan(loss).any().item():
raise ValueError("NaN loss detected.")
if torch.isinf(loss).any().item():
raise ValueError("inf loss detected.")
total_loss += loss.item()
total_steps += 1
dataloader = cstorch.utils.data.DataLoader(input_fn, params)
if compile_only or validate_only:
num_steps = None
else:
num_steps = cstorch.utils.data.compute_num_steps(
dataloader, num_steps=runconfig.get("eval_steps"), num_epochs=1
)
executor = cstorch.utils.data.DataExecutor(
dataloader, num_steps=num_steps, cs_config=cs_config, writer=writer,
)
start_time = time.time()
try:
for step, batch in enumerate(executor, start=1):
loss = eval_step(batch)
post_eval_step(loss, step)
if not (compile_only or validate_only):
for name, value in metrics.compute_all_metrics().items():
writer.add_scalar(name, value, global_step)
logging.info(f"Metric: {name} = {value}")
avg_eval_loss = total_loss / total_steps
writer.add_scalar("loss", avg_eval_loss, global_step)
logging.info(f"Avg Eval Loss: {avg_eval_loss}")
logging.info("Evaluation completed successfully!")
finally:
if not (compile_only or validate_only):
# compute the total samples processed based on the number of steps
# and the number of Cerebras systems in the cluster
end_time = time.time()
total_samples = total_steps * dataloader.batch_size
logging.info(
f"Processed {total_samples} sample(s) "
f"in {end_time - start_time} seconds."
)
[docs]def get_latest_checkpoint(model_dir):
"""Get the path to the checkpoint with the highest global step"""
checkpoints = sorted(
Path(model_dir).glob("checkpoint_*.mdl"),
key=lambda p: int(re.match(r"checkpoint_(\d+).mdl", p.name).group(1)),
)
if len(checkpoints) > 0:
return checkpoints[-1]
else:
return None
[docs]def compute_params_norm(model):
"""Compute the model wise norm of the parameters"""
import cerebras_pytorch.experimental as cstorch
param_norm = torch.tensor(0.0).to(model.device)
for _, param in model.named_parameters():
if param.requires_grad:
# simply add if we want to include all params
param_norm += torch.pow(torch.norm(param), 2.0)
cstorch.summarize_scalar("model_wise_params_norm", torch.sqrt(param_norm))
[docs]def compute_grad_norm(model):
"""Compute the model wise and per layer norm of the gradients"""
import cerebras_pytorch.experimental as cstorch
params_grad_norm = torch.tensor(0.0).to(model.device)
for _, param in model.named_parameters():
if param.grad is not None:
params_grad_norm += torch.pow(torch.norm(param.grad), 2.0)
params_grad_norm = torch.sqrt(params_grad_norm)
cstorch.summarize_scalar("model_wise_grad_norm", params_grad_norm)
per_layer_grad_norm = {}
layer_pattern = re.compile(r".*(layers\.)(\d+)(\.).*")
for name, param in model.named_parameters():
if param.grad is None:
continue
# get a match if module name contains `layers.i.0` where i is layer num
match = layer_pattern.match(name)
if match:
layer_id = match.group(2)
if layer_id not in per_layer_grad_norm:
per_layer_grad_norm[layer_id] = torch.tensor(0.0).to(
model.device
)
per_layer_grad_norm[layer_id] += torch.pow(
torch.norm(param.grad), 2.0
)
for layer_id in per_layer_grad_norm:
cstorch.summarize_scalar(
f"per_layer_grad_norm/layer_{layer_id}",
torch.sqrt(per_layer_grad_norm[layer_id]),
)
[docs]def optimizer_step_with_summaries(
loss: torch.Tensor,
optimizer: "cstorch.optim.Optimizer",
grad_scaler: "cstorch.amp.GradScaler",
max_gradient_norm: float = None,
max_gradient_value: float = None,
log_summaries: bool = False,
model: torch.nn.Module = None,
):
"""
Customized equivalent to cstorch.amp.optimizer_step
additionally featuring grad norm summaries
"""
optimizer.zero_grad()
grad_scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
grad_scaler.unscale_(optimizer)
if log_summaries:
assert model is not None
compute_grad_norm(model)
# gradient clipping
if max_gradient_norm is not None and max_gradient_norm < 0.0:
raise ValueError(
f"max_gradient_norm has to be a non-negative float. Got "
f"{max_gradient_norm}"
)
if max_gradient_value is not None and max_gradient_value < 0.0:
raise ValueError(
f"max_gradient_value has to be a non-negative float. Got "
f"{max_gradient_value}"
)
if max_gradient_norm is not None and max_gradient_value is not None:
raise ValueError(
f"Gradients can be clipped by norm(={max_gradient_norm}) or by "
f"value(={max_gradient_value}), but not both. "
f"Do not set both `max_gradient_norm` and `max_gradient_value`."
)
params = (
p
for param_group in optimizer.param_groups
for p in param_group["params"]
)
if max_gradient_norm is not None:
torch.nn.utils.clip_grad_norm_(list(params), max_gradient_norm)
elif max_gradient_value is not None:
torch.nn.utils.clip_grad_value_(list(params), max_gradient_value)
grad_scaler.step(optimizer)
grad_scaler.update()