# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import cerebras_pytorch as cstorch
from modelzoo.common.pytorch.layers import (
    EmbeddingLayer,
    TransformerDecoder,
    TransformerDecoderLayer,
)
from modelzoo.common.pytorch.model_utils.norms import get_norm
from modelzoo.transformers.pytorch.gpt2.sparse_mask import (
    create_fixed_sparse_attention_mask,
)
from modelzoo.transformers.pytorch.transformer_utils import (
    build_broadcastable_attention_mask,
    make_sparse_mask_broadcastable,
)
[docs]class GPT2LMHeadModel(nn.Module):
    """
    GPT-2 model with LM head
    """
[docs]    def __init__(
        self,
        # Embedding
        vocab_size=50257,
        max_position_embeddings=1024,
        embd_pdrop=0.1,
        position_embedding_type="learned",
        position_embedding_offset=0,
        hidden_size=768,
        share_embedding_weights=True,
        embedding_layer_norm=False,
        num_relative_attention_buckets=32,
        rotary_dim=None,
        rope_theta=10000,
        # Encoder
        num_hidden_layers=12,
        dropout_rate=0.1,
        norm_type="layernorm",
        layer_norm_epsilon=1.0e-5,
        # Encoder - Attention
        num_heads=12,
        attention_type="scaled_dot_product",
        attention_module="aiayn_attention",
        extra_attention_params={},
        use_projection_bias_in_attention=True,
        use_ffn_bias_in_attention=True,
        attention_dropout_rate=0.1,
        attention_softmax_fp32=True,
        fixed_sparse_attention=None,
        # Encoder - ffn
        filter_size=3072,
        nonlinearity="gelu",
        use_ffn_bias=True,
        # Task-specific
        use_bias_in_output=False,
        initializer_range=0.02,
        embedding_initializer=None,
        initializer=None,
        output_layer_initializer=None,
        # muP (maximal update parameterization)  parameters
        output_logits_scale=None,
        embeddings_scale=1.0,
        scale_qk_dot_by_d=False,
        alibi_trainable_slopes=False,
        pos_scaling_factor=1.0,
        scale_qk_dot_by_layer_idx=False,
    ):
        super(GPT2LMHeadModel, self).__init__()
        # std deviation for weight initialization
        self.initializer_range = initializer_range
        self.num_hidden_layers = num_hidden_layers
        self.share_embedding_weights = share_embedding_weights
        self.embedding_layer_norm = embedding_layer_norm
        self.max_position_embeddings = max_position_embeddings
        self.position_embedding_type = position_embedding_type
        self.embeddings_scale = embeddings_scale
        self.num_heads = num_heads
        if initializer is None:
            attention_initializer = {
                "name": "truncated_normal",
                "mean": 0.0,
                "std": self.initializer_range,
            }
            ffn_initializer = {
                "name": "truncated_normal",
                "mean": 0.0,
                "std": self.initializer_range,
            }
            if output_layer_initializer is None:
                output_layer_initializer = {
                    "name": "truncated_normal",
                    "mean": 0.0,
                    "std": self.initializer_range
                    / math.sqrt(2 * self.num_hidden_layers),
                }
        else:
            attention_initializer = initializer
            ffn_initializer = initializer
        if embedding_initializer is None:
            embedding_initializer = {
                "name": "truncated_normal",
                "mean": 0.0,
                "std": self.initializer_range,
            }
        norm_class = get_norm(norm_type)
        if position_embedding_type == "rotary":
            if rotary_dim is None:
                rotary_dim = hidden_size // num_heads
            # https://github.com/huggingface/transformers/blob/f0577df6de36e7e7f28e90fa76da0657de038a39/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L84-L85
            # https://arxiv.org/pdf/2104.09864.pdf Section 3.3
            assert (
                rotary_dim <= hidden_size / num_heads
            ), "Rotary dimensions should be <= hidden size divided by number of attention heads."
            assert (
                rotary_dim % 2 == 0
            ), "Rotary dimension must be an even number."
        self.embedding_layer = EmbeddingLayer(
            vocab_size=vocab_size,
            embedding_size=hidden_size,
            embeddings_initializer=embedding_initializer,
            position_embedding_type=position_embedding_type,
            position_embeddings_initializer=embedding_initializer,
            max_position_embeddings=max_position_embeddings,
            position_embedding_offset=position_embedding_offset,
            num_heads=num_heads,
            num_relative_attention_buckets=num_relative_attention_buckets,
            rotary_dim=rotary_dim,
            rope_theta=rope_theta,
            alibi_trainable_slopes=alibi_trainable_slopes,
            pos_scaling_factor=pos_scaling_factor,
        )
        if self.embedding_layer_norm:
            self.embedding_ln_f = norm_class(
                hidden_size, eps=layer_norm_epsilon
            )
        self.drop_embd = nn.Dropout(embd_pdrop)
        decoder_layer = TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=filter_size,
            dropout=dropout_rate,
            activation=nonlinearity,
            layer_norm_eps=layer_norm_epsilon,
            norm_layer=norm_class,
            norm_first=True,
            extra_attention_params=extra_attention_params,
            add_cross_attention=False,
            attention_type=attention_type,
            scale_qk_dot_by_d=scale_qk_dot_by_d,
            scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
            attention_module=attention_module,
            attention_dropout_rate=attention_dropout_rate,
            attention_softmax_fp32=attention_softmax_fp32,
            use_projection_bias_in_attention=use_projection_bias_in_attention,
            use_ffn_bias_in_attention=use_ffn_bias_in_attention,
            use_ffn_bias=use_ffn_bias,
            attention_initializer=attention_initializer,
            attention_output_layer_initializer=output_layer_initializer,
            ffn_initializer=ffn_initializer,
            ffn_output_layer_initializer=output_layer_initializer,
            use_ff_layer1_dropout=False,
        )
        self.output_logits_scale = output_logits_scale
        # Final LayerNorm
        self.ln_f = norm_class(hidden_size, eps=layer_norm_epsilon)
        self.transformer_decoder = TransformerDecoder(
            decoder_layer, num_layers=num_hidden_layers, norm=self.ln_f,
        )
        if fixed_sparse_attention is not None:
            self.fixed_sparsity_mask = create_fixed_sparse_attention_mask(
                max_sequence_length=max_position_embeddings,
                n_heads=num_heads,
                **fixed_sparse_attention,
            )
        else:
            self.fixed_sparsity_mask = None
        self.lm_head = nn.Linear(
            hidden_size, vocab_size, bias=use_bias_in_output
        )
        self.tie_weights()
        self.__reset_parameters() 
    def reset_parameters(self):
        self.embedding_layer.reset_parameters()
        self.transformer_decoder.reset_parameters()
        self.__reset_parameters()
    def __reset_parameters(self):
        # Init final norm layer
        if hasattr(self.ln_f, "bias"):
            self.ln_f.bias.data.zero_()
        self.ln_f.weight.data.fill_(1.0)
        # Initialize LM head
        if not self.share_embedding_weights:
            self.lm_head.weight.data.normal_(
                mean=0.0, std=self.initializer_range
            )
        if self.lm_head.bias is not None:
            self.lm_head.bias.data.zero_()
    def tie_weights(self):
        if not self.share_embedding_weights:
            return
        output_embedding = self.get_output_embeddings()
        input_embedding = self.get_input_embeddings()
        output_embedding.weight = input_embedding.weight
        if getattr(output_embedding, "bias", None) is not None:
            output_embedding.bias.data = nn.functional.pad(
                output_embedding.bias.data,
                (
                    0,
                    output_embedding.weight.shape[0]
                    - output_embedding.bias.shape[0],
                ),
                "constant",
                0,
            )
        if hasattr(output_embedding, "out_features") and hasattr(
            input_embedding, "num_embeddings"
        ):
            output_embedding.out_features = input_embedding.num_embeddings
    def get_output_embeddings(self):
        return self.lm_head
    def get_input_embeddings(self):
        return self.embedding_layer.get_input_embeddings()
    def compute_input_embeddings(self, input_ids, position_ids=None):
        hidden_states = self.embedding_layer(
            input_ids, position_ids=position_ids
        )
        if self.embedding_layer_norm:
            hidden_states = self.embedding_ln_f(hidden_states)
        hidden_states = hidden_states * torch.tensor(
            float(self.embeddings_scale), dtype=hidden_states.dtype
        )
        hidden_states = self.drop_embd(hidden_states)
        return hidden_states
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        attention_span=None,
        position_ids=None,
    ):
        hidden_states = self.compute_input_embeddings(input_ids, position_ids)
        causal_attention_mask = build_broadcastable_attention_mask(
            attention_mask,
            attention_span=attention_span,
            build_causal=True,
            device=input_ids.device,
            dtype=hidden_states.dtype,
            num_heads=self.num_heads,
        )
        # Fixed sparse attention, used in GPT-3 model
        sparse_attention_mask = None
        if self.fixed_sparsity_mask is not None:
            sparse_attention_mask = make_sparse_mask_broadcastable(
                self.fixed_sparsity_mask,
                attention_mask,
                dtype=hidden_states.dtype,
                device=hidden_states.device,
                revert_mask=False,
            )
        # Helpers on alibi/relative position embeddings bias
        length = input_ids.shape[1]
        self_attn_position_bias = self.embedding_layer.compute_position_bias(
            length, length
        )
        hidden_states = self.transformer_decoder(
            hidden_states,
            tgt_mask=causal_attention_mask,
            sparse_mask=sparse_attention_mask,
            rotary_position_embedding_helper=self.embedding_layer.get_rope_helper(),
            self_attn_position_bias=self_attn_position_bias,
        )
        if (
            cstorch.use_cs()
            and cstorch.current_executor().cs_config.precision_opt_level == 1
        ):
            lm_logits = cstorch.pol(bwd_level=0)(self.lm_head)(hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)
        # scale lm_logits for muP transfer
        if self.output_logits_scale:
            lm_logits = lm_logits * torch.tensor(
                float(self.output_logits_scale), dtype=lm_logits.dtype,
            )
        return lm_logits