# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from modelzoo.common.pytorch.layers import FeedForwardNetwork
from modelzoo.transformers.pytorch.bert.bert_model import BertModel
[docs]class BertClassifierHead(nn.Module):
[docs]    def __init__(
        self,
        hidden_size=768,
        num_classes=2,
        use_bias=True,
        kernel_initializer="xavier_uniform",
    ):
        super().__init__()
        self.classifier = FeedForwardNetwork(
            input_unit=hidden_size,
            layers_units=[num_classes],
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
        ) 
    def reset_parameters(self):
        self.classifier.reset_parameters()
    def forward(self, pooled_output):
        cls_logits = self.classifier(pooled_output)
        return cls_logits 
[docs]class BertMLMHead(nn.Module):
[docs]    def __init__(
        self,
        hidden_size=768,
        vocab_size=30522,
        embedding_size=None,
        use_ffn_bias_in_mlm=True,
        use_output_bias_in_mlm=True,
        activation="gelu",
        layer_norm_epsilon=1.0e-5,
        dropout=None,
        kernel_initializer="xavier_uniform",
    ):
        super().__init__()
        if embedding_size is None:
            embedding_size = hidden_size
        self.mlm_transform = BertMLMHeadTransform(
            hidden_size=hidden_size,
            embedding_size=embedding_size,
            use_ffn_bias_in_mlm=use_ffn_bias_in_mlm,
            activation=activation,
            layer_norm_epsilon=layer_norm_epsilon,
            dropout=dropout,
            kernel_initializer=kernel_initializer,
        )
        self.classifier = FeedForwardNetwork(
            input_unit=embedding_size,
            layers_units=[vocab_size],
            use_bias=use_output_bias_in_mlm,
            kernel_initializer=kernel_initializer,
        ) 
    def reset_parameters(self):
        self.mlm_transform.reset_parameters()
        self.classifier.reset_parameters()
    def get_output_embeddings(self):
        return self.classifier.ffn[0].linear_layer
    def forward(self, bert_output):
        mlm_embeddings = self.mlm_transform(bert_output)
        mlm_logits = self.classifier(mlm_embeddings)
        return mlm_logits 
[docs]class BertPretrainModel(nn.Module):
    """
    Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head. Following the paper: https://arxiv.org/abs/1810.04805.
    """
