Source code for experimental.optim

"""Emulates torch.optim"""
import inspect
import logging
from copy import deepcopy

import numpy

from . import lr_scheduler
from .Adadelta import Adadelta
from .Adafactor import Adafactor
from .Adagrad import Adagrad
from .Adamax import Adamax
from .AdamBase import Adam, AdamW
from .ASGD import ASGD
from .Lamb import Lamb
from .Lion import Lion
from .NAdam import NAdam
from .optimizer import Optimizer
from .RAdam import RAdam
from .RMSprop import RMSprop
from .Rprop import Rprop
from .SGD import SGD


def _retrieve_all_subclasses(cls):
    for subcls in cls.__subclasses__():
        yield subcls
        yield from _retrieve_all_subclasses(subcls)


[docs]def configure_optimizer(optimizer_type: str, params, **kwargs): """ Configures and requires an Optimizer specified using the provided optimizer type The optimizer class's signature is inspected and relevant parameters are extracted from the keyword arguments Args: optimizer_type: The name of the optimizer to configure params: The model parameters passed to the optimizer """ optimizer_map = { cls.__name__.lower(): cls for cls in _retrieve_all_subclasses(Optimizer) } if optimizer_type.lower() not in optimizer_map: raise ValueError( f"Invalid optimizer type. Expected one of " f"{sorted(optimizer_map.keys())}. Got: {optimizer_type}" ) cls = optimizer_map[optimizer_type.lower()] learning_rate = kwargs.pop("learning_rate", None) if isinstance(learning_rate, (float, str)): learning_rate = float(learning_rate) else: learning_rate = None # common aliases aliases = { "weight_decay": ["weight_decay_rate"], "betas": [("beta1", "beta2")], "eps": [("eps1", "eps2")], "etas": [("eta1", "eta2")], } warning_str = ( f"{cls.__name__} expected parameter {{expected}} " f"but found {{actual}}. " f"It will accept {{actual}} for now and assign it to {{expected}}, " f"but this behaviour is deprecated and will be removed " f"in a future release" ) cls_kwargs = {} # inspect the optimizer and extract the required parameters from the kwargs signature = inspect.signature(cls.__init__) for name, parameter in signature.parameters.items(): if name in ("self", "params"): continue # pylint: disable=protected-access if parameter.kind == inspect._ParameterKind.VAR_KEYWORD: cls_kwargs.update(kwargs) break if name in kwargs: cls_kwargs[name] = kwargs.pop(name) elif name in ("lr", "learning_rate"): if learning_rate is None: if parameter.default is not inspect._empty: learning_rate = parameter.default else: learning_rate = 0.1 # default dummy value cls_kwargs[name] = learning_rate elif name in aliases: for alias in aliases[name]: if isinstance(alias, str) and alias in kwargs: logging.warning( warning_str.format(expected=name, actual=alias) ) cls_kwargs[name] = kwargs.pop(alias) break elif isinstance(alias, (list, tuple)) and all( a in kwargs for a in alias ): logging.warning( warning_str.format( expected=name, actual=str(alias).replace("'", "") ) ) cls_kwargs[name] = type(alias)(kwargs.pop(a) for a in alias) break if len(kwargs) > 0: # Replace the default values in the signature to show the user the # values they passed in in the warning message so that they can verify # what they actually passed in signature = signature.replace( parameters=[ inspect.Parameter( name=name, kind=param.kind, default=cls_kwargs.get(name, param.default), annotation=param.annotation, ) for name, param in signature.parameters.items() if name != "self" ] ) logging.warning( f"{cls.__name__} got {len(kwargs)} unexpected " f"and unused parameters: {sorted(kwargs.keys())}.\n" f"Please ensure that you specified the correct parameters:\n" f"{cls.__name__}{signature}\n" f"Passing in unused parameters is deprecated behaviour and " f"support for it will be removed in a future release." ) try: return cls(params, **cls_kwargs) except TypeError as e: raise RuntimeError( f"Failed to configure {cls.__name__} optimizer" ) from e
[docs]def configure_lr_scheduler(optimizer, learning_rate): """ Configures a learning rate scheduler specified using the provided lr_scheduler type The learning rate scheduler's class's signature is inspected and relevant parameters are extracted from the keyword arguments Args: lr_scheduler_type: The name of the lr_scheduler to configure optimizer: The optimizer passed to the lr_scheduler """ if not learning_rate: return None learning_rate = deepcopy(learning_rate) lr_scheduler_map = { cls.__name__.lower(): cls for cls in _retrieve_all_subclasses(lr_scheduler.LRScheduler) } def get_scheduler(learning_rate: dict): scheduler = learning_rate.pop("scheduler").lower() for name in (scheduler, f"{scheduler}lr"): if name in lr_scheduler_map: cls = lr_scheduler_map[name] break else: raise ValueError( f"Invalid lr_scheduler type. Expected one of " f"{list(lr_scheduler_map.keys())}. Got: {scheduler}" ) # common aliases aliases = { "total_iters": ["steps", "decay_steps"], "initial_learning_rate": ["learning_rate"], "learning_rates": ["values"], "milestones": ["boundaries"], } warning_str = ( f"{cls.__name__} expected parameter {{expected}} " f"but found {{actual}}. " f"It will accept {{actual}} for now and assign it to {{expected}}, " f"but this behaviour is deprecated and will be removed " f"in a future release" ) cls_kwargs = {} # inspect the optimizer and extract the required parameters from the kwargs signature = inspect.signature(cls.__init__) for name, parameter in signature.parameters.items(): if name in ("self", "optimizer"): continue # pylint: disable=protected-access if parameter.kind == inspect._ParameterKind.VAR_KEYWORD: cls_kwargs.update(learning_rate) break if name in learning_rate: cls_kwargs[name] = learning_rate.pop(name) elif name.lower() in learning_rate: logging.warning( warning_str.format(expected=name, actual=name.lower()) ) cls_kwargs[name] = learning_rate.pop(name.lower()) elif name in aliases: for alias in aliases[name]: if alias in learning_rate: logging.warning( warning_str.format(expected=name, actual=alias) ) cls_kwargs[name] = learning_rate.pop(alias) break if len(learning_rate) > 0: # Replace the default values in the signature to show the user the # values they passed in in the warning message so that they can verify # what they actually passed in signature = signature.replace( parameters=[ inspect.Parameter( name=name, kind=param.kind, default=cls_kwargs.get(name, param.default), annotation=param.annotation, ) for name, param in signature.parameters.items() if name != "self" ] ) logging.warning( f"{cls.__name__} got {len(learning_rate)} unexpected " f"and unused parameters: {sorted(learning_rate.keys())}.\n" f"Please ensure that you specified the correct parameters:\n" f"{cls.__name__}{signature}\n" f"Passing in unused parameters is deprecated behaviour and " f"support for it will be removed in a future release." ) try: return cls(optimizer, **cls_kwargs) except TypeError as e: raise RuntimeError( f"Failed to configure {cls.__name__} scheduler" ) from e if isinstance(learning_rate, (float, str)): return None # No learning rate scheduler needed if isinstance(learning_rate, dict): return get_scheduler(learning_rate) if isinstance(learning_rate, (list, tuple)): if len(learning_rate) == 1: return get_scheduler(learning_rate[0]) schedulers = [] main_scheduler = None for params in learning_rate: # TODO: figure out a better way to specify this if main_scheduler is None and "main_scheduler" in params: main_scheduler = params["main_scheduler"] schedulers.append(get_scheduler(params)) if main_scheduler is not None and main_scheduler.lower() in ( "chained", "chainedlr", ): return lr_scheduler.ChainedScheduler(schedulers) else: # default to sequential total_iters = [ scheduler.total_iters for scheduler in schedulers[:-1] ] assert all(total_iter is not None for total_iter in total_iters) milestones = list(numpy.array(total_iters).cumsum()) return lr_scheduler.SequentialLR(optimizer, schedulers, milestones) raise ValueError( f"Unsupported LR scheduler type." f"Expected one of float/dict/list/tuple. " f"Got: {learning_rate}" )
__all__ = [cls.__name__ for cls in _retrieve_all_subclasses(Optimizer)] + [ "configure_optimizer", "configure_lr_scheduler", "lr_scheduler", ]