Learning rate scheduling#

Overview#

Similar to the optimizers, the PyTorch Vanilla learning rate schedulers are incompatible with traced lazy execution. There are exact drop-in replacements for all commonly used learning rate schedulers available in cerebras.pytorch.optim.lr_scheduler. For example:

lr_scheduler = cstorch.optim.lr_scheduler.LinearLR(
    optimizer,
    initial_learning_rate=0.01,
    end_learning_rate=0.0001,
    total_iters=1000,
)

Writing Custom Learning Rate Schedulers#

To define a Cerebras-compliant learning rate scheduler, create a subclass of cerebras.pytorch.optim.lr_scheduler.LRScheduler.

For example,

class CustomScheduler(cstorch.optim.lr_scheduler.LRScheduler):

    def __init__(self, optimizer, ...):
        ...
        super().__init__(optimizer, total_iters=..., last_epoch=...)

    ...

    def _get_closed_form_lr(self) -> torch.Tensor:
        ...

As can be seen in the above example, the base LRScheduler class expects three arguments: the optimizer whose learning rate is being scheduled, the total number of iterations that the scheduler is scheduled(optional), and the last epoch to start on.

In addition, the following abstract method must be overridden:

  • _get_closed_form_lr

    This method is where the full scheduler is defined in closed form. Note that due to the nature of lazy tensor tracing and execution, there may not be any Python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    Moreover, static structures are allowed. For example, a loop with a fixed number of iterations, or a Python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

    This method is expected to return a torch.Tensor that represents the full learning rate schedule as a computed tensor.

    See the existing LR scheduler implementations for examples of how to correctly define the schedule.

Once you’ve written your custom LR scheduler, as long as its available in the global scope, you can use it in ModelZoo in a similar way by setting the scheduler to be the name of your custom LR schedule class in your params YAML file.