[docs]    def __init__(
        self,
        disable_nsp=False,
        mlm_loss_weight=1.0,
        label_smoothing=0.0,
        num_classes=2,
        mlm_nonlinearity=None,
        # Embedding
        vocab_size=50257,
        max_position_embeddings=1024,
        position_embedding_type="learned",
        embedding_pad_token_id=0,
        mask_padding_in_positional_embed=False,
        rotary_dim=None,
        rope_theta=10000,
        num_relative_attention_buckets=32,
        alibi_trainable_slopes=False,
        pos_scaling_factor=1.0,
        hidden_size=768,
        share_embedding_weights=True,
        # Encoder
        num_hidden_layers=12,
        layer_norm_epsilon=1.0e-5,
        # Encoder Attn
        num_heads=12,
        attention_module="aiayn_attention",
        extra_attention_params={},
        attention_type="scaled_dot_product",
        attention_softmax_fp32=True,
        dropout_rate=0.1,
        nonlinearity="gelu",
        pooler_nonlinearity=None,
        attention_dropout_rate=0.1,
        use_projection_bias_in_attention=True,
        use_ffn_bias_in_attention=True,
        # Encoder ffn
        filter_size=3072,
        use_ffn_bias=True,
        use_ffn_bias_in_mlm=True,
        use_output_bias_in_mlm=True,
        # Task-specific
        initializer_range=0.02,
        num_segments=2,
    ):
        """
        Args:
            disable_nsp (:obj:`bool` `optional`, defaults to False):
                Whether to disable next-sentence-prediction and only use masked-language-model.
            mlm_loss_weight (:obj:`float` `optional`, defaults to 1.0):
                The scaling factor for masked-language-model loss.
            label_smoothing (:obj:`float` `optional`, defaults to 0.0):
                The label smoothing factor used during training.
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.disable_nsp = disable_nsp
        self.share_embedding_weights = share_embedding_weights
        self.initializer_range = initializer_range
        self.bert_encoder = BertModel(
            # Embedding
            vocab_size=vocab_size,
            max_position_embeddings=max_position_embeddings,
            position_embedding_type=position_embedding_type,
            hidden_size=hidden_size,
            embedding_dropout_rate=dropout_rate,
            embedding_pad_token_id=embedding_pad_token_id,
            mask_padding_in_positional_embed=mask_padding_in_positional_embed,
            rotary_dim=rotary_dim,
            rope_theta=rope_theta,
            num_relative_attention_buckets=num_relative_attention_buckets,
            alibi_trainable_slopes=alibi_trainable_slopes,
            pos_scaling_factor=pos_scaling_factor,
            # Encoder
            num_hidden_layers=num_hidden_layers,
            layer_norm_epsilon=layer_norm_epsilon,
            # Encoder Attn
            num_heads=num_heads,
            attention_module=attention_module,
            extra_attention_params=extra_attention_params,
            attention_type=attention_type,
            attention_softmax_fp32=attention_softmax_fp32,
            dropout_rate=dropout_rate,
            nonlinearity=nonlinearity,
            pooler_nonlinearity=pooler_nonlinearity,
            attention_dropout_rate=attention_dropout_rate,
            use_projection_bias_in_attention=use_projection_bias_in_attention,
            use_ffn_bias_in_attention=use_ffn_bias_in_attention,
            # Encoder ffn
            filter_size=filter_size,
            use_ffn_bias=use_ffn_bias,
            # Task-specific
            initializer_range=initializer_range,
            num_segments=num_segments,
            add_pooling_layer=(not self.disable_nsp),
        )
        kernel_initializer = {
            "name": "truncated_normal",
            "std": self.initializer_range,
            "mean": 0.0,
            "a": self.initializer_range * -2.0,
            "b": self.initializer_range * 2.0,
        }
        if not self.disable_nsp:
            self.bert_cls_head = BertClassifierHead(
                hidden_size=hidden_size,
                num_classes=num_classes,
                use_bias=use_ffn_bias,
                kernel_initializer=kernel_initializer,
            )
        if mlm_nonlinearity is None:
            mlm_nonlinearity = nonlinearity
        self.bert_mlm_head = BertMLMHead(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            use_ffn_bias_in_mlm=use_ffn_bias_in_mlm,
            use_output_bias_in_mlm=use_output_bias_in_mlm,
            activation=mlm_nonlinearity,
            layer_norm_epsilon=layer_norm_epsilon,
            dropout=None,
            kernel_initializer=kernel_initializer,
        )
        self.tie_weights() 
    def reset_parameters(self):
        self.bert_encoder.reset_parameters()
        if not self.disable_nsp:
            self.bert_cls_head.reset_parameters()
        self.bert_mlm_head.reset_parameters()
    def tie_weights(self):
        if not self.share_embedding_weights:
            return
        output_embedding = self.get_output_embeddings()
        input_embedding = self.get_input_embeddings()
        output_embedding.weight = input_embedding.weight
        if getattr(output_embedding, "bias", None) is not None:
            output_embedding.bias.data = nn.functional.pad(
                output_embedding.bias.data,
                (
                    0,
                    output_embedding.weight.shape[0]
                    - output_embedding.bias.shape[0],
                ),
                "constant",
                0,
            )
        if hasattr(output_embedding, "out_features") and hasattr(
            input_embedding, "num_embeddings"
        ):
            output_embedding.out_features = input_embedding.num_embeddings
    def get_output_embeddings(self):
        return self.bert_mlm_head.get_output_embeddings()
    def get_input_embeddings(self):
        return self.bert_encoder.embedding_layer.get_input_embeddings()
[docs]    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        masked_lm_positions=None,
        should_gather_mlm_labels=False,
    ):
        """
        Args:
            input_ids (Tensor):
                The id of input tokens. Can be of shape ``[batch_size, seq_length]``
            attention_mask (Tensor):
                Can be 2D of shape ``[batch_size, seq_length]``,
                or 3D of shape ``[batch, query_length, seq_length]``,
                or 4D of shape ``[batch, num_heads, query_length, seq_length]``.
            position_ids (Tensor):
                The position id of input tokens. Can be of shape ``[batch_size, seq_length]``
            token_type_ids (Tensor):
                The segment id of input tokens, indicating which sequence the token belongs to.
                Can be of shape ```[batch_size, seq_length]`
            masked_lm_positions (Tensor):
                Position ids of mlm tokens. Shape ``[batch_size, max_predictions_per_seq]``
        """
        mlm_hidden_states, pooled_hidden_states = self.bert_encoder(
            input_ids,
            position_ids=position_ids,
            segment_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        batch_size, seq_len, hidden_size = list(mlm_hidden_states.size())
        focused_mlm_hidden_states = mlm_hidden_states
        if should_gather_mlm_labels:
            batch_size, max_num_pred = list(masked_lm_positions.size())
            index = torch.broadcast_to(
                masked_lm_positions.unsqueeze(-1),
                (batch_size, max_num_pred, hidden_size),
            ).long()
            focused_mlm_hidden_states = torch.gather(
                mlm_hidden_states, dim=1, index=index
            )
        mlm_logits = self.bert_mlm_head(focused_mlm_hidden_states)
        # nsp_logits
        nsp_logits = None
        if not self.disable_nsp:
            nsp_logits = self.bert_cls_head(pooled_hidden_states)
        return mlm_logits, nsp_logits, mlm_hidden_states, pooled_hidden_states