Source code for common.pytorch.gradient_clipper

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from modelzoo.common.pytorch import amp


[docs]class GradientClipper:
[docs] def __init__( self, max_gradient_norm: float = 0.0, max_gradient_value: float = 0.0 ): self.max_gradient = 0.0 self.max_gradient_fn = None self.set_max_gradients(max_gradient_norm, max_gradient_value)
[docs] def set_max_gradients( self, max_gradient_norm: float = 0.0, max_gradient_value: float = 0.0 ): if max_gradient_norm is None or max_gradient_norm < 0.0: raise ValueError( f"max_gradient_norm cannot be none or negative. Got " f"{max_gradient_norm}" ) if max_gradient_value is None or max_gradient_value < 0.0: raise ValueError( f"max_gradient_value cannot be none or negative. Got " f"{max_gradient_value}" ) if max_gradient_norm > 0.0 and max_gradient_value > 0.0: raise ValueError( f"Gradients can be clipped by norm(={max_gradient_norm}) or by " f"value(={max_gradient_value}), but not both. " f"Do not set both `max_gradient_norm` and `max_gradient_value`." ) elif max_gradient_norm > 0.0: self.max_gradient = max_gradient_norm self.max_gradient_fn = torch.nn.utils.clip_grad_norm_ elif max_gradient_value > 0.0: self.max_gradient = max_gradient_value self.max_gradient_fn = torch.nn.utils.clip_grad_value_ self.check_amp()
[docs] def check_amp(self): """Disable GGC here if GGC + DLS is enabled by the GradScaler""" if ( self.max_gradient > 0 and amp.get_init_params().get("max_gradient_norm", None) is not None ): assert self.max_gradient_fn == torch.nn.utils.clip_grad_norm_ self.max_gradient_fn = None
[docs] def clip(self, params: dict): if self.max_gradient_fn is not None: self.max_gradient_fn(params, self.max_gradient)
def __call__(self, *args, **kwargs): return self.clip(*args, **kwargs)