# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Tuple
import torch
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.gpt2_hf_cs import (
    ConfigConverter_GPT2Model_HF_CS20,
    Converter_GPT2LMHeadModel_CS20_CS21,
    Converter_GPT2Model_HF_CS17,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.helper import (
    Build_HF_CS_Converter_WithOptionalModel,
    transpose_key_if_2D,
)
[docs]class Converter_Santacoder_Attention_HF_CS(BaseCheckpointConverter_HF_CS):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [
                    EquivalentSubkey("c_proj", "proj_output_dense_layer"),
                    r"\.(?:weight|bias)",
                ],
                action=transpose_key_if_2D,
            ),
            ConversionRule(
                [
                    EquivalentSubkey("q_attn", "proj_q_dense_layer"),
                    r"\.(?:weight|bias)",
                ],
                action=transpose_key_if_2D,
            ),
            ConversionRule(
                [
                    EquivalentSubkey("kv_attn", "proj_k_dense_layer"),
                    r"\.(?:weight|bias)",
                ],
                action=self.kv_attn_converter,
            ),
            ConversionRule(
                [
                    EquivalentSubkey("kv_attn", "proj_v_dense_layer"),
                    r"\.(?:weight|bias)",
                ],
                action=self.assert_already_converted,
            ),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-X.X"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return None
    def kv_attn_converter(
        self,
        old_key,
        new_key,
        old_state_dict,
        new_state_dict,
        from_index,
        action_fn_args,
    ):
        if from_index == 0:
            self.kv_attn_converter_hf_to_cs(
                old_key, new_key, old_state_dict, new_state_dict, action_fn_args
            )
        else:
            self.kv_attn_converter_cs_to_hf(
                old_key, new_key, old_state_dict, new_state_dict, action_fn_args
            )
    def kv_attn_converter_hf_to_cs(
        self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args
    ):
        # HF represents K and V in a packed format. We need to unpack the
        # weight and bias tensor for CS format.
        k_key = new_key
        v_key = re.sub(r"\.proj_k_dense_layer\.", ".proj_v_dense_layer.", k_key)
        if new_key.endswith(".bias"):
            assert len(old_state_dict[old_key].shape) == 1
            (new_state_dict[k_key], new_state_dict[v_key],) = torch.chunk(
                old_state_dict[old_key], 2, dim=0
            )
        elif new_key.endswith(".weight"):
            (new_state_dict[k_key], new_state_dict[v_key],) = torch.chunk(
                torch.transpose(old_state_dict[old_key], 0, 1), 2, dim=0
            )
        else:
            raise ValueError("Invalid key after conversion: {}".format(new_key))
    def kv_attn_converter_cs_to_hf(
        self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args,
    ):
        # HF represents K and V in a packed format. It also contains
        # special ".bias" and ".masked_bias" register buffers that need to be
        # initialize
        k_key = old_key
        v_key = re.sub(r"\.proj_k_dense_layer\.", ".proj_v_dense_layer.", k_key)
        assert (
            k_key in old_state_dict
        ), "Expected the following key to exist! {}".format(k_key)
        assert (
            v_key in old_state_dict
        ), "Expected the following key to exist! {}".format(v_key)
        new_state_dict[new_key] = torch.cat(
            (old_state_dict[k_key], old_state_dict[v_key],), dim=0,
        )
        # Need to transpose to convert from Linear.weight -> Conv1D.weight
        if len(new_state_dict[new_key].shape) == 2:
            new_state_dict[new_key] = torch.transpose(
                new_state_dict[new_key], 0, 1
            )
        if new_key.endswith(".bias"):
            max_position_embeddings = action_fn_args["configs"][1]["model"][
                "max_position_embeddings"
            ]
            attn_bias_key = re.sub(r"\.kv_attn\.", ".", new_key)
            new_state_dict[attn_bias_key] = torch.tril(
                torch.ones(
                    (max_position_embeddings, max_position_embeddings),
                    dtype=torch.uint8,
                )
            ).view(1, 1, max_position_embeddings, max_position_embeddings)
            masked_bias_key = re.sub(r"\.kv_attn\.", ".masked_", new_key)
            new_state_dict[masked_bias_key] = torch.tensor(-1e4)
    def assert_already_converted(
        self,
        old_key,
        new_key,
        old_state_dict,
        new_state_dict,
        from_index,
        action_fn_args,
    ):
        if from_index == 0:
            # We should never hit this case as this key should have been matched
            # already
            assert False, "Invalid key: {}".format(old_key)
        else:
            # When we convert from CS -> HF, the proj_q_dense_layer should also handle
            # conversion of proj_k_dense_layer and proj_v_dense_layer since HF
            # represents these three layers in a packed format. We simply need
            # to test that the key containing the packed format has already
            # been converted.
            assert (
                new_key in new_state_dict
            ), "Key should've been already converted: {} -> {}".format(
                old_key, new_key
            ) 
