Source code for common.pytorch.optim.CSOptimizer

# Copyright 2022 Cerebras Systems.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Abstract base class for Cerebras Optimizers.
from abc import ABC, abstractmethod
from collections import defaultdict

import torch
from torch.optim import Optimizer

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch

[docs]class CSOptimizer(Optimizer, ABC): """ Cerebras Base Optimizer class """
[docs] def __init__(self, params, defaults, enable_global_step=False): """ Cerebras Base Optimizer class handles preinitialization of optimizer states for non-CS runs, making the implementation of the optimizer compatible with both CS and non-CS runs. It also preinitializes global steps tensor and provides a method to retrieve the global steps. """ super(CSOptimizer, self).__init__(params, defaults) if cm.use_cs(): # Add progress updates for optimizer state initialization progress = cbtorch.state().progress_tracker if progress is not None: progress.set_description( f"Initializing {self.__class__.__name__} optimizer" ) progress.set_postfix(note="Preinitializing optimizer state") progress.update() param_num = 0 class LoggedDict(dict): def __init__(self, *args, **kwargs): nonlocal param_num super().__init__(*args, **kwargs) self.param_num = param_num param_num += 1 def __setitem__(self, name, value): progress.set_postfix( note=f"Initialized optimizer.state.{self.param_num}.{name}" ) progress.update() return super().__setitem__(name, value) def update(self, *args, **kwargs): for k, v in dict(*args, **kwargs).items(): self[k] = v self.state = defaultdict(LoggedDict) self.preinitialize() # Need to change state back into a normal dict # so that it can be saved to a checkpoint if cm.use_cs() and progress is not None: progress.update() for group in self.param_groups: for p in group['params']: self.state[p] = dict(self.state[p]) if enable_global_step: for group in self.param_groups: for p in group['params']: self.state[p]["step"] = torch.tensor( 0.0, device="cpu", dtype=torch.float32 ).to(p.device) self.post_load_state_dict()
[docs] def load_state_dict(self, state_dict): super().load_state_dict(state_dict) self.post_load_state_dict()
[docs] def post_load_state_dict(self): """ Actions to perform after initializing state and loading the state dict """ def tensor_cast(value): if isinstance(value, int): value = torch.tensor(value, dtype=torch.int32) elif isinstance(value, float): value = torch.tensor(value, dtype=torch.float32) elif isinstance(value, (list, tuple)): value = type(value)(map(tensor_cast, value)) return value # Convert all python scalars in the param groups to 32 bit torch tensors for param_group in self.param_groups: keys = list(param_group.keys()) for key in keys: if key == "params": for p in param_group["params"]: state_names = list(self.state[p].keys()) for name in state_names: value = self.state[p].pop(name) self.state[p][name] = tensor_cast(value) else: value = param_group.pop(key) param_group[key] = tensor_cast(value)
def _get_global_step(self, p): """ Increases the global steps by 1 and returns the current value of global step tensor in torch.float32 format. """ self.state[p]["step"] += 1.0 global_step = self.state[p]["step"] return global_step
[docs] @abstractmethod def state_names_to_sparsify(self): """ Return the names of of per-parameter states that need to be sparsified when applying sparsity to the underlying parameters. """
[docs] @abstractmethod def preinitialize(self): """ Allocates tensors for the optimizer state to allow direct compilation of the model before the first step. """ raise NotImplementedError( "preinitialize must be implemented in a child class!" )
[docs] @abstractmethod def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ raise NotImplementedError("step must be implemented in a child class!")
[docs] def to(self, device=None): """Moves optimizer state onto specified device or onto corresponding parameter's device if no device is specified. Args: device (optional): Device to move state tensors to. If not specified, the corresponding parameter's device will be used. Returns: self """ for group in self.param_groups: for p in group["params"]: state = self.state[p] to_device = device if device is not None else p.device for key in state: state[key] = state[key].to(to_device) return self