# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Abstract base class for PyTorch models.
"""
import logging
import warnings
from contextlib import nullcontext
from typing import Callable, Union
import torch
from modelzoo.common.pytorch import amp
from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch, modes
from modelzoo.common.pytorch.gradient_clipper import GradientClipper
from modelzoo.common.pytorch.half_dtype import half_dtype_instance
from modelzoo.common.pytorch.optim import (
ASGD,
SGD,
Adadelta,
Adafactor,
Adagrad,
Adam,
Adamax,
AdamW,
Lamb,
Lion,
NAdam,
RAdam,
RMSprop,
Rprop,
lr_scheduler,
)
from modelzoo.common.pytorch.utils import (
named_parameters_requiring_grad,
partition_params_groups_with_adjusted_lr,
partition_params_groups_with_weight_decay,
)
SUPPORTED_OPTIMIZERS = [
'Adadelta',
'Adafactor',
'Adagrad',
'Adam',
'AdamW',
'Adamax',
'ASGD',
'Lamb',
'Lion',
'NAdam',
'RAdam',
'RMSprop',
'Rprop',
'SGD',
]
[docs]class Final(type):
"""Placeholder class for deprecation warning"""
[docs] def __new__(mcs, name, bases, classdict):
for b in bases:
if isinstance(b, Final):
warnings.warn(
"Inheriting from PyTorchBaseModel is now deprecated and "
"will be removed in a future release. Please change your "
"model to inherit from torch.nn.Module instead."
)
return type.__new__(mcs, name, bases, dict(classdict))
[docs]class PyTorchBaseModel(metaclass=Final):
"""Base Model Definition for Cerebras runners"""
[docs] def __init__(
self,
params: dict,
model_fn: Union[Callable[[dict], torch.nn.Module], torch.nn.Module],
device: torch.device = None,
):
use_bfloat16 = params["model"].get("use_bfloat16", False)
half_dtype_instance.use_bfloat16 = use_bfloat16
if use_bfloat16 and cm.use_cs():
amp.use_bfloat16(use_bfloat16)
if isinstance(model_fn, torch.nn.Module):
# To keep compatibility with deprecated usage
self.torch_model = model_fn
else:
self.torch_model = model_fn(params)
if cm.use_cs():
self.model = cbtorch.module(self.torch_model, device)
elif device:
self.model = self.torch_model.to(device)
else:
self.model = self.torch_model
self.device = device
if hasattr(self.model, "_post_device_transfer"):
self.model._post_device_transfer()
self.mode = params["runconfig"]["mode"]
self.mixed_precision = params["model"]["mixed_precision"]
self.is_pretrained_checkpoint = params["runconfig"].get(
"is_pretrained_checkpoint", False
)
# Whether or not to allow multireplica runs
# default to false for eval runs.
self.allow_multireplica = (
params["model"].get("allow_multireplica", True)
and self.mode == "train"
)
seed = params["runconfig"].get("seed", None)
if seed is not None:
torch.manual_seed(seed)
oparams = params["optimizer"]
loss_scaling_factor = oparams.get("loss_scaling_factor", 1.0)
if loss_scaling_factor == "dynamic" and use_bfloat16:
oparams["loss_scaling_factor"] = 1.0
logging.info(
"No need to use DLS for loss when `use_bfloat16` is set to"
" `True`. Setting `loss_scaling_factor ` to `1.0`."
)
# Learning rate params
self.lr_scheduler = None
lr_params = {
"learning_rate": oparams["learning_rate"],
"disable_lr_steps_reset": oparams.get(
"disable_lr_steps_reset", False
),
"adjust_learning_rate": oparams.get("adjust_learning_rate"),
}
# learning rate scaling params
lr_adjustment_scalars = []
lr_adjustment_layers = []
if oparams.get("adjust_learning_rate"):
for layer_type, adjustment_scalar in oparams.get(
"adjust_learning_rate"
).items():
lr_adjustment_layers.append(layer_type)
lr_adjustment_scalars.append(adjustment_scalar)
assert len(lr_adjustment_scalars) == len(
lr_adjustment_layers
), "number of keys for layer types should match the number of scalars"
if not isinstance(lr_params["learning_rate"], (float, str, dict, list)):
raise ValueError(
f"Learning rate must be a float, a dict, or a list of dicts. "
f"Got {type(lr_params['learning_rate'])}"
)
self.optimizer = None
if self.mode in (modes.TRAIN, modes.TRAIN_AND_EVAL):
if cm.is_appliance():
ctx = cbtorch.state().init_tracker.entry("configure_optimizer")
else:
ctx = nullcontext()
with ctx:
self.optimizer = self._configure_optimizer(
oparams, lr_adjustment_layers, lr_adjustment_scalars
)
if cm.use_cs():
self.optimizer = cbtorch.optimizer(self.optimizer)
self.lr_scheduler = self._configure_lr_scheduler(lr_params)
if cm.use_cs() and cbtorch.env().weight_streaming_mode:
# pylint: disable=no-member
self.optimizer.set_main_lr_scheduler(self.lr_scheduler)
if cm.use_cs(): # init grad scaler for mixed precision
self.grad_scaler = amp.GradScaler(
loss_scale=oparams.get("loss_scaling_factor"),
init_scale=oparams.get("initial_loss_scale"),
steps_per_increase=oparams.get("steps_per_increase"),
min_loss_scale=oparams.get("min_loss_scale"),
max_loss_scale=oparams.get("max_loss_scale"),
max_gradient_norm=oparams.get("max_gradient_norm"),
mixed_precision=self.mixed_precision,
)
if self.optimizer:
# Gradient clipping params
self.optimizer.gradient_clipper = GradientClipper(
oparams.get("max_gradient_norm", 0.0),
oparams.get("max_gradient_value", 0.0),
)
# set duplicate params for params and buffers in the model
self._duplicate_params_map = self._named_members(
# pylint: disable=protected-access
self.model,
lambda module: module._parameters.items(),
)
self._duplicate_params_map.update(
# pylint: disable=protected-access
self._named_members(
self.model, lambda module: module._buffers.items()
)
)
if cm.use_cs():
progress = cbtorch.state().progress_tracker
if progress is not None:
# We're done initialization so we can now close the progress tracker
progress.close()
[docs] def train(self):
"""
Sets the model into training mode, equivalent to .train() called on a torch.nn.Module.
"""
self.model.train()
# Setting up train mode across immediate children following
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train
for module in vars(self).values():
if hasattr(module, "train"):
module.train()
self.mode = modes.TRAIN
[docs] def eval(self):
"""
Sets the model into eval mode, equivalent to .eval() called on a torch.nn.Module.
"""
self.model.eval()
# Setting up eval mode across immediate children following
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.eval
for module in vars(self).values():
if hasattr(module, "eval"):
module.eval()
self.mode = modes.EVAL
@property
def duplicate_params_map(self):
"""
Returns a map of param names which hold the same tensors
key and value are same as the names that appear in state_dict
"""
return self._duplicate_params_map
@property
def supported_cs_modes(self):
"""
Returns a list of modes that are supported for CS runs.
By default we support train and eval, however, this property
is designed to be overriden on a model-by-model basis.
"""
return (modes.TRAIN, modes.EVAL)
@property
def supported_non_cs_modes(self):
"""
Returns a list of modes that are supported for non-CS (CPU/GPU) runs.
By default we support train, eval and train_and_eval, however, this
property is designed to be overriden on a model-by-model basis.
"""
return (modes.TRAIN, modes.EVAL, modes.TRAIN_AND_EVAL)
@property
def supported_modes(self):
"""Supported modes conditional on hardware backend"""
if cm.use_cs():
return self.supported_cs_modes
return self.supported_non_cs_modes
[docs] def supports_mode(self, mode) -> bool:
"""Check if model supports provided mode"""
if cm.use_cs():
return mode in self.supported_cs_modes
else:
return mode in self.supported_non_cs_modes
def _configure_optimizer(
self,
oparams: dict,
lr_adjustment_layers: list,
lr_adjustment_scalars: list,
):
"""
Configure an optimizer based on the params and return it
"""
optimizer_type = oparams["optimizer_type"].lower()
learning_rate = oparams["learning_rate"]
if isinstance(learning_rate, (float, str)):
learning_rate = float(learning_rate)
else: # Indicates learning rate scheduling which sets the LR in the scheduler
learning_rate = 0.1
if cm.use_cs():
if not cbtorch.env().weight_streaming_mode:
assert optimizer_type in [
"sgd",
"adafactor",
"adam",
"adamw",
], "Only SGD Adafactor, and Adam/AdamW Optimizers are supported in pipeline mode."
# According to muP hyperparameter transfer (https://arxiv.org/abs/2203.03466),
# learning rate of non-embedding kernel params need to be scaled by
# the adjustment factor based on the hidden size.
#
# adaptive_lr_layers is a list of these kernel params
# to which this scale needs to be applied.
param_optimizer = list(named_parameters_requiring_grad(self.model))
# default: assemble all params in 1 group
param_optimizer_grouped = [{"params": list(param_optimizer)}]
# split param_groups in 2 groups: with and without weight decay
param_optimizer_grouped = partition_params_groups_with_weight_decay(
self.model,
param_optimizer_grouped,
oparams.get("weight_decay_rate", 0.0),
)
# create additional param groups for each layer type with lr adjustment scalar
param_optimizer_grouped = partition_params_groups_with_adjusted_lr(
self.model,
param_optimizer_grouped,
lr_adjustment_layers,
lr_adjustment_scalars,
)
# remove param name from the (name, param) tuple as the name was only used for referencing
# while grouping params
for group_idx in range(len(param_optimizer_grouped)):
param_list = []
for _, param in param_optimizer_grouped[group_idx]["params"]:
param_list.append(param)
param_optimizer_grouped[group_idx].pop("params")
param_optimizer_grouped[group_idx]["params"] = param_list
# ensure number of params before and after group partitioning match
num_params_grouped = sum(
len(g["params"]) for g in param_optimizer_grouped
)
assert num_params_grouped == len(
param_optimizer
), "number of params before and after group partitioning don't match"
if optimizer_type == "sgd":
return SGD(
param_optimizer_grouped,
lr=learning_rate,
momentum=oparams["momentum"],
weight_decay=oparams.get("weight_decay_rate", 0.0),
nesterov=oparams.get("use_nesterov", False),
)
elif optimizer_type == "adam":
return Adam(
param_optimizer_grouped,
betas=(oparams.get("beta1", 0.9), oparams.get("beta2", 0.999)),
eps=oparams.get("eps", 1e-6),
weight_decay=oparams.get("weight_decay_rate", 0.0),
amsgrad=oparams.get("amsgrad", False),
)
elif optimizer_type == "adamw":
return AdamW(
param_optimizer_grouped,
betas=(oparams.get("beta1", 0.9), oparams.get("beta2", 0.999)),
eps=oparams.get("eps", 1e-6),
weight_decay=oparams.get("weight_decay_rate", 0.0),
correct_bias=oparams.get("correct_bias", True),
amsgrad=oparams.get("amsgrad", False),
)
elif optimizer_type == "adamax":
return Adamax(
param_optimizer_grouped,
lr=learning_rate,
betas=(oparams.get("beta1", 0.9), oparams.get("beta2", 0.999)),
eps=oparams.get("eps", 1e-6),
weight_decay=oparams.get("weight_decay_rate", 0.0),
maximize=oparams.get("maximize", False),
)
elif optimizer_type == "adadelta":
return Adadelta(
param_optimizer_grouped,
lr=learning_rate,
rho=oparams.get("rho", 0.9),
eps=oparams.get("eps", 1e-6),
weight_decay=oparams.get("weight_decay_rate", 0.0),
maximize=oparams.get("maximize", False),
)
elif optimizer_type == "adafactor":
eps = (oparams.get("eps1", 1e-30), oparams.get("eps2", 1e-3))
clip_threshold = oparams.get("clip_threshold", 1.0)
decay_rate = oparams.get("decay_rate", -0.8)
beta1 = oparams.get("beta1", None)
weight_decay = oparams.get("weight_decay_rate", 0.0)
scale_parameter = oparams.get("scale_parameter", True)
relative_step = oparams.get("relative_step", False)
warmup_init = oparams.get("warmup_init", False)
return Adafactor(
param_optimizer_grouped,
lr=learning_rate,
eps=eps,
clip_threshold=clip_threshold,
decay_rate=decay_rate,
beta1=beta1,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step=relative_step,
warmup_init=warmup_init,
)
elif optimizer_type == "adagrad":
return Adagrad(
param_optimizer_grouped,
lr=learning_rate,
lr_decay=oparams.get("lr_decay", 0.0),
weight_decay=oparams.get("weight_decay_rate", 0.0),
initial_accumulator_value=oparams.get(
"initial_accumulator_value", 0.0
),
eps=oparams.get("eps", 1e-6),
maximize=oparams.get("maximize", False,),
)
elif optimizer_type == "asgd":
return ASGD(
param_optimizer_grouped,
lr=learning_rate,
lambd=oparams.get("lambd", 1e-4),
alpha=oparams.get("alpha", 0.75),
t0=oparams.get("t0", 1e-6),
weight_decay=oparams.get("weight_decay", 0.0),
maximize=oparams.get("maximize", False),
)
elif optimizer_type == "lamb":
eps = oparams.get("eps", 1e-6)
betas = (oparams.get("beta1", 0.9), oparams.get("beta2", 0.999))
adam = oparams.get("adam", False)
weight_decay = oparams.get("weight_decay_rate", 0.0)
return Lamb(
param_optimizer_grouped,
lr=learning_rate,
betas=betas,
eps=eps,
weight_decay=weight_decay,
adam=adam,
)
elif optimizer_type == "lion":
betas = (oparams.get("beta1", 0.9), oparams.get("beta2", 0.99))
weight_decay = oparams.get("weight_decay_rate", 0.0)
return Lion(
param_optimizer_grouped,
lr=learning_rate,
betas=betas,
weight_decay=weight_decay,
)
elif optimizer_type == "radam":
eps = oparams.get("eps", 1e-6)
betas = (oparams.get("beta1", 0.9), oparams.get("beta2", 0.999))
weight_decay = oparams.get("weight_decay_rate", 0)
return RAdam(
param_optimizer_grouped,
lr=learning_rate,
eps=eps,
betas=betas,
weight_decay=weight_decay,
)
elif optimizer_type == "nadam":
eps = oparams.get("eps", 1e-6)
betas = (oparams.get("beta1", 0.9), oparams.get("beta2", 0.999))
weight_decay = oparams.get("weight_decay_rate", 0)
momentum_decay = oparams.get("momentum_decay", 4e-3)
return NAdam(
param_optimizer_grouped,
lr=learning_rate,
eps=eps,
betas=betas,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
)
elif optimizer_type == "rmsprop":
return RMSprop(
param_optimizer_grouped,
alpha=oparams.get("alpha", 0.99),
momentum=oparams.get("momentum", 0),
centered=oparams.get("centered", False),
eps=oparams.get("eps", 1e-8),
weight_decay=oparams.get("weight_decay_rate", 0.0),
)
elif optimizer_type == "rprop":
etas = (oparams.get("eta1", 0.5), oparams.get("eta2", 1.2))
step_sizes = (
oparams.get("step_size_min", 1e-6),
oparams.get("step_size_max", 50.0),
)
return Rprop(
param_optimizer_grouped, etas=etas, step_sizes=step_sizes,
)
else:
raise ValueError(
f"Unsupported optimizer type {optimizer_type}. Supported types:"
f"{SUPPORTED_OPTIMIZERS}."
)
[docs] def get_optimizer(self):
"""
Returns the optimizer associated with this model.
"""
return self.optimizer
def _configure_lr_scheduler(self, lr_params):
"""
Initiates the LR Scheduler associated with this model.
"""
learning_rate = lr_params["learning_rate"]
disable_lr_steps_reset = lr_params["disable_lr_steps_reset"]
def _get_scheduler(optimizer, schedule_params):
"""
Parses a dict of learning rate scheduler specifications and
returns a learning rate tensor.
:param dict schedule_params:
A dict with a "scheduler" key (e.g.,
schedule_params["scheduler"] = "Exponential") and all
params schedulers of that type need.
:returns: The learning rate tensor.
"""
scheduler = schedule_params["scheduler"].lower()
# to handle discrepancy in step parameters
if "steps" in schedule_params:
schedule_params["decay_steps"] = schedule_params["steps"]
elif "decay_steps" in schedule_params:
schedule_params["steps"] = schedule_params["decay_steps"]
if "learning_rate" in schedule_params:
schedule_params["initial_learning_rate"] = schedule_params[
"learning_rate"
]
schedule_params["base_lr"] = schedule_params["learning_rate"]
elif "initial_learning_rate" in schedule_params:
schedule_params["learning_rate"] = schedule_params[
"initial_learning_rate"
]
schedule_params["base_lr"] = schedule_params[
"initial_learning_rate"
]
elif "base_lr" in schedule_params:
schedule_params["learning_rate"] = schedule_params["base_lr"]
schedule_params["initial_learning_rate"] = schedule_params[
"base_lr"
]
if "gamma" in schedule_params:
schedule_params["decay_rate"] = schedule_params["gamma"]
elif "decay_rate" in schedule_params:
schedule_params["gamma"] = schedule_params["decay_rate"]
def check_required_params(required_params):
missing = list(set(required_params) - set(schedule_params))
if missing:
raise ValueError(
f"Missing required parameters {missing} "
f"for the {scheduler} learning rate scheduler. "
f"Note, the {scheduler} learning rate scheduler "
f"requires the following parameters: {required_params}"
)
if scheduler == "constant" or scheduler == "constantlr":
check_required_params(["learning_rate"])
return lr_scheduler.ConstantLR(
optimizer,
val=schedule_params["learning_rate"],
decay_steps=schedule_params.get("steps", None),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif scheduler == "exponential" or scheduler == "exponentiallr":
check_required_params(["initial_learning_rate", "decay_rate"])
return lr_scheduler.ExponentialLR(
optimizer,
initial_learning_rate=float(
schedule_params["initial_learning_rate"]
),
decay_steps=schedule_params.get("decay_steps", 1),
decay_rate=schedule_params["decay_rate"],
staircase=schedule_params.get("staircase", False),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif (
scheduler == "piecewiseconstant"
or scheduler == "piecewiseconstantlr"
):
check_required_params(["values", "boundaries"])
return lr_scheduler.PiecewiseConstantLR(
optimizer,
learning_rates=schedule_params["values"],
milestones=schedule_params["boundaries"],
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif scheduler in (
"polynomial",
"polynomiallr",
"linear",
"linearlr",
):
check_required_params(
[
"initial_learning_rate",
"end_learning_rate",
"decay_steps",
]
)
power = (
1.0
if scheduler == "linear" or scheduler == "linearLR"
else schedule_params.get("power", 1.0)
)
return lr_scheduler.PolynomialLR(
optimizer,
initial_learning_rate=float(
schedule_params["initial_learning_rate"]
),
end_learning_rate=schedule_params["end_learning_rate"],
decay_steps=schedule_params["decay_steps"],
power=power,
cycle=schedule_params.get("cycle", False),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif (
scheduler == "inverseexponentialtimedecay"
or scheduler == "inverseexponentialtimedecaylr"
):
check_required_params(
[
"initial_learning_rate",
"step_exponent",
"decay_steps",
"decay_rate",
]
)
return lr_scheduler.InverseExponentialTimeDecayLR(
optimizer,
initial_learning_rate=float(
schedule_params["initial_learning_rate"]
),
step_exponent=schedule_params["step_exponent"],
decay_steps=schedule_params["decay_steps"],
decay_rate=schedule_params["decay_rate"],
staircase=schedule_params.get("staircase", False),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif (
scheduler == "inversesquarerootdecay"
or scheduler == "inversesquarerootdecaylr"
):
return lr_scheduler.InverseSquareRootDecayLR(
optimizer,
initial_learning_rate=float(
schedule_params.get("initial_learning_rate", 1)
),
scale=schedule_params.get("scale", 1.0),
warmup_steps=schedule_params.get("warmup_steps", 1.0),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif scheduler == "cosinedecay" or scheduler == "cosinedecaylr":
check_required_params(
[
"initial_learning_rate",
"end_learning_rate",
"decay_steps",
]
)
return lr_scheduler.CosineDecayLR(
optimizer,
initial_learning_rate=float(
schedule_params["initial_learning_rate"]
),
end_learning_rate=schedule_params["end_learning_rate"],
decay_steps=schedule_params["decay_steps"],
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif (
scheduler == "cosineannealing"
or scheduler == "cosineannealinglr"
):
check_required_params(
["initial_learning_rate", "t_max",]
)
return lr_scheduler.CosineAnnealingLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
T_max=schedule_params["t_max"],
eta_min=schedule_params.get("eta_min", 0.0),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif scheduler == "step" or scheduler == "steplr":
check_required_params(
["initial_learning_rate", "step_size", "gamma",]
)
return lr_scheduler.StepLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
gamma=schedule_params["gamma"],
step_size=schedule_params["step_size"],
disable_lr_steps_reset=False,
)
elif scheduler == "multistep" or scheduler == "multisteplr":
check_required_params(
["initial_learning_rate", "gamma", "milestones",]
)
return lr_scheduler.MultiStepLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
gamma=schedule_params["gamma"],
milestones=schedule_params["milestones"],
disable_lr_steps_reset=False,
)
elif scheduler == "lambda" or scheduler == "lambdalr":
check_required_params(["initial_learning_rate"])
return lr_scheduler.LambdaLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
disable_lr_steps_reset=False,
)
elif scheduler == "cosineannealingwarmrestarts":
check_required_params(
["initial_learning_rate", "t_0",]
)
return lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
T_0=schedule_params["t_0"],
T_mult=schedule_params.get("t_mult", 1),
eta_min=schedule_params.get("eta_min", 0.0),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif (
scheduler == "multiplicative" or scheduler == "multiplicativelr"
):
check_required_params(
["initial_learning_rate", "coefficient",]
)
return lr_scheduler.MultiplicativeLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
coefficient=schedule_params["coefficient"],
disable_lr_steps_reset=False,
)
elif scheduler == "cyclic" or scheduler == "cycliclr":
check_required_params(
["base_lr", "max_lr",]
)
return lr_scheduler.CyclicLR(
optimizer,
base_lr=schedule_params["base_lr"],
max_lr=schedule_params["max_lr"],
step_size_up=schedule_params.get("step_size_up", 2000),
step_size_down=schedule_params.get("step_size_down", None),
mode=schedule_params.get("mode", "triangular"),
gamma=schedule_params.get("gamma", 1.0),
scale_mode=schedule_params.get("scale_mode", "cycle"),
disable_lr_steps_reset=disable_lr_steps_reset,
)
elif scheduler == "onecycle" or scheduler == "onecyclelr":
check_required_params(
["initial_learning_rate", "max_lr",]
)
return lr_scheduler.OneCycleLR(
optimizer,
initial_learning_rate=schedule_params[
"initial_learning_rate"
],
max_lr=schedule_params["max_lr"],
total_steps=schedule_params.get("total_steps", 1000),
pct_start=schedule_params.get("pct_start", 0.3),
final_div_factor=schedule_params.get(
"final_div_factor", 1e4
),
three_phase=schedule_params.get("three_phase", False),
anneal_strategy=schedule_params.get(
"anneal_strategy", "cos"
),
disable_lr_steps_reset=False,
)
else:
raise ValueError(f"Unsupported LR scheduler {scheduler}")
# convert the learning rate object into list of dictionaries
# to make the scheduler handling for uniform and lean
# handle a constant learning rate
# scientific notation (e.g. "1e-5") parsed as string in yaml
if isinstance(learning_rate, (float, str)):
learning_rate_dicts = [
{"scheduler": "constant", "learning_rate": float(learning_rate)}
]
elif isinstance(learning_rate, dict):
learning_rate_dicts = [learning_rate]
elif isinstance(learning_rate, list):
learning_rate_dicts = learning_rate
else:
raise ValueError(
f"Unsupported LR scheduler type {type(learning_rate)}"
f"Supported LR schedulers are ['Constant', 'Exponential',"
f" 'PiecewiseConstant', 'Polynomial',"
f" 'InverseExponentialTimeDecay']"
)
if len(learning_rate_dicts) > 1:
for scheduler in learning_rate[:-1]:
assert "steps" in scheduler or "decay_steps" in scheduler, (
"Non final learning rate schedulers must have either "
"the 'steps' or 'decay_steps' parameter given."
)
schedulers = []
for learning_rate_ in learning_rate_dicts:
# wrap the scheduler in the ScalePerParamLR which
# adapts the layerwise learning rates depending upon
# the adjust_learning_rate key in the respective param group
assert isinstance(
learning_rate_, dict
), f'{learning_rate_} should be a dict'
schedulers.append(_get_scheduler(self.optimizer, learning_rate_))
if len(schedulers) == 1:
return lr_scheduler.ScalePerParamLR(
self.optimizer,
scheduler=schedulers[0],
decay_steps=learning_rate_.get("decay_steps", None),
disable_lr_steps_reset=disable_lr_steps_reset,
)
else:
if (
"main_scheduler" in learning_rate_dicts[0]
and learning_rate_dicts[0]["main_scheduler"] == "chained"
):
return lr_scheduler.ChainedScheduler(schedulers=schedulers,)
else:
milestones = [
scheduler.start_step for scheduler in schedulers[1:]
]
return lr_scheduler.ScalePerParamLR(
self.optimizer,
scheduler=lr_scheduler.SequentialLR(
self.optimizer,
schedulers=schedulers,
milestones=milestones,
),
decay_steps=learning_rate_.get("decay_steps", None),
disable_lr_steps_reset=disable_lr_steps_reset,
)
[docs] def get_lr_scheduler(self):
"""
Returns the LR Scheduler associated with this model.
"""
return self.lr_scheduler
[docs] def get_state(self, keep_vars=False):
"""
Returns the state of the model and optimizer
"""
state_dict = {
"model": self.model.state_dict(keep_vars=keep_vars),
# PyTorchBaseModel state dict format version
"state_version": 0.2,
}
if self.optimizer:
state_dict["optimizer"] = self.optimizer.state_dict()
if self.lr_scheduler:
state_dict["lr_scheduler"] = self.lr_scheduler.state_dict()
if self.mixed_precision and cm.use_cs():
state_dict["amp"] = self.grad_scaler.state_dict()
return state_dict
def _named_members(self, model, get_member_fn, prefix='', recurse=True):
"""
Helper method which returns a map of param_name -> set of duplicate param names
"""
memo = dict()
names = dict()
modules = (
model.named_modules(prefix=prefix, remove_duplicate=False)
if recurse
else [(prefix, self)]
)
for module_prefix, module in modules:
members = get_member_fn(module)
for k, v in members:
name = module_prefix + ('.' if module_prefix else '') + k
if v is None:
continue
elif v in memo:
# whenever a duplicate is found
# update the existing list of duplicate
# names corresponding to the first name
duplicates = names.get(memo[v], set([memo[v]]))
duplicates.add(name)
names[memo[v]] = duplicates
# also add a key for new name with
# value as the duplicates list
names[name] = duplicates
continue
memo[v] = name
return names
[docs] def set_state(self, state, strict=True):
"""
Sets the state of the model and optimizer
"""
if self.is_pretrained_checkpoint and self.mode != modes.EVAL:
# allow loading weights ignoring the missing and unexpected keys
# except when doing eval
strict = False
model = self.torch_model
if hasattr(model, "model") and state.get("state_version", 0.1) == 0.1:
# Required for backwards compatibility
# older checkpoints of models in the modelzoo
# will not contain the `model.` prefix
model = model.model
model.load_state_dict(state["model"], strict=strict)
if (
self.optimizer
and "optimizer" in state
and not self.is_pretrained_checkpoint
):
# load optimizer state for resuming training
self.optimizer.load_state_dict(state["optimizer"])
if self.lr_scheduler and "lr_scheduler" in state:
self.lr_scheduler.load_state_dict(state["lr_scheduler"])
if (
self.mixed_precision
and cm.is_wse_device()
and not self.is_pretrained_checkpoint
):
amp_state = state.get('amp')
if amp_state:
self.grad_scaler.load_state_dict(amp_state)
def __call__(self, *args, **kwargs):
"""
Given one iteration of a dataloader, returns the loss associated with
one forward pass of that batch.
"""
return self.model(*args, **kwargs)