Source code for cerebras.modelzoo.common.run_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for running Cerebras Pytorch Models."""

import argparse
import inspect
import os
import subprocess
import sys
from shutil import which
from typing import Any, Callable, Dict, List, Optional

from cerebras.modelzoo.common.utils.run.cli_pytorch import get_params_from_args
from cerebras.modelzoo.common.utils.run.utils import DeviceType


[docs]def torchrun(filename: str, arguments: List[str]): """Starts a distributed GPU run using torchrun.""" import torch torchrun_cmd = [ which("torchrun", path=os.path.dirname(sys.executable)), "--nnodes=1", f"--nproc_per_node={torch.cuda.device_count()}", filename, *arguments, ] try: print( f"Starting distributed GPU run using torchrun:\n" f"{' '.join(torchrun_cmd)}" ) subprocess.run(torchrun_cmd, check=True) except Exception as e: raise RuntimeError( f"Failed to spawn distributed GPU run using torchrun" ) from e
[docs]def run( extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """Entry point to running pytorch models including CLI argument parsing. Args: extra_args_parser_fn: An optional callable that adds any extra parser args not covered in `get_parser` fn. """ parent = inspect.getouterframes(inspect.currentframe())[1] params = get_params_from_args(extra_args_parser_fn) main( params, script=parent.filename, extra_args_parser_fn=extra_args_parser_fn, )
[docs]def main( params: Dict[str, Any], script: Optional[str] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """Runs a full end-to-end CS/non-CS workflow for a PyTorch model. Args: params: The parsed YAML config dictionary. script: The script to run in subprocesses for distributed GPU runs. extra_args_parser_fn: An optional callable that adds any extra parser args not covered in `get_parser` fn. """ if not script: parent = inspect.getouterframes(inspect.currentframe())[1] script = parent.filename if ( "runconfig" in params # If using distributed GPU with experimental API and params["runconfig"]["target_device"] == DeviceType.GPU and params["runconfig"].get("enable_distributed", False) # If this is already set, we've already launched distributed training and os.environ.get("LOCAL_RANK") is None ): # use torchrun to launch distributed training torchrun(script, sys.argv[1:]) return None from cerebras.modelzoo.common.pytorch_utils import RunConfigParamsValidator from cerebras.modelzoo.trainer.utils import ( inject_cli_args_to_trainer_params, run_trainer, ) runconfig_params = params["runconfig"] RunConfigParamsValidator(extra_args_parser_fn).validate(runconfig_params) mode = runconfig_params["mode"] # Recursively update the params with the runconfig if "runconfig" in params and "trainer" in params: params = inject_cli_args_to_trainer_params( params.pop("runconfig"), params ) return run_trainer(mode, params)