Trainer API#
- class cerebras.modelzoo.Trainer(device=None, backend=None, model_dir=Ellipsis, model=Ellipsis, optimizer=None, schedulers=None, precision=None, sparsity=None, loop=None, checkpoint=None, logging=None, callbacks=None, loggers=None, seed=None)[source]#
- The Trainer class is the main entry point for training models in ModelZoo. - Parameters
- device (Optional[str]) – The device to train the model on. It must be one of “CSX”, “CPU”, or “GPU”. 
- backend (Optional[Backend]) – The backend used to train the model. This argument is mutually exclusive with device. 
- model_dir (str) – The directory where the model artifacts are saved. 
- model (Union[Callable[[], torch.nn.Module], torch.nn.Module]) – - The model to train. It must be one of the following: - If a callable is passed, it is assumed to be a function that takes in no arguments returns a torch.nn.Module. 
- If a torch.nn.Module is passed, it is used as is. 
 
- optimizer (Union[Optimizer, Callable[[torch.nn.Module], Optimizer], None]) – - The optimizer used to optimize the model. It must be one of the following: 
- schedulers (SchedulersInput) – - The set of optimizer schedulers to be used. Common schedulers include LR schedulers. It must be a list of these items: - If a cstorch.optim.scheduler.Scheduler is passed, it is used as is. 
- A callable that is assumed to be a function that takes in a - Optimizerand returns a cstorch.optim.scheduler.Scheduler.
- If None, there is no optimizer param group scheduling. 
 
- precision (Optional[Precision]) – The Precision callback used during training 
- sparsity (Optional[SparsityAlgorithm]) – - The sparsity algorithm used to sparsify weights during training/validation It must be one of the following: - If a callable is passed, it is assumed to be a function that takes in no arguments returns a - SparsityAlgorithm.
- If a - SparsityAlgorithmis passed, it is used as is.
 
- loop (Optional[LoopCallback]) – The loop callback to use for training. It must be an instance of LoopCallback. If not provided, the default loop is TrainingLoop(num_epochs=1). 
- checkpoint (Optional[Checkpoint]) – The checkpoint callback to use for saving/loading checkpoints. It must be an instance of Checkpoints. If not provided, then no checkpoints are saved. 
- logging (Optional[Logging]) – The logging callback used to set up python logging. This callback also controls when logs are supposed to be logged. If not provided, the default logging settings - Logging(log_steps=1, log_level="INFO")are used.
- callbacks (Optional[List[Callback]]) – A list of callbacks to used by the trainer. The order in which the callbacks are provided is important as it determines the order in which the callback’s hooks are executed. 
- loggers (Optional[List[Logger]]) – A list of loggers to use for logging. 
- seed (Optional[int]) – Initial seed for the torch random number generator. 
 
 - property all_callbacks: Generator[cerebras.modelzoo.trainer.callbacks.callback.Callback, None, None]#
- Get all callback objects available to the trainer. 
 - property validation_callbacks: List[cerebras.modelzoo.trainer.callbacks.callback.ValidationCallback]#
- Returns all validation callbacks in the Trainer’s callback list. 
 - call(hook_name, *args, **kwargs)[source]#
- Call the hook with name hook_name for all callbacks in the Trainer’s callback list as well as the callbacks in the global registry. - The callback’s method is passed in the trainer object itself as well as any args and kwargs that are passed into this method. e.g. - Parameters
- hook_name (str) – The name of the hook to call. It must be the name of a method in the Callback class. 
- args – Other positional arguments to forward along to the called hook. 
- kwargs – Other keyword arguments to forward along to the called hook. 
 
 
 - property precision: Optional[cerebras.modelzoo.trainer.callbacks.precision.Precision]#
- Returns the precision callback instance if it exists. 
 - property grad_accum: cerebras.modelzoo.trainer.callbacks.grad_accum.GradientAccumulationCallback#
- Returns the gradient accumulation callback instance. 
 - property should_run_optimizer_step: bool#
- Returns True if we should run the optimizer step. - The gradient accumulation callback may set this to False if we are accumulating gradients and have not reached the accumulation steps. Note, this only applies to CPU/GPU runs. 
 - property loop: cerebras.modelzoo.trainer.callbacks.loop.LoopCallback#
- Returns the default loop settings. 
 - property checkpoint: cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint#
- Returns the checkpoint callback. 
 - property logging: cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint#
- Returns the logging callback. 
 - property logger: logging.Logger#
- Returns the Trainer’s Python logger object. 
 - property is_log_step: bool#
- Returns True if the current step is a log step. 
 - property is_first_iteration: bool#
- Returns True if the executor is on its first iteration. 
 - property is_final_iteration: bool#
- Returns True if the executor is on its final iteration. 
 - property is_tracing: bool#
- Returns True if we are currently tracing the model. 
 - final log_metrics(**kwargs)[source]#