# This is a base converter for Santacoder that inherits from GPT-2
# CS17 converter that contains most of the rules necessary for
# converting GPT-2 checkpoints. This class is meant to be used as
# an action within the rules of the CS-2.0 converter below,
# that catches checkpoints from Pytorch 2.0 API and PyTorchBaseModel.
# It is not meant for use on its own, because this model was not
# included in the codebase before release 2.0. Note that we include a
# a formats() method in this class and the SantacoderLMHeadModel
# converter below because it is a required method, due to the
# declaration as an @abstractmethod in the BaseDictionaryConverter.
# The cs-X.X in the formats() method is meant to call this to attention
[docs]class Converter_SantacoderModel_HF_CS(Converter_GPT2Model_HF_CS17):
    def attention_converter_class(self):
        return Converter_Santacoder_Attention_HF_CS()
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-X.X"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS20 
[docs]class Converter_SantacoderLMHeadModel_HF_CS(BaseCheckpointConverter_HF_CS):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [r"lm_head\.(?:weight|bias)"], action=self.replaceKey,
            ),
            ConversionRule(
                [
                    EquivalentSubkey("transformer.", ""),
                    Converter_SantacoderModel_HF_CS(),
                ],
                action=None,
            ),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-X.X"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS20 
[docs]class Converter_SantacoderModel_HF_CS20(Converter_SantacoderModel_HF_CS):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            # Catch checkpoints from Pytorch 2.0 API
            ConversionRule([Converter_SantacoderModel_HF_CS(),], action=None,),
            # Catch checkpoints from deprecated PyTorchBaseModel
            ConversionRule(
                [
                    EquivalentSubkey("", "model."),
                    Converter_SantacoderModel_HF_CS(),
                ],
                action=None,
            ),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.0"))
    @classmethod
    def converter_note(cls) -> str:
        return (
            "{} GPT2CustomModel <-> {} GPT2LMHeadModel (configured as SantaCoder)\n"
            "The HF model doesn't contain a language model head while the CS "
            "one does. When converting to CS, the exported checkpoint will "
            "contain a language model head initialized to default random "
            "values. When converting to HF, the language model head will be "
            "dropped."
        ).format(cls.formats()[0], cls.formats()[1])
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS20 
[docs]class Converter_SantacoderLMHeadModel_HF_CS20(BaseCheckpointConverter_HF_CS):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            # Catch checkpoints from Pytorch 2.0 API
            ConversionRule(
                [Converter_SantacoderLMHeadModel_HF_CS(),], action=None,
            ),
            # Catch checkpoints from deprecated PyTorchBaseModel
            ConversionRule(
                [
                    EquivalentSubkey("", "model."),
                    Converter_SantacoderLMHeadModel_HF_CS(),
                ],
                action=None,
            ),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.0"))
    @classmethod
    def converter_note(cls) -> str:
        return "{} GPT2LMHeadCustomModel <-> {} GPT2LMHeadModel (configured as SantaCoder)".format(
            cls.formats()[0], cls.formats()[1]
        )
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS20 
[docs]class ConfigConverter_SantacoderModel_HF_CS20(
    ConfigConverter_GPT2Model_HF_CS20
):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [
                    EquivalentSubkey(
                        "scale_attn_by_inverse_layer_idx",
                        "scale_qk_dot_by_layer_idx",
                    )
                ],
                action=self.replaceKey,
            ),
            ConversionRule(
                ["attention_head_type"],
                action=BaseConfigConverter.assert_factory_fn(0, "multiquery"),
            ),
            ConversionRule(
                ["attention_module"],
                action=BaseConfigConverter.assert_factory_fn(
                    1, "multiquery_attention"
                ),
            ),
            ConversionRule(
                ["extra_attention_params"],
                action=BaseConfigConverter.assert_factory_fn(
                    1, {"num_kv_groups": 1}
                ),
            ),
            *self.rules,
        ]
        self.post_convert_defaults[0].update(
            {
                "architectures": ["GPT2LMHeadCustomModel"],
                "attention_head_type": "multiquery",
                "scale_attn_by_inverse_layer_idx": False,
                "scale_attn_weight": True,
                "auto_map": {
                    "AutoConfig": "configuration_gpt2_mq.GPT2CustomConfig",
                    "AutoModelForCausalLM": "modeling_gpt2_mq.GPT2LMHeadCustomModel",
                },
                "model_type": "gpt2",
                "reorder_and_upcast_attn": False,
                "summary_activation": None,
                "summary_first_dropout": 0.1,
                "summary_proj_to_labels": True,
                "summary_type": "cls_index",
                "summary_use_proj": True,
                "torch_dtype": "float32",
                "transformers_version": "4.24.0",
                "use_cache": True,
            },
        )
        self.post_convert_defaults[1].update(
            {
                "position_embedding_type": "learned",
                "attention_module": "multiquery_attention",
                "softmax_dtype_fp32": False,
                "scale_by_layer_index": False,
                "extra_attention_params": {"num_kv_groups": 1},
                "use_projection_bias_in_attention": True,
                "use_ffn_bias_in_attention": True,
                "use_ffn_bias": True,
                "loss_scaling": "num_tokens",
                "use_bfloat16": True,
            },
        ) 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.0")) 
