# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Tuple
import torch
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.bert import (
    ConfigConverter_Bert_HF_CS18,
    Converter_BertLayerNorm_HF_CS,
    Converter_BertModel_CS16_CS17,
    Converter_BertModel_WithoutOptionalModel_HF_CS21,
    Converter_BertPretrainModel_HF_CS18,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.helper import (
    Build_HF_CS_Converter_WithOptionalModel,
)
[docs]class Converter_RobertaPretrainModel_HF_CS(BaseCheckpointConverter_HF_CS):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [
                    EquivalentSubkey("roberta.", "bert_encoder."),
                    Converter_BertModel_CS16_CS17(),  # CS16 = HF
                ],
            ),
            # CLS:
            ConversionRule(
                [
                    EquivalentSubkey(
                        "lm_head.dense",
                        "bert_mlm_head.mlm_transform.ffn.ffn.0.linear_layer",
                    ),
                    r"\.(?:weight|bias)",
                ],
                action=self.replaceKey,
            ),
            ConversionRule(
                [
                    EquivalentSubkey(
                        "lm_head.", "bert_mlm_head.mlm_transform.",
                    ),
                    Converter_BertLayerNorm_HF_CS("layer_norm", "ln"),
                ],
                action=None,
            ),
            ConversionRule(
                [
                    EquivalentSubkey(
                        "lm_head.decoder",
                        "bert_mlm_head.classifier.ffn.0.linear_layer",
                    ),
                    r"\.weight",
                ],
                action=self.replaceKey,
            ),
            ConversionRule(
                [
                    EquivalentSubkey(
                        "lm_head.decoder",
                        "bert_mlm_head.classifier.ffn.0.linear_layer",
                    ),
                    r"\.bias",
                ],
                action=self.convert_cls_predictions_bias,
            ),
            ConversionRule([r"lm_head\.bias"], exists="left"),
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return None
    def convert_cls_predictions_bias(
        self,
        old_key,
        new_key,
        old_state_dict,
        new_state_dict,
        from_index,
        action_fn_args,
    ):
        self.replaceKey(
            old_key,
            new_key,
            old_state_dict,
            new_state_dict,
            from_index,
            action_fn_args,
        )
        if from_index == 1:
            # HF stores an extra copy of the decoder bias in the predictions object itself
            bias_key = re.sub(r"\.decoder\.", ".", new_key)
            self.replaceKey(
                old_key,
                bias_key,
                old_state_dict,
                new_state_dict,
                from_index,
                action_fn_args,
            ) 
[docs]class Converter_RobertaPretrainModel_HF_CS18(
    Converter_BertPretrainModel_HF_CS18
):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            # Catch checkpoints from Pytorch 2.0 API
            ConversionRule(
                [Converter_RobertaPretrainModel_HF_CS(),], action=None,
            ),
            # Catch checkpoints from 1.7/1.8
            ConversionRule(
                [
                    EquivalentSubkey("", "model."),
                    Converter_RobertaPretrainModel_HF_CS(),
                ],
                action=None,
            ),
        ] 
    @classmethod
    def converter_note(cls) -> str:
        return "{} <-> {} for RobertaForPreTraining".format(
            cls.formats()[0], cls.formats()[1]
        )
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_Roberta_HF_CS18
[docs]    def post_model_convert(
        self,
        old_state_dict,
        new_state_dict,
        configs,
        from_index,
        drop_unmatched_keys,
        key_prefix="",
    ):
        if from_index == 1:
            num_segments = configs[1]["model"]["num_segments"]
            if not num_segments:
                new_state_dict[
                    key_prefix
                    + "roberta.embeddings.token_type_embeddings.weight"
                ] = torch.zeros(
                    configs[0]["type_vocab_size"], configs[0]["hidden_size"]
                )
        super().post_model_convert(
            old_state_dict,
            new_state_dict,
            configs,
            from_index,
            drop_unmatched_keys,
            key_prefix=key_prefix,
        )  
