# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Tuple
import torch
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
BaseConfigConverter_HF_CS,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
[docs]class Converter_GPT2_Attention_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("c_proj", "proj_output_dense_layer"),
"\.(?:weight|bias)",
],
action=transpose_key_if_2D,
),
ConversionRule(
[
EquivalentSubkey("c_attn", "proj_q_dense_layer"),
"\.(?:weight|bias)",
],
action=self.c_attn_converter,
),
ConversionRule(
[
EquivalentSubkey("q_attn", "proj_q_dense_layer"),
"\.(?:weight|bias)",
],
action=self.assert_already_converted,
),
ConversionRule(
[
EquivalentSubkey("c_attn", "proj_k_dense_layer"),
"\.(?:weight|bias)",
],
action=self.assert_already_converted,
),
ConversionRule(
[
EquivalentSubkey("c_attn", "proj_v_dense_layer"),
"\.(?:weight|bias)",
],
action=self.assert_already_converted,
),
]
[docs] @staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return None
[docs] def c_attn_converter(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
self.c_attn_converter_hf_to_cs17(
old_key, new_key, old_state_dict, new_state_dict, action_fn_args
)
else:
self.c_attn_converter_cs17_to_hf(
old_key, new_key, old_state_dict, new_state_dict, action_fn_args
)
[docs] def c_attn_converter_hf_to_cs17(
self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args
):
# HF represents Q, K, and V in a packed format. We need to unpack the
# weight and bias tensor for CS 1.7 format.
q_key = new_key
k_key = re.sub("\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key)
v_key = re.sub("\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key)
if new_key.endswith(".bias"):
assert len(old_state_dict[old_key].shape) == 1
packed_dim = old_state_dict[old_key].shape[0]
embed_dim = packed_dim // 3
assert (
3 * embed_dim == packed_dim
), "Invalid tensor shape {} at {}. Bias should be divisible by 3 since Q, K, and V are packed".format(
old_state_dict[old_key].shape, old_key
)
(
new_state_dict[q_key],
new_state_dict[k_key],
new_state_dict[v_key],
) = torch.chunk(old_state_dict[old_key], 3, dim=0)
elif new_key.endswith(".weight"):
embed_dim, packed_dim = old_state_dict[old_key].shape
assert (
3 * embed_dim == packed_dim
), "Invalid tensor shape {} at {}. The second dimension should be 3x the first dimension (embed_dim) since Q, K, and V are packed".format(
old_state_dict[old_key].shape, old_key
)
(
new_state_dict[q_key],
new_state_dict[k_key],
new_state_dict[v_key],
) = torch.chunk(
torch.transpose(old_state_dict[old_key], 0, 1), 3, dim=0
)
else:
raise ValueError("Invalid key after conversion: {}".format(new_key))
[docs] def c_attn_converter_cs17_to_hf(
self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args,
):
# HF represents Q, K, and V in a packed format. It also contains
# special ".bias" and ".masked_bias" register buffers that need to be
# initalized
q_key = old_key
k_key = re.sub("\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key)
v_key = re.sub("\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key)
assert (
k_key in old_state_dict
), "Expected the following key to exist! {}".format(k_key)
assert (
v_key in old_state_dict
), "Expected the following key to exist! {}".format(v_key)
new_state_dict[new_key] = torch.cat(
(
old_state_dict[q_key],
old_state_dict[k_key],
old_state_dict[v_key],
),
dim=0,
)
# Need to tranpose to convert from Linear.weight -> Conv1D.weight
if len(new_state_dict[new_key].shape) == 2:
new_state_dict[new_key] = torch.transpose(
new_state_dict[new_key], 0, 1
)
if new_key.endswith(".bias"):
max_position_embeddings = action_fn_args["configs"][1]["model"][
"max_position_embeddings"
]
attn_bias_key = re.sub("\.c_attn\.", ".", new_key)
new_state_dict[attn_bias_key] = torch.tril(
torch.ones(
(max_position_embeddings, max_position_embeddings),
dtype=torch.uint8,
)
).view(1, 1, max_position_embeddings, max_position_embeddings)
masked_bias_key = re.sub("\.c_attn\.", ".masked_", new_key)
new_state_dict[masked_bias_key] = torch.tensor(-1e4)
[docs] def assert_already_converted(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
# We should never hit this case as this key should have been matched
# already
assert False, "Invalid key: {}".format(old_key)
else:
# When we convert from CS -> HF, the proj_q_dense_layer should also handle
# conversion of proj_k_dense_layer and proj_v_dense_layer since HF
# represents these three layers in a packed format. We simply need
# to test that the key containing the packed format has already
# been converted.
assert (
new_key in new_state_dict
), "Key should've been already converted: {} -> {}".format(
old_key, new_key
)
[docs]class Converter_GPT2Model_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("wte", "embedding_layer.word_embeddings"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"wpe", "embedding_layer.position_embeddings"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
"\.\d+\.",
EquivalentSubkey("attn.", "self_attn."),
Converter_GPT2_Attention_HF_CS17(),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
"\.\d+\.",
EquivalentSubkey("ln_1", "norm1"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
"\.\d+\.",
EquivalentSubkey("ln_2", "norm3"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
"\.\d+\.",
EquivalentSubkey("mlp.c_fc", "ffn.ffn.0.linear_layer"),
"\.(?:weight|bias)",
],
action=transpose_key_if_2D,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
"\.\d+\.",
EquivalentSubkey("mlp.c_proj", "ffn.ffn.1.linear_layer"),
"\.(?:weight|bias)",
],
action=transpose_key_if_2D,
),
ConversionRule(
[
EquivalentSubkey("ln_f", "transformer_decoder.norm"),
"\.(?:weight|bias)",
],
action=self.replace_final_norm,
),
ConversionRule(["lm_head\.(?:weight|bias)"], exists="right"),
ConversionRule(["ln_f\.(?:weight|bias)"], exists="right"),
ConversionRule(
["h\.\d+\.attn\.(?:masked_bias|bias)",], exists="left"
),
]
[docs] def replace_final_norm(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
new_state_dict[new_key] = old_state_dict[old_key]
# CS 1.7 has both "ln_f" and "transformer_decoder.norm"
# we need to copy the original ("ln_f") too:
if from_index == 0:
ln_f_key = re.sub("transformer_decoder\.norm\.", "ln_f.", new_key)
new_state_dict[ln_f_key] = old_state_dict[old_key]
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
from_index,
drop_unmatched_keys,
):
if from_index == 0:
logging.warning(
"{} GPT2 has a language model head (lm_head) "
"while {} GPT2Model does not. Initializing lm_head to default.".format(
self.formats()[1], self.formats()[0]
)
)
# Manually tie weights
if from_index == 1 and configs[1]["model"]["share_embedding_weights"]:
if (
old_state_dict.get("embedding_layer.word_embeddings.weight", 0)
is None
):
old_state_dict[
"embedding_layer.word_embeddings.weight"
] = old_state_dict["lm_head.weight"]
[docs] def post_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
from_index,
drop_unmatched_keys,
):
if from_index == 0:
# We are converting from HF GPT2Model (which is headless) -> CS GPT2LMHeadModel
# We need to create 'lm_head' and init to default values
hf_config = configs[0]
cs_config = configs[1]
use_bias_in_output = cs_config["model"].get(
"use_bias_in_output", False
)
vocab_size = cs_config["model"]["vocab_size"]
embed_dim = cs_config["model"]["hidden_size"]
if hf_config["tie_word_embeddings"]:
lm_head_weight = old_state_dict['wte.weight']
else:
lm_head_weight = torch.zeros((vocab_size, embed_dim))
lm_head_weight.normal_(mean=0.0, std=0.02)
new_state_dict["lm_head.weight"] = lm_head_weight
if use_bias_in_output:
lm_head_bias = torch.zeros(vocab_size)
new_state_dict["lm_head.bias"] = lm_head_bias
[docs] @classmethod
def converter_note(cls) -> str:
return (
"{} GPT2Model <-> {} GPT2LMHeadModel\n"
"The HF model doesn't contain a language model head while the CS "
"one does. When converting to CS, the exported checkpoint will "
"contain a language model head initialized to default random "
"values. When converting to HF, the language model head will be "
"dropped."
).format(cls.formats()[0], cls.formats()[1])
[docs] @staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_GPT2Model_HF_CS17
[docs]class Converter_GPT2Model_HF_CS18(Converter_GPT2Model_HF_CS17):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule([Converter_GPT2Model_HF_CS17(),], action=None,),
# Catch checkpoints from depricated PyTorchBaseModel
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_GPT2Model_HF_CS17(),
],
action=None,
),
]
[docs] @classmethod
def converter_note(cls) -> str:
return (
"{} GPT2Model <-> {} GPT2LMHeadModel\n"
"The HF model doesn't contain a language model head while the CS "
"one does. When converting to CS, the exported checkpoint will "
"contain a language model head initialized to default random "
"values. When converting to HF, the language model head will be "
"dropped."
).format(cls.formats()[0], cls.formats()[1])
[docs] @staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_GPT2Model_HF_CS18
[docs]class ConfigConverter_GPT2Model_HF_CS17(BaseConfigConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Embedding
ConversionRule(["vocab_size"], action=self.replaceKey),
ConversionRule(
["position_embedding_type"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, "learned"),
),
ConversionRule(
["use_position_embedding"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
[EquivalentSubkey("embd_pdrop", "embedding_dropout_rate")],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"tie_word_embeddings", "share_embedding_weights"
)
],
action=self.replaceKey,
),
ConversionRule(
["embedding_layer_norm"],
action=BaseConfigConverter.assert_factory_fn(1, False),
),
# Decoder Block
ConversionRule(
[EquivalentSubkey("n_embd", "hidden_size")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("n_head", "num_heads")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("n_layer", "num_hidden_layers")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("n_positions", "max_position_embeddings")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("scale_attn_weights", "attention_type")],
action=self.convert_attention_type,
),
ConversionRule(
["use_projection_bias_in_attention"],
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ffn_bias_in_attention"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ffn_bias"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
[EquivalentSubkey("n_inner", "filter_size")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("activation_function", "nonlinearity")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("attn_pdrop", "attention_dropout_rate")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("resid_pdrop", "dropout_rate")],
action=self.replaceKey,
),
ConversionRule(["rotary_dim"], action=self.replaceKey),
ConversionRule(["layer_norm_epsilon"], action=self.replaceKey,),
ConversionRule(
["use_bias_in_output"],
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(["initializer_range"], action=self.replaceKey),
ConversionRule(
["fixed_sparse_attention"],
action=BaseConfigConverter.assert_factory_fn(1, None),
),
ConversionRule(
["norm_first"],
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ff_layer1_dropout"],
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["scale_attn_by_inverse_layer_idx"],
action=BaseConfigConverter.assert_factory_fn(0, False),
),
ConversionRule(
["reorder_and_upcast_attn"],
action=BaseConfigConverter.assert_factory_fn(0, False),
),
]
[docs] def convert_attention_type(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
new_state_dict[new_key] = (
"scaled_dot_product"
if old_state_dict[old_key]
else "dot_product"
)
else:
if (
old_state_dict[old_key] != "scaled_dot_product"
and old_state_dict[old_key] != "dot_product"
):
raise ConfigConversionError(
"Can't convert config with {}={}. Only {} is supported.".format(
old_key,
old_state_dict[old_key],
"scaled_dot_product and dot_product",
)
)
new_state_dict[new_key] = old_state_dict[old_key].startswith(
"scaled_"
)
[docs] def pre_config_convert(
self, config, from_index,
):
config = super().pre_config_convert(config, from_index)
defaults = [
{"tie_word_embeddings": True,},
{"share_embedding_weights": True,},
]
# Apply defaults
for key in defaults[from_index]:
if key not in config:
config[key] = defaults[from_index][key]
if from_index == 0:
if "n_inner" not in config or config["n_inner"] is None:
config["n_inner"] = 4 * config["n_embd"]
else:
if "embedding_dropout_rate" not in config:
config["embedding_dropout_rate"] = config["dropout_rate"]
if "attention_dropout_rate" not in config:
config["attention_dropout_rate"] = config["dropout_rate"]
return config
[docs]class ConfigConverter_GPT2Model_HF_CS18(ConfigConverter_GPT2Model_HF_CS17):
[docs] def __init__(self):
super().__init__()
[docs]class Converter_GPT2LMHeadModel_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
["lm_head\.(?:weight|bias)"], action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("transformer.", ""),
Converter_GPT2Model_HF_CS17(),
],
action=None,
),
]
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
from_index,
drop_unmatched_keys,
):
# Manually tie weights
if from_index == 1 and configs[1]["model"]["share_embedding_weights"]:
if (
old_state_dict.get("embedding_layer.word_embeddings.weight", 0)
is None
):
old_state_dict[
"embedding_layer.word_embeddings.weight"
] = old_state_dict["lm_head.weight"]
[docs] @classmethod
def converter_note(cls) -> str:
return "{} GPT2LMHeadModel <-> {} GPT2LMHeadModel".format(
cls.formats()[0], cls.formats()[1]
)
[docs] @staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_GPT2Model_HF_CS17
[docs]class Converter_GPT2LMHeadModel_HF_CS18(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule(
[Converter_GPT2LMHeadModel_HF_CS17(),], action=None,
),
# Catch checkpoints from depricated PyTorchBaseModel
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_GPT2LMHeadModel_HF_CS17(),
],
action=None,
),
]
[docs] @classmethod
def converter_note(cls) -> str:
return "{} GPT2LMHeadModel <-> {} GPT2LMHeadModel".format(
cls.formats()[0], cls.formats()[1]
)
[docs] @staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_GPT2Model_HF_CS18
#### Action Helper Functions
[docs]def transpose_key_if_2D(
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# HF checkpoint stores some layers as Conv2D instead of Linear.
# In those cases, we need to tranpsose the weight matrix for the
# dimensions to line up when converting.
x = old_state_dict[old_key]
if len(x.shape) == 2:
x = x.transpose(0, 1)
new_state_dict[new_key] = x