Source code for common.pytorch.layers.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch

from .AlibiPositionEmbeddingLayer import AlibiPositionEmbeddingLayer
from .RelativePositionEmbeddingLayer import RelativePositionEmbeddingLayer

LOSS_SCOPE = "loss"


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError(
        "activation should be relu/gelu, not {}".format(activation)
    )


[docs]def apply_loss_reduction(loss, reduction): if reduction == 'mean': return torch.mean(loss) elif reduction == 'sum': return torch.sum(loss) else: return loss
[docs]def apply_position_bias(embedding_helper, seq_length, key_length, past_kv=None): self_attn_position_bias = None if isinstance( embedding_helper, (RelativePositionEmbeddingLayer, AlibiPositionEmbeddingLayer,), ): self_attn_position_bias = embedding_helper( seq_length, key_length, past_kv=past_kv ) return self_attn_position_bias
# Since we do not have automatic loss scope detection yet, some changes in user model would # be required to tag the ops belong to loss. # A wrapper to wrap loss function for Autogen support, so that loss function # will be handled by Autogen. # To use this Autogen loss function wrapper, when initializing the loss function, # using our loss in Layer API, and adding a keyword arguement 'use_autogen=True' (the default is False). # For example: # loss = MSELoss(..., use_autogen=True) # # To apply the Autogen loss wrapper to any custom loss function, just add this wrapper as the decorator # of the custom loss class. # For example: # @autogen_loss # class CustomLoss(nn.Module): # def __init__(...): # # In the future, we may remove this temporary wrapper after we developed better technics to support Autogen.
[docs]def autogen_loss(loss_cls): loss_cls._old_init = loss_cls.__init__ loss_cls._old_forward = loss_cls.forward def autogen_init(self, *args, **kwargs): self.autogen_enabled = kwargs.pop("use_autogen", False) self._old_init(*args, **kwargs) if self.autogen_enabled and cm.use_cs(): self.mark_with_autogen = cbtorch.nn.Scope(scope_name=LOSS_SCOPE) def autogen_forward(self, *args, **kwargs): if self.autogen_enabled and cm.use_cs(): args = [self.mark_with_autogen(arg) for arg in args] kwargs = {k: self.mark_with_autogen(v) for k, v in kwargs.items()} loss = self._old_forward(*args, **kwargs) return self.mark_with_autogen.exit(loss) else: return self._old_forward(*args, **kwargs) loss_cls.__init__ = autogen_init loss_cls.forward = autogen_forward return loss_cls