# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Optional, Tuple
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter,
    BaseConfigConverter,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
# CS models may contain an extra 'model.' prefix. During HF -> CS conversion,
# we do not want to output checkpoints with this prefix. In CS -> HF conversion,
# we want to handle both the extra 'model.' prefix and no prefix cases.
[docs]def Build_HF_CS_Converter_WithOptionalModel(
    name,
    converter,
    derived_class,
    config_converter_class=None,
    formats=None,
    converter_note_fn=None,
):
    assert issubclass(
        derived_class, BaseCheckpointConverter
    ), "derived_class parameter must be a subclass of BaseCheckpointConverter"
    class ConverterWithOptionalModel(derived_class):
        def __init__(self) -> None:
            super().__init__()
            self.rules = [
                ConversionRule([converter(),], action=None,),
                # If above did not match, try to apply conversion with stripped
                # 'model.' prefix
                ConversionRule(
                    [EquivalentSubkey("", "model."), converter(),], action=None,
                ),
            ]
    ConverterWithOptionalModel.__name__ = name
    if config_converter_class:
        assert issubclass(
            config_converter_class, BaseConfigConverter
        ), "config_converter_class parameter must be a subclass of BaseConfigConverter"
        @staticmethod
        def _get_config_converter_class() -> BaseConfigConverter:
            return config_converter_class
        ConverterWithOptionalModel.get_config_converter_class = (
            _get_config_converter_class
        )
        ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
            {"get_config_converter_class"}
        )
    if formats:
        assert (
            isinstance(formats, tuple)
            and len(formats) == 2
            and all(isinstance(e, FormatVersions) for e in formats)
        ), "formats argument must be a tuple of two FormatVersions"
        @staticmethod
        def _formats_fn() -> Tuple[FormatVersions, FormatVersions]:
            return formats
        ConverterWithOptionalModel.formats = _formats_fn
        ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
            {"formats"}
        )
    if converter_note_fn:
        @classmethod
        def _converter_note(cls) -> str:
            return converter_note_fn(cls)
        ConverterWithOptionalModel.converter_note = _converter_note
        ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
            {"converter_note"}
        )
    return ConverterWithOptionalModel 
[docs]def convert_use_rms_layer_norm_helper(
    self,
    old_key,
    new_key,
    old_state_dict,
    new_state_dict,
    from_index,
    action_fn_args,
):
    if from_index == 0:
        new_state_dict[new_key] = (
            "rmsnorm" if old_state_dict[old_key] else "layernorm"
        )
    else:
        if old_state_dict[old_key] == "rmsnorm":
            new_state_dict[new_key] = True
        elif old_state_dict[old_key] == "layernorm":
            new_state_dict[new_key] = False
        else:
            raise ConfigConversionError(
                "{} did not support {}".format(
                    self.formats()[0], old_state_dict[old_key]
                )
            ) 
[docs]def convert_use_biasless_layer_norm_helper(
    self,
    old_key,
    new_key,
    old_state_dict,
    new_state_dict,
    from_index,
    action_fn_args,
):
    if from_index == 0:
        new_state_dict[new_key] = (
            "biasless-layernorm" if old_state_dict[old_key] else "layernorm"
        )
    else:
        if old_state_dict[old_key] == "biasless-layernorm":
            new_state_dict[new_key] = True
        elif old_state_dict[old_key] == "layernorm":
            new_state_dict[new_key] = False
        else:
            raise ConfigConversionError(
                "{} did not support {}".format(
                    self.formats()[0], old_state_dict[old_key]
                )
            ) 
# Old cstorch checkpoints had a bug where aliased weights would show up as None
# This helper function fixes this by tying old_key and new_key together
# if either one doesn't exist or is None.
[docs]def tie_none_weights(
    old_key: str,
    new_key: str,
    old_state_dict: OrderedDict,
    new_state_dict: OrderedDict,
    from_index: int,
    action_fn_args: Optional[dict] = None,
) -> None:
    r"""
    Ties weights stored at old_key & new_key
    """
    if new_key not in old_state_dict or (
        old_state_dict[old_key] is not None and old_state_dict[new_key] is None
    ):
        new_state_dict[old_key] = old_state_dict[old_key]
        new_state_dict[new_key] = old_state_dict[old_key]
    elif (
        old_state_dict[old_key] is None and old_state_dict[new_key] is not None
    ):
        new_state_dict[old_key] = old_state_dict[new_key]
        new_state_dict[new_key] = old_state_dict[new_key]
    else:
        new_state_dict[old_key] = old_state_dict[old_key] 
# Ties old_key and new_key if share_embedding_weights is enabled in the config
# (default is enabled)
[docs]def maybe_tie_lm_head(
    old_key: str,
    new_key: str,
    old_state_dict: OrderedDict,
    new_state_dict: OrderedDict,
    from_index: int,
    action_fn_args: Optional[dict] = None,
) -> None:
    cs_config = action_fn_args["configs"][1]
    if cs_config["model"].get("share_embedding_weights", True):
        tie_none_weights(
            old_key,
            new_key,
            old_state_dict,
            new_state_dict,
            from_index,
            action_fn_args,
        )
    else:
        new_state_dict[old_key] = old_state_dict[old_key] 
[docs]def transpose_key_if_2D(
    old_key,
    new_key,
    old_state_dict,
    new_state_dict,
    from_index,
    action_fn_args,
):
    # HF checkpoint stores some layers as Conv2D instead of Linear.
    # In those cases, we need to transpose the weight matrix for the
    # dimensions to line up when converting.
    x = old_state_dict[old_key]
    if len(x.shape) == 2:
        x = x.transpose(0, 1)
    new_state_dict[new_key] = x