# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Tuple
import torch
from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
BaseConfigConverter_HF_CS,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.bert import ( # To CS 1.7; To CS 1.8
ConfigConverter_Bert_HF_CS21,
Converter_BertLayerNorm_HF_CS,
Converter_BertModel_WithoutOptionalModel_HF_CS21,
)
from cerebras.modelzoo.tools.checkpoint_converters.helper import (
Build_HF_CS_Converter_WithOptionalModel,
)
[docs]class Converter_Esm2Model_WithoutOptionalModel_HF_CS21(
Converter_BertModel_WithoutOptionalModel_HF_CS21
):
[docs] def __init__(self) -> None:
super().__init__()
self.rules = [
# Encoder Layers:
ConversionRule(
[
EquivalentSubkey(
"encoder.layer",
"transformer_encoder.layers",
),
r"\.\d+\.",
EquivalentSubkey(
"attention.self.query", "self_attn.proj_q_dense_layer"
),
r"\.(?:weight|bias)",
],
action=self.convert_with_interleaving_query_key,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer",
"transformer_encoder.layers",
),
r"\.\d+\.",
EquivalentSubkey(
"attention.self.key", "self_attn.proj_k_dense_layer"
),
r"\.(?:weight|bias)",
],
action=self.convert_with_interleaving_query_key,
),
ConversionRule(
[
EquivalentSubkey("embeddings.", ""),
Converter_BertLayerNorm_HF_CS("layer_norm", "embed_ln_f"),
],
action=None,
),
ConversionRule(
[
r"encoder\.layer\.\d+\.attention\.self\.rotary_embeddings"
r"\.inv_freq",
],
exists="left",
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer",
"transformer_encoder.layers",
),
r"\.\d+\.",
EquivalentSubkey("attention.", ""),
Converter_BertLayerNorm_HF_CS("LayerNorm", "norm1"),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer",
"transformer_encoder.layers",
),
r"\.\d+\.",
Converter_BertLayerNorm_HF_CS("LayerNorm", "norm2"),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.",
"transformer_encoder.",
),
Converter_BertLayerNorm_HF_CS(
"emb_layer_norm_after", "norm"
),
],
action=None,
),
*self.rules,
]
def convert_with_interleaving_query_key(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
cs_config = action_fn_args["configs"][1]
if cs_config["model"]["position_embedding_type"] != "rotary":
new_state_dict[new_key] = old_state_dict[old_key]
else:
# Query & Keys should be interleaved since HF and CS RoPE differ
tensor = old_state_dict[old_key]
initial_shape = tensor.size()
num_heads = cs_config["model"]["num_heads"]
if from_index == 0:
if len(tensor.size()) == 2:
tensor = tensor.view(
num_heads, tensor.size(0) // num_heads, tensor.size(-1)
)
elif len(tensor.size()) == 1:
tensor = tensor.view(num_heads, tensor.size(0) // num_heads)
tensor = self.interleave_helper(tensor, cs_config)
else:
tensor = self.reverse_interleave_helper(
tensor, cs_config, num_heads
)
tensor = tensor.view(*initial_shape)
new_state_dict[new_key] = tensor
def interleave_helper(self, t, cs_config):
rotary_dim = cs_config["model"]["rotary_dim"]
if len(t.shape) == 3:
to_rotate = t[:, :rotary_dim, :]
to_pass = t[:, rotary_dim:, :]
to_rotate = (
to_rotate.reshape(t.shape[0], 2, -1, t.shape[-1])
.permute(0, 2, 1, 3)
.reshape(t.shape[0], -1, t.shape[-1])
)
interleaved = torch.cat((to_rotate, to_pass), dim=1)
elif len(t.shape) == 2:
to_rotate = t[:, :rotary_dim]
to_pass = t[:, rotary_dim:]
to_rotate = (
to_rotate.reshape(t.shape[0], 2, -1)
.permute(0, 2, 1)
.reshape(t.shape[0], -1)
)
interleaved = torch.cat((to_rotate, to_pass), dim=1)
else:
assert False, (
"shape of query, key, value projection tensor has to have shape of length 2 "
"(biases) or 3 (weights) when converting from HF to CS."
)
return interleaved
def reverse_interleave_helper(self, t, cs_config, num_heads):
rotary_dim = cs_config["model"]["rotary_dim"]
if len(t.shape) == 2:
t = t.reshape(num_heads, -1, t.shape[-1])
to_rotate = t[:, :rotary_dim, :]
to_pass = t[:, rotary_dim:, :]
# pylint: disable=redefined-builtin
reversed = (
to_rotate.reshape(num_heads, -1, 2, t.shape[-1])
.permute(0, 2, 1, 3)
.reshape(num_heads, rotary_dim, t.shape[-1])
)
reversed = torch.cat((reversed, to_pass), dim=1)
elif len(t.shape) == 1:
t = t.reshape(num_heads, -1)
to_rotate = t[:, :rotary_dim]
to_pass = t[:, rotary_dim:]
reversed = (
to_rotate.reshape(num_heads, -1, 2)
.permute(0, 2, 1)
.reshape(num_heads, -1)
)
reversed = torch.cat((reversed, to_pass), dim=1)
else:
assert False, (
"shape of query, key, value projection tensor has to have shape of length 1 "
"(biases) or 2 (weights) when converting from CS to HF."
)
return reversed
def position_embeddings_convert(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if not (
from_index == 0
and action_fn_args["configs"][0]["position_embedding_type"]
!= "absolute"
):
self.replaceKey(
old_key, new_key, old_state_dict, new_state_dict, from_index
)
if from_index == 1:
# HF stores an register buffer with position_ids
position_id_key = re.sub(
r"\.position_embeddings\.weight", ".position_ids", new_key
)
if "max_position_embeddings" in action_fn_args["configs"][0]:
max_position_embeddings = action_fn_args["configs"][0][
"max_position_embeddings"
]
else:
max_position_embeddings = action_fn_args["configs"][1]["model"][
"max_position_embeddings"
]
new_state_dict[position_id_key] = torch.arange(
max_position_embeddings
).expand((1, -1))
[docs]class Converter_Esm2PretrainModel_WithoutOptionalModel_HF_CS21(
BaseCheckpointConverter_HF_CS
):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("esm.", "bert_encoder."),
Converter_Esm2Model_WithoutOptionalModel_HF_CS21(),
],
),
# Language Model Head:
ConversionRule(
[
EquivalentSubkey(
"lm_head.dense",
"bert_mlm_head.mlm_transform.ffn.ffn.0.linear_layer",
),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"lm_head.",
"bert_mlm_head.mlm_transform.",
),
Converter_BertLayerNorm_HF_CS("layer_norm", "ln"),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"lm_head.decoder",
"bert_mlm_head.classifier.ffn.0.linear_layer",
),
r"\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"lm_head",
"bert_mlm_head.classifier.ffn.0.linear_layer",
),
r"\.bias",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"lm_head",
"bert_mlm_head.classifier.ffn.0.linear_layer",
),
r"\.bias",
],
action=self.replaceKey,
),
# Contact Head:
ConversionRule(
[r"esm\.contact_head\.regression\.(?:weight|bias)"],
action=None,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-2.1", "cs-2.2"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Esm2_HF_CS21
@classmethod
def converter_note(cls) -> str:
return "{} EsmForMaskedLM <-> {} for Esm2ForPreTrainingModel".format(
cls.formats()[0], cls.formats()[1]
)
Converter_Esm2PretrainModel_HF_CS21 = Build_HF_CS_Converter_WithOptionalModel(
"Converter_Esm2PretrainModel_HF_CS21",
Converter_Esm2PretrainModel_WithoutOptionalModel_HF_CS21,
derived_class=Converter_Esm2PretrainModel_WithoutOptionalModel_HF_CS21,
)
[docs]class ConfigConverter_Esm2_HF_CS21(ConfigConverter_Bert_HF_CS21):
[docs] def __init__(self) -> None:
if not hasattr(self, "model_type"):
self.model_type = "esm"
super().__init__()
self.rules = [
ConversionRule(
["max_position_embeddings"],
action=self.convert_max_pos_embed,
),
ConversionRule(
["encoder_nonlinearity"],
action=BaseConfigConverter.assert_factory_fn(1, "gelu"),
),
ConversionRule(
["mlm_nonlinearity"],
action=BaseConfigConverter.assert_factory_fn(1, "gelu"),
),
ConversionRule(
["use_final_layer_norm"],
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
[
EquivalentSubkey(
"emb_layer_norm_before", "embedding_layer_norm"
)
],
action=self.replaceKey,
),
ConversionRule(
["token_dropout"],
action=self.replace_token_dropout,
),
ConversionRule(
["mask_token_id"],
action=self.replaceKey,
),
ConversionRule(
["pad_token_id"],
action=self.replaceKey,
),
ConversionRule(
["disable_nsp"],
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["rotary_dim"], exists="right", action=self.assert_rotary_dim
),
*self.rules,
]
self.pre_convert_defaults[0].update(
{
"mask_token_id": None,
"pad_token_id": None,
"token_dropout": False,
"emb_layer_norm_before": False,
}
)
self.pre_convert_defaults[1].update(
{
"disable_nsp": False,
"pad_token_id": 0,
"mask_padding_in_positional_embed": False,
"use_final_layer_norm": False,
"token_dropout": False,
"embedding_layer_norm": True,
}
)
self.post_convert_defaults[0].update(
{
"is_folding_model": False,
"esmfold_config": None,
}
)
self.post_convert_defaults[1].update(
{
"use_final_layer_norm": True,
"disable_nsp": True,
}
)
def assert_rotary_dim(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
assert from_index == 1, "{} should only exist in CS config".format(
old_key
)
if (
old_state_dict["position_embedding_type"] == "rotary"
and old_state_dict[old_key]
!= old_state_dict["hidden_size"] // old_state_dict["num_heads"]
):
raise ConfigConversionError(
"rotary_dim must be hidden_size // num_heads in order to be compatible with HF"
)
def replace_token_dropout(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
token_dropout = old_state_dict[old_key]
if token_dropout and old_state_dict.get("mask_token_id") is None:
raise ConfigConversionError(
"mask_token_id must be provided when token_dropout is enabled"
)
new_state_dict[new_key] = token_dropout
def convert_max_pos_embed(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# The following only applies to learned embeddings. There is no effect
# on RoPE
# The number of positional embeddings = MSL + pad token offset + 1
# HF refers to number of positional embeddings (the total) as
# max_position_embeddings while we refer to MSL as
# max_position_embeddings
if (
from_index == 0
and old_state_dict["position_embedding_type"] == "absolute"
):
new_state_dict[new_key] = (
old_state_dict[old_key] - old_state_dict["pad_token_id"] - 1
)
elif (
from_index == 1
and old_state_dict["position_embedding_type"] == "learned"
):
new_state_dict[new_key] = (
old_state_dict[old_key] + old_state_dict["pad_token_id"] + 1
)
else:
new_state_dict[new_key] = old_state_dict[old_key]
def convert_position_embedding_type(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# HF supports absolute, relative_key, relative_key_query, rotary
# CS supports learned, fixed, rotary
embed_type = old_state_dict[old_key]
if embed_type == "rotary":
new_state_dict[new_key] = embed_type
elif from_index == 0:
if embed_type == "absolute":
new_state_dict[new_key] = "learned"
new_state_dict["mask_padding_in_positional_embed"] = True
else:
raise ConfigConversionError(
"CS model doesn't support HF's position_embedding_type={}".format(
embed_type
)
)
else:
if embed_type == "learned":
if (
old_state_dict.get("mask_padding_in_positional_embed")
!= True
):
raise ConfigConversionError(
"ESM-2 trained in CS with learned embeddings must have "
"mask_padding_in_positional_embed=True in order to "
"convert to HF"
)
new_state_dict[new_key] = "absolute"
else:
raise ConfigConversionError(
"HF model doesn't support CS's position_embedding_type={}".format(
embed_type
)
)
def pre_config_convert(
self,
config,
converter_indices,
):
return BaseConfigConverter_HF_CS.pre_config_convert(
self, config, converter_indices
)
def post_config_convert(
self,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
):
if converter_indices.direction == 0:
if new_config["position_embedding_type"] == "rotary":
new_config["rotary_dim"] = (
new_config["hidden_size"] // new_config["num_heads"]
)
return BaseConfigConverter_HF_CS.post_config_convert(
self,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
)