[docs]class ConfigConverter_Roberta_HF_CS18(ConfigConverter_Bert_HF_CS18):
[docs]    def __init__(self):
        super().__init__()
        # Override Bert's config converter with the following:
        self.rules = [
            ConversionRule(
                ["model_type"],
                action=BaseConfigConverter.assert_factory_fn(0, "roberta"),
            ),
            ConversionRule(
                ["max_position_embeddings"], action=self.convert_max_pos_embed,
            ),
            ConversionRule(
                [EquivalentSubkey("type_vocab_size", "num_segments")],
                action=self.convert_num_segments,
            ),
            ConversionRule(["pad_token_id"], action=self.replaceKey,),
            ConversionRule(
                ["mask_padding_in_positional_embed"],
                action=BaseConfigConverter.assert_factory_fn(1, True),
            ),
            ConversionRule(
                ["disable_nsp"],
                action=BaseConfigConverter.assert_factory_fn(1, True),
            ),
            ConversionRule(
                ["mlm_nonlinearity"],
                action=BaseConfigConverter.assert_factory_fn(1, "gelu"),
            ),
            *self.rules,
        ]
        self.pre_convert_defaults[0].update(
            {
                "vocab_size": 50265,
                "position_embedding_type": "absolute",
                "type_vocab_size": 2,
                "pad_token_id": 1,
            }
        )
        self.pre_convert_defaults[1].update(
            {
                "disable_nsp": False,
                "pad_token_id": 0,
                "mask_padding_in_positional_embed": False,
            }
        )
        self.post_convert_defaults[0].update({"model_type": "roberta"})
        self.post_convert_defaults[1].update(
            {"disable_nsp": True, "mask_padding_in_positional_embed": True,}
        ) 
    def convert_num_segments(
        self,
        old_key,
        new_key,
        old_state_dict,
        new_state_dict,
        from_index,
        action_fn_args,
    ):
        # CS allows segment embeddings to be disabled while HF doesn't
        # When it is disabled in CS, we need to enable it in HF and set the
        # embedding weight to zero
        if from_index == 1 and old_state_dict[old_key] == 0:
            new_state_dict[new_key] = 1
        else:
            new_state_dict[new_key] = old_state_dict[old_key]
    def convert_max_pos_embed(
        self,
        old_key,
        new_key,
        old_state_dict,
        new_state_dict,
        from_index,
        action_fn_args,
    ):
        # The number of positional embeddings = MSL + pad token offset + 1
        # HF refers to number of positional embeddings (the total) as
        # max_position_embeddings while we refer to MSL as
        # max_position_embeddings
        if from_index == 0:
            new_state_dict[new_key] = (
                old_state_dict[old_key] - old_state_dict["pad_token_id"] - 1
            )
        else:
            new_state_dict[new_key] = (
                old_state_dict[old_key] + old_state_dict["pad_token_id"] + 1
            )
    def pre_config_convert(
        self, config, from_index,
    ):
        config = super().pre_config_convert(config, from_index)
        if from_index == 1:
            if "num_segments" not in config:
                config["num_segments"] = 0 if config["disable_nsp"] else 2
        return config
    def post_config_convert(
        self,
        original_config,
        old_config,
        new_config,
        from_index,
        drop_unmatched_keys,
    ):
        if from_index == 0:
            new_config["mlm_nonlinearity"] = "gelu"
        return super().post_config_convert(
            original_config,
            old_config,
            new_config,
            from_index,
            drop_unmatched_keys,
        ) 
###########################################################
# In CS 2.1, we refactored the embedding layer.
# CS 2.0 <> CS 2.1, and HF <> CS 2.1 converters:
###########################################################
[docs]class ConfigConverter_Roberta_HF_CS21(ConfigConverter_Roberta_HF_CS18):
    "CS 2.1 config is the same as CS 2.0"
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.1")) 
[docs]class Converter_RobertaPretrainModel_WithoutOptionalModel_HF_CS21(
    Converter_RobertaPretrainModel_HF_CS
):
[docs]    def __init__(self):
        super().__init__()
        self.rules = [
            ConversionRule(
                [
                    EquivalentSubkey("roberta.", "bert_encoder."),
                    Converter_BertModel_WithoutOptionalModel_HF_CS21(),  # CS16 = HF
                ],
            ),
            *self.rules,
        ] 
    @staticmethod
    def formats() -> Tuple[FormatVersions, FormatVersions]:
        return (FormatVersions("hf"), FormatVersions("cs-2.1"))
    @staticmethod
    def get_config_converter_class() -> BaseConfigConverter:
        return ConfigConverter_Roberta_HF_CS21 
Converter_RobertaPretrainModel_HF_CS21 = Build_HF_CS_Converter_WithOptionalModel(
    "Converter_RobertaPretrainModel_HF_CS21",
    Converter_RobertaPretrainModel_WithoutOptionalModel_HF_CS21,
    derived_class=Converter_RobertaPretrainModel_HF_CS18,
    config_converter_class=ConfigConverter_Roberta_HF_CS21,
    formats=(FormatVersions("hf"), FormatVersions("cs-2.1")),
)