"""contains the Cerebras Adam and AdamW implementation"""importmathimportsysfromtypingimportCallable,Iterable,Tupleimporttorchimportcerebras_pytorch.experimentalascstorchfrom.optimizerimportOptimizerclassAdamBase(Optimizer):""" Base for Adam and AdamW optimizer implemented to conform to execution within the constraints of the Cerebras WSE, including pre-initilizing optimizer state and performing a gradual reduction of bias correction using exponential decay of `beta1_power` and `beta2_power` rather than recomputing `beta1^step` each step. """def__init__(self,params:Iterable[torch.nn.Parameter],lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-6,weight_decay:float=0.0,l2_regularization_rate:float=0.0,correct_bias:bool=True,amsgrad:bool=False,):iflr<0.0:raiseValueError(f"Invalid learning rate: {lr} - should be >= 0.0")ifnot0.0<=betas[0]<1.0:raiseValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0]")ifnot0.0<=betas[1]<1.0:raiseValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0]")ifeps<0.0:raiseValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay,l2_regularization_rate=l2_regularization_rate,correct_bias=correct_bias,amsgrad=amsgrad,)super().__init__(params,defaults)defstate_names_to_sparsify(self):# Only return state names which can be maskable by sparsity optimizer:# those with the same shape as their corresponding parameterreturn["exp_avg","exp_avg_sq","max_exp_avg_sq"]defpreinitialize(self):""" Allocates tensors for the optimizer state to allow direct compilation of the model before the first step. """forgroupinself.param_groups:forpingroup["params"]:state=self.state[p]# State initialization# Exponential moving average of gradient valuesstate["exp_avg"]=cstorch.zeros_like(p)# Exponential moving average of squared gradient valuesstate["exp_avg_sq"]=cstorch.zeros_like(p)ifgroup["amsgrad"]:state["max_exp_avg_sq"]=cstorch.zeros_like(p)ifgroup["correct_bias"]:# No bias correction for Bertbeta1,beta2=group["betas"]# beta1 ^ step, initialized for used on step 1state["beta1_power"]=torch.tensor(beta1).to(p.device)state["beta2_power"]=torch.tensor(beta2).to(p.device)@torch.no_grad()defstep(self,closure:Callable=None):""" Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:loss=closure()forgroupinself.param_groups:forpingroup["params"]:ifp.gradisNone:continuegrad=p.grad# This is equivalent to Algorithm 2 i.e Adam with L2 regularization# (https://arxiv.org/pdf/1711.05101.pdf)ifgroup["l2_regularization_rate"]>0.0:grad=grad.add(p,alpha=group["l2_regularization_rate"])state=self.state[p]exp_avg,exp_avg_sq=state["exp_avg"],state["exp_avg_sq"]beta1,beta2=group["betas"]# Decay the first and second moment running average coefficient# In-place operations to update the averages at the same time.exp_avg.mul_(beta1).add_(grad,alpha=1.0-beta1)exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value=1.0-beta2)ifgroup["amsgrad"]:max_exp_avg_sq=state["max_exp_avg_sq"]torch.maximum(max_exp_avg_sq,exp_avg_sq,out=max_exp_avg_sq)state["max_exp_avg_sq"]=max_exp_avg_sqdenom=max_exp_avg_sq.sqrt().add_(group["eps"])else:denom=exp_avg_sq.sqrt().add_(group["eps"])update=exp_avg/denomifgroup["correct_bias"]:# No bias correction for Bert.one=torch.tensor(1.0,dtype=torch.float32,device=p.device)bias_correction1=one-state["beta1_power"]bias_correction2=one-state["beta2_power"]step_size=torch.sqrt(bias_correction2)/bias_correction1update*=step_size# Update `beta1^step` for the next step.state["beta1_power"]*=beta1state["beta2_power"]*=beta2# Applying weight decay here is equivalent to Algorithm 2# (https://arxiv.org/pdf/1711.05101.pdf)# Decoupled Weight Decay regularization i.e AdamWifgroup["weight_decay"]>0.0:update.add_(p,alpha=group["weight_decay"])# Scale the update by the learning rate.update*=group["lr"]# Finally, update the weight data.p.sub_(update)returnloss# pylint: disable=no-self-usedefconvert_state_dict_for_checkpoint(self,state_dict):""" Converts the state_dict for compatibility with AdamW from huggingface_common, which is the optimizer used by PyTorchBaseModel when not run on WSE and is otherwise API compatible. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. Returns the modified state_dict. """# huggingface AdamW and PyTorch Adam stores a `step`# Cerebras (this) AdamW/Adam stores `beta1^step` as `beta1_power`.forparam,stateinstate_dict["state"].items():if"beta1_power"notinstate:# huggingface AdamW increments `step` always, but doesn't use# it if isn't performing bias correction. We don't have the# step available, so save a dummy value.state["step"]=0continueforparam_groupsinstate_dict["param_groups"]:ifparaminparam_groups["params"]:beta1,beta2=param_groups["betas"]# beta1_power = beta1**step# so step = log(beta1_power, beta1)ifstate["beta1_power"]:# With default beta1=0.9, this should be finite (and# have the most resolution) until ~6700 steps.beta1_power=state["beta1_power"]state["step"]=int(math.log(beta1_power,beta1))else:# if beta1_power is 0, it likely underflowed. Check# beta2_power. Otherwise, use DOUBLE_MIN# With default beta2=0.999, this should be finite# until ~700k step.beta2_power=state["beta2_power"]orsys.float_info.minstate["step"]=int(math.log(beta2_power,beta2))breakreturnstate_dictdefload_state_dict(self,state_dict):""" Loads the optimizer state. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. This overrides torch.optim.Optimizer to add checkpoint compatibility with the AdamW from huggingface_common, which is otherwise API compatible. """# huggingface AdamW and PyTorch Adam stores a `step`# Cerebras (this) AdamW/Adam stores `beta1^step` as `beta1_power`.forparam,stateinstate_dict["state"].items():if"step"instateand"beta1_power"notinstate:step=state.pop("step")# go find betas for this parametercorrect_bias=Falsebeta1=Nonebeta2=Noneforparam_groupinstate_dict["param_groups"]:ifparaminparam_group["params"]:correct_bias=param_group["correct_bias"]beta1,beta2=param_group["betas"]breakifcorrect_bias:state["beta1_power"]=torch.tensor(beta1**step,dtype=torch.float32)state["beta2_power"]=torch.tensor(beta2**step,dtype=torch.float32)super().load_state_dict(state_dict)
[docs]classAdamW(AdamBase):""" AdamW specific overrides to AdamBase """
[docs]defload_state_dict(self,state_dict):""" Loads the optimizer state. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. Adds checkpoint compatibility with the AdamW from HuggingFace """forgroupinstate_dict["param_groups"]:group["l2_regularization_rate"]=0.0super().load_state_dict(state_dict)
[docs]classAdam(AdamBase):""" Adam specific overrides to AdamBase """
[docs]def__init__(self,params:Iterable[torch.nn.Parameter],lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-6,weight_decay:float=0.0,amsgrad:bool=False,):# This init uses `weight_decay` to be in sync with PyTorch APIsuper().__init__(params=params,lr=lr,betas=betas,eps=eps,weight_decay=0.0,l2_regularization_rate=weight_decay,correct_bias=True,amsgrad=amsgrad,)forgroupinself.param_groups:group["l2_regularization_rate"]=group.pop("weight_decay",0.0)group["weight_decay"]=0.0
[docs]defload_state_dict(self,state_dict):""" Loads the optimizer state. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. Adds checkpoint compatibility with the Adam from PyTorch """forgroupinstate_dict["param_groups"]:group["l2_regularization_rate"]=group.pop("weight_decay",0.0)group["weight_decay"]=0.0group["correct_bias"]=Truesuper().load_state_dict(state_dict)