- Log the given kwargs to all loggers. - Example usage: - trainer.log_metrics(loss=loss.item()) - Parameters
- kwargs – The key-value pairs to log. 
 
 - final name_scope(name)[source]#
- Append name to the trainer’s name scope stack whilst inside the context. - Parameters
- name (str) – The name to append to the name scope stack. 
 
 - property name_scope_path: str#
- Returns the current name scope path. - This is the the name scope stack joined by ‘/’. 
 - final get_val_dataloader_scope(val_dataloader)[source]#
- Get the name scope for the given val dataloader. 
 - final training_step(batch)[source]#
- Run a single training step on the given batch. - Note that if retrace is off, content of this method will only run on the first iteration. So any inputs to this method must either be non-changing or torch tensors. - Parameters
- batch – The batch of data to train on. 
- batch_idx – The index of the batch in the dataloader. 
 
- Returns
- A dictionary containing the loss and any other outputs. 
- Return type
- Dict[str, Any] 
 
 - final forward(batch)[source]#
- Run the forward pass on the given batch. - Parameters
- batch – The batch of data to run the forward pass on. 
- Returns
- A dictionary containing the loss and any other outputs. 
- Return type
- Dict[str, Any] 
 
 - final backward(outputs)[source]#
- Run the backward pass on the given loss. - Parameters
- outputs (dict) – The outputs of the model. Expect key ‘loss’ to be present. 
 
 - on_exception(hook)[source]#
- Context manager to handle exceptions in the given hook. - Parameters
- hook – The hook to handle exceptions for. 
 
 - final fit(train_dataloader, val_dataloader=None, ckpt_path=Ellipsis)[source]#
- Complete a full training run on the given train and validation dataloaders. - Parameters
- train_dataloader (cerebras.appliance.log.named_class_logger) – The training dataloader. 
- val_dataloader (Optional[Union[cerebras.appliance.log.named_class_logger, List[cerebras.appliance.log.named_class_logger]]]) – - The validation dataloader. - If provided, validation is run every eval_frequency steps as defined in the loop callback. - If not provided, only training is run. - If a list of dataloaders is provided, then each dataloader is validated in sequence. 
- ckpt_path (Optional[str]) – The path to the checkpoint to load before starting training. If not provided and autoload_last_checkpoint is True, then the latest checkpoint is loaded 
 
 
 - validation_step#
 - final validate(val_dataloader=None, ckpt_path=Ellipsis, loop=None)[source]#
- Complete a full validation run on the validation dataloader. - Parameters
- val_dataloader (Optional[cerebras.appliance.log.named_class_logger]) – - The validation dataloader. If a list of dataloaders is provided, then each dataloader is - validated in sequence. 
- ckpt_path (Optional[str]) – The path to the checkpoint to load before starting validation. If not provided and autoload_last_checkpoint is True, then the latest checkpoint is loaded. 
- loop (Optional[cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop]) – The loop callback to use for validation. If not provided, the default loop is used. If provided, it must be an instance of ValidationLoop. Note, this should only be provided if the loop callback provided in the constructor is not sufficient. 
 
 
 - final validate_all(val_dataloaders=None, ckpt_paths=Ellipsis, loop=None)[source]#
- Runs all upstream and downstream validation permutations. - for ckpt_path in ckpt_paths: for val_dataloader in val_dataloaders: trainer.validate(val_dataloader, ckpt_path) # run downstream validation run_validation(...) - Parameters
- val_dataloaders (Optional[Union[cerebras.appliance.log.named_class_logger, List[cerebras.appliance.log.named_class_logger]]]) – A list of validation dataloaders to run validation on. 
- ckpt_paths (Optional[Union[List[str], str]]) – A list of checkpoint paths to run validation on. Each checkpoint path must be a path to a checkpoint file, or a glob pattern. 
- loop (Optional[cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop]) – The validation loop to use for validation. If not provided, then the default loop is used. 
 
 
 - final save_checkpoint()[source]#
- Save a checkpoint at the current global step. - The checkpoint state dict is constructed by various callbacks that implement the on_save_checkpoint method. 
 - final load_checkpoint(ckpt_path=None)[source]#
- Load a checkpoint from the given path. - The checkpoint state dict is loaded and processed by various callbacks that implement the on_load_checkpoint method. - Parameters
- ckpt_path (Optional[str]) – The path to the checkpoint to load If not provided and autoload_last_checkpoint is True, then the latest checkpoint is loaded 
 
 
Cerebras Model Zoo Callbacks API#
| This module contains the base Callback class as well as a number of core callbacks directly invoked by the Trainer as well as other optional callbacks that can be used to extend the functionality of the Trainer. | 
Cerebras Model Zoo Extensions API#
| This module contains integrations of external tools to the Trainer. | 
Cerebras Model Zoo Loggers API#
| This module contains the base Logger class as well as a few useful Logger subclasses. |