###########################################################
# In CS 2.1, we refactored the embedding layer.
###########################################################
[docs]class Converter_SantacoderLMHeadModel_CS20_CS21(
    Converter_GPT2LMHeadModel_CS20_CS21
):
    @classmethod
    def converter_note(cls) -> str:
        return "GPT2LMHeadModel class (configured as Santacoder)" 
[docs]class ConfigConverter_SantacoderModel_HF_CS21(
    ConfigConverter_SantacoderModel_HF_CS20
):
[docs]    def __init__(self) -> None:
        super().__init__()
        del self.post_convert_defaults[1]["use_bfloat16"]
        self.post_convert_defaults[1].update({"fp16_type": "bfloat16"}) 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.1"))
[docs]    def supports_mup_conversion(self):
        return True  
[docs]class Converter_SantacoderModel_WithoutOptionalModel_HF_CS21(
    Converter_SantacoderModel_HF_CS
):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [
                    EquivalentSubkey(
                        "wpe", "embedding_layer.position_embeddings.embed"
                    ),
                    "\.(?:weight|bias)",
                ],
                action=self.replaceKey,
            ),
            *self.rules,
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.1"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS21
    @classmethod
    def converter_note(cls) -> str:
        return (
            "{} GPT2CustomModel <-> {} GPT2LMHeadModel (configured as SantaCoder)\n"
            "The HF model doesn't contain a language model head while the CS "
            "one does. When converting to CS, the exported checkpoint will "
            "contain a language model head initialized to default random "
            "values. When converting to HF, the language model head will be "
            "dropped."
        ).format(cls.formats()[0], cls.formats()[1]) 
Converter_SantacoderModel_HF_CS21 = Build_HF_CS_Converter_WithOptionalModel(
    "Converter_SantacoderModel_HF_CS21",
    Converter_SantacoderModel_WithoutOptionalModel_HF_CS21,
    derived_class=Converter_SantacoderModel_WithoutOptionalModel_HF_CS21,
)
[docs]class Converter_SantacoderLMHeadModel_WithoutOptionalModel_HF_CS21(
    BaseCheckpointConverter_HF_CS
):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [r"lm_head\.(?:weight|bias)"], action=self.replaceKey,
            ),
            ConversionRule(
                [
                    EquivalentSubkey("transformer.", ""),
                    Converter_SantacoderModel_WithoutOptionalModel_HF_CS21(),
                ],
                action=None,
            ),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.1"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_SantacoderModel_HF_CS21
    @classmethod
    def converter_note(cls) -> str:
        return "{} GPT2LMHeadCustomModel <-> {} GPT2LMHeadModel (configured as SantaCoder)".format(
            cls.formats()[0], cls.formats()[1]
        )
    def supports_mup_conversion(self):
        return True 
Converter_SantacoderLMHeadModel_HF_CS21 = Build_HF_CS_Converter_WithOptionalModel(
    "Converter_SantacoderLMHeadModel_HF_CS21",
    Converter_SantacoderLMHeadModel_WithoutOptionalModel_HF_CS21,
    derived_class=Converter_SantacoderLMHeadModel_WithoutOptionalModel_HF_CS21,
)