# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running Cerebras Pytorch Models"""
import argparse
import inspect
import math
import os
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
import yaml
import modelzoo.common.pytorch.half_dtype as half_dtype
from modelzoo import CSOFT_PACKAGE, CSoftPackage
from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import is_ltc_mlir_mode_enabled, modes
from modelzoo.common.pytorch.pytorch_base_runner import PyTorchBaseRunner
from modelzoo.common.pytorch.utils import (
RunConfigParamsValidator,
get_checkpoints,
setup_logging,
)
from modelzoo.common.run_utils.cli_parser import get_params_from_args
from modelzoo.common.run_utils.utils import DeviceType
DATA_FN_TYPE = Callable[[dict], torch.utils.data.DataLoader]
half_dtype_instance = half_dtype.half_dtype_instance
[docs]def arg_filter(arg: str, keyword: str) -> bool:
"""Checks if a given arg matches the given keyword"""
arg = arg.strip()
return (
arg.startswith(f"--{keyword}=")
or arg.startswith(f"-{keyword}=")
or arg == f"--{keyword}"
or arg == f"-{keyword}"
)
[docs]def update_sideband_mode_arg(
arguments: List[str], new_mode_arg: str, old_mode: str
) -> List[str]:
"""Updates sideband arguments to a different mode"""
# filter out args with the name of the old mode provided they
# have "mode" or "m" preceding
offset_arguments = [None] + arguments
updated_args = [
an_arg
for an_arg, prev_arg in zip(arguments, offset_arguments)
if an_arg != old_mode
or not (arg_filter(prev_arg, "mode") or arg_filter(prev_arg, "m"))
]
# filter and add the new mode
updated_args = [
new_mode_arg
if arg_filter(an_arg, "mode") or arg_filter(an_arg, "m")
else an_arg
for an_arg in updated_args
]
return updated_args
[docs]def sideband_eval_all(
filename: str, arguments: List[str], params: Dict[Any, Any]
):
"""Temporary support for running eval multiple times via subprocess"""
eval_mode = "--mode=eval"
if any(arg_filter(an_arg, "checkpoint_path") for an_arg in arguments):
raise ValueError(
"Checkpoint path cannot be provided with eval_all. Checkpoints inferred from model_dir"
)
updated_args = update_sideband_mode_arg(
arguments, eval_mode, modes.EVAL_ALL
)
# Gather all checkpoints
checkpoint_path = None
updated_args.append(checkpoint_path)
checkpoints = get_checkpoints(params['runconfig']['model_dir'],)
if len(checkpoints) == 0:
raise ValueError(
f"No checkpoints found at {params['runconfig']['model_dir']}"
)
for a_chkpt in checkpoints:
checkpoint_path = f"--checkpoint_path={a_chkpt}"
updated_args[-1] = checkpoint_path
# By just calling this from the top each run will be a separate logdir
single_run = [sys.executable, filename]
single_run.extend(updated_args)
subprocess.run(single_run, check=True)
[docs]def sideband_train_eval_all(
filename: str, arguments: List[str], params: Dict[Any, Any]
):
"""Temporary support for running train and eval multiple times via subprocess"""
train_mode = "--mode=train"
eval_mode = "--mode=eval"
train_args = update_sideband_mode_arg(
arguments, train_mode, modes.TRAIN_AND_EVAL
)
eval_args = update_sideband_mode_arg(
arguments, eval_mode, modes.TRAIN_AND_EVAL
)
runconfig = params['runconfig']
if runconfig.get('num_steps', None) is not None:
if runconfig.get('num_epochs', None) is not None:
raise ValueError(
"num_steps and num_epochs cannot both be specified "
"in the runconfig section of params"
)
if runconfig.get('steps_per_epoch', None) is not None:
raise ValueError(
"num_steps and steps_per_epoch cannot both be specified "
"in the runconfig section of params"
)
if runconfig.get('eval_frequency', None) is None:
raise ValueError(
"if num_steps is specified, eval_frequency is needed "
"to dictate how many train steps before each eval"
)
total_steps = int(runconfig['num_steps'])
train_steps = int(runconfig['eval_frequency'])
num_iters = math.ceil(total_steps / train_steps)
last_steps = total_steps % train_steps
# add num_steps overwrite
last_train_args = train_args + [f"--num_steps={last_steps}"]
train_args.append(f"--num_steps={train_steps}")
elif runconfig.get('num_epochs', None) is not None:
num_iters = int(runconfig['num_epochs'])
# add num_epochs overwrite
train_args.append("--num_epochs=1")
last_steps = 0
else:
raise ValueError(
"For train_and_eval mode, one of `num_steps` or `num_epochs` "
" must be specified and not be None."
)
single_run = [sys.executable, filename]
train_cmd = single_run + train_args
eval_cmd = single_run + eval_args
for i in range(num_iters):
# TRAIN
if i == num_iters - 1 and last_steps > 0:
train_cmd = single_run + last_train_args
try:
subprocess.run(train_cmd, check=True)
except Exception as e:
raise RuntimeError(f"Training at iteration {i} failed.") from e
# EVAL
try:
subprocess.run(eval_cmd, check=True)
except Exception as e:
raise RuntimeError(f"Evaluate at iteration {i} failed.") from e
[docs]def run(
model_fn: Callable[[dict], torch.nn.Module],
train_data_fn: Optional[DATA_FN_TYPE] = None,
eval_data_fn: Optional[DATA_FN_TYPE] = None,
default_params_fn: Optional[Callable[[dict], dict]] = None,
extra_args_parser_fn: Optional[
Callable[[], List[argparse.ArgumentParser]]
] = None,
):
"""Backward compatible entry point to running pytorch models"""
parent = inspect.getouterframes(inspect.currentframe())[1]
run_dir = os.path.dirname(os.path.abspath(parent.filename))
params = get_params_from_args(run_dir, extra_args_parser_fn)
if default_params_fn:
params = default_params_fn(params) or params
main(params, model_fn, train_data_fn, eval_data_fn, script=parent.filename)
[docs]def main(
params: Dict[str, Any],
model_fn: Callable[[dict], torch.nn.Module],
train_data_fn: Optional[DATA_FN_TYPE] = None,
eval_data_fn: Optional[DATA_FN_TYPE] = None,
script: Optional[str] = None,
extra_args_parser_fn: Optional[
Callable[[], List[argparse.ArgumentParser]]
] = None,
):
"""Entry point to running pytorch models"""
if not script:
parent = inspect.getouterframes(inspect.currentframe())[1]
script = parent.filename
if params["runconfig"]["mode"] == modes.EVAL_ALL:
sideband_eval_all(script, sys.argv[1:], params)
return None
# TODO ambiguity on what to return, possibly just run the final checkpoint in
# the main process below
# TODO enable existing train_and_eval functionality to work with cs
if (
params["runconfig"]["mode"] == modes.TRAIN_AND_EVAL
and params["runconfig"]["target_device"] == DeviceType.CSX
):
sideband_train_eval_all(script, sys.argv[1:], params)
return None
return run_with_params(
params,
model_fn,
train_data_fn,
eval_data_fn,
extra_args_parser_fn=extra_args_parser_fn,
)
[docs]def run_with_params(
params: Dict[str, Any],
model_fn: Callable[[dict], torch.nn.Module],
train_data_fn: Optional[DATA_FN_TYPE] = None,
eval_data_fn: Optional[DATA_FN_TYPE] = None,
extra_args_parser_fn: Optional[
Callable[[], List[argparse.ArgumentParser]]
] = None,
):
"""
Runs a full end-to-end CS/non-CS workflow for a given model
Args:
model_fn: A callable that takes in a 'params' argument
which it uses to configure and return a torch.nn.Module
train_data_fn: A callable that takes in a 'params' argument
which it uses to configure and return a PyTorch dataloader
corresponding to the training dataset
eval_data_fn: A callable that takes in a 'params' argument
which it uses to configure and return a PyTorch dataloader
corresponding to the evaluation dataset
default_params_fn: An optional callable that takes in the params
dictionary and updates any missing params
with default values
extra_args_parser_fn: An optional callable that adds any
extra parser args not covered in `get_parser` fn.
"""
runconfig_params = params["runconfig"]
RunConfigParamsValidator(extra_args_parser_fn).validate(runconfig_params)
if (
params["runconfig"]["target_device"] in (DeviceType.CSX, DeviceType.CPU)
and (
params["runconfig"].get("experimental_api", False)
or is_ltc_mlir_mode_enabled()
)
# TODO: Remove this check once we add a no dependency flow
and CSOFT_PACKAGE != CSoftPackage.NONE
):
params["runconfig"]["experimental_api"] = True
# pylint: disable=import-outside-toplevel
from modelzoo.common.pytorch.run_cstorch_flow import run_cstorch_flow
# Use the new experimental API flow
return run_cstorch_flow(params, model_fn, train_data_fn, eval_data_fn)
else:
params["runconfig"]["experimental_api"] = False
# Default to the previous PyTorchBaseModel/PyTorchBaseRunner flow
return run_base_model_flow(
params, model_fn, train_data_fn, eval_data_fn,
)
[docs]def run_base_model_flow(params, model_fn, train_data_fn, eval_data_fn):
"""Runs PytorchBaseModel and Runner flow"""
runconfig_params = params["runconfig"]
# Set up logging level as soon as possible. Ideally, we could even move this
# to the pytorch_cli entry point to set up logging during import of the
# modelzoo model...
setup_logging(
runconfig_params.get("logging"),
runconfig_params.get("streamer_logging"),
logging_dir=runconfig_params.get("model_dir"),
)
if "seed" in runconfig_params:
torch.manual_seed(runconfig_params["seed"])
runner = PyTorchBaseRunner.create(model_fn, params)
# Save params.yaml only in master task
mode = runconfig_params["mode"]
# using this dir structure to keep in sync with runners
summary_dir = os.path.join(params["runconfig"]["model_dir"], f"{mode}")
os.makedirs(summary_dir, exist_ok=True)
if cm.is_master_ordinal():
with open(
os.path.join(summary_dir, f"params_{mode}.yaml"), "w+",
) as _fout:
yaml.dump(params, _fout, default_flow_style=False)
# Initialize the dataloaders depending on the mode
if mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
assert train_data_fn, "Train dataloader function has not been provided"
train_loader = train_data_fn(params)
runner.train_data_fn = train_data_fn
if mode in (modes.EVAL, modes.TRAIN_AND_EVAL):
assert eval_data_fn, "Eval dataloader function has not been provided"
eval_loader = eval_data_fn(params)
runner.eval_data_fn = eval_data_fn
if mode == modes.TRAIN:
runner.train(train_loader)
elif mode == modes.EVAL:
runner.evaluate(eval_loader)
elif mode == modes.TRAIN_AND_EVAL:
runner.train_and_eval(train_loader, eval_loader)
else:
raise ValueError(f"Mode {mode} is not supported.")