"""The Cerebras base optimizer class"""
from abc import ABC, abstractmethod
import torch
from cerebras_pytorch.experimental.backend import current_backend_impl
[docs]class Optimizer(torch.optim.Optimizer, ABC):
"""
The abstracct Cerebras base optimizer class.
Enforces that the `preinitialize` method is implemented
wherein the optimizer state should be initialized ahead of time
"""
[docs] def __init__(self, *args, enable_global_step: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.backend = current_backend_impl()
with self.backend.device:
self.preinitialize()
if enable_global_step:
for group in self.param_groups:
for p in group["params"]:
self.state[p]["step"] = torch.tensor(
0.0, dtype=torch.float32
).to(p.device)
self._lr_scheduler_registry = []
self.backend.register_optimizer(self)
[docs] def increment_global_step(self, p):
"""
Increases the global steps by 1 and returns the current
value of global step tensor in torch.float32 format.
"""
if "step" not in self.state[p]:
raise RuntimeError(
"No global step in the state. "
"Please pass in `enable_global_step=True` "
"to initialize the global step"
)
self.state[p]["step"] += 1.0
return self.state[p]["step"]
[docs] def state_dict(self, *args, **kwargs):
s = super().state_dict(*args, **kwargs)
return s
[docs] def load_state_dict(self, state_dict):
with self.backend.device:
super().load_state_dict(state_dict)
self.backend.post_optimizer_load_state_dict(self)
[docs] def visit_state(self, fn):
"""
Applies a lambda to each stateful value.
"""
for state in self.state.values():
for key, val in state.items():
new_val = fn(val)
if new_val is not None:
state[key] = new_val
[docs] @abstractmethod
def state_names_to_sparsify(self):
"""
Return the names of of per-parameter states that need to be sparsified
when applying sparsity to the underlying parameters.
"""
[docs] @abstractmethod
def preinitialize(self):
"""
The optimizer state must be initialized ahead of time in order
to capture the full compute graph in the first iteration. This method
must be overriden to perform the state preinitialization
"""
[docs] @abstractmethod
def step(self, closure=None):
"""
Perform the optimizer step itself. Note, there should be no new state
being created in this function. All state must be created ahead of time in
`preinitialize` and only updated in this method.
"""