# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Tuple
import torch
from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
BaseConfigConverter_HF_CS,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.gpt_neox_hf_cs import (
Converter_GPT_Neox_Attention_HF_CS17,
)
[docs]class Converter_Falcon_40B_Attention_HF_CS20(
Converter_GPT_Neox_Attention_HF_CS17
):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("dense", "proj_output_dense_layer"),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("query_key_value", "proj_q_dense_layer"),
r"\.(?:weight|bias)",
],
action=self.qkv_converter,
),
ConversionRule(
[
EquivalentSubkey("query_key_value", "proj_k_dense_layer"),
r"\.(?:weight|bias)",
],
action=self.assert_already_converted,
),
ConversionRule(
[
EquivalentSubkey("query_key_value", "proj_v_dense_layer"),
r"\.(?:weight|bias)",
],
action=self.assert_already_converted,
),
]
def interleave_helper(self, rotary_dim, t):
if len(t.shape) == 4:
to_rotate = t[:, :, :rotary_dim, :]
to_pass = t[:, :, rotary_dim:, :]
to_rotate = (
to_rotate.reshape(t.shape[0], t.shape[1], 2, -1, t.shape[-1])
.permute(0, 1, 3, 2, 4)
.reshape(t.shape[0], t.shape[1], -1, t.shape[-1])
)
interleaved = torch.cat((to_rotate, to_pass), dim=2)
elif len(t.shape) == 3:
to_rotate = t[:, :, :rotary_dim]
to_pass = t[:, :, rotary_dim:]
to_rotate = (
to_rotate.reshape(t.shape[0], t.shape[1], 2, -1)
.permute(0, 1, 3, 2)
.reshape(t.shape[0], t.shape[1], -1)
)
interleaved = torch.cat((to_rotate, to_pass), dim=2)
else:
assert False, (
"shape of query, key, value projection tensor has to have shape of length 2 "
"(biases) or 3 (weights) when converting from HF to CS"
)
return interleaved
def reverse_interleave_helper(
self, rotary_dim, t, group_size=None, num_groups=None
):
if len(t.shape) == 2:
t = t.reshape(num_groups, group_size, -1, t.shape[-1])
to_rotate = t[:, :, :rotary_dim, :]
to_pass = t[:, :, rotary_dim:, :]
# pylint: disable=redefined-builtin
reversed = (
to_rotate.reshape(num_groups, group_size, -1, 2, t.shape[-1])
.permute(0, 1, 3, 2, 4) # 2, 1, 3)
.reshape(num_groups, group_size, rotary_dim, t.shape[-1])
)
reversed = torch.cat((reversed, to_pass), dim=2)
elif len(t.shape) == 1:
t = t.reshape(num_groups, group_size, -1)
to_rotate = t[:, :, :rotary_dim]
to_pass = t[:, :, rotary_dim:]
reversed = (
to_rotate.reshape(num_groups, group_size, -1, 2)
.permute(0, 1, 3, 2)
.reshape(num_groups, group_size, -1)
)
reversed = torch.cat((reversed, to_pass), dim=2)
else:
assert False, (
"shape of query, key, value projection tensor has to have shape of length 1 "
"(biases) or 2 (weights) when converting from CS to HF"
)
return reversed
def qkv_converter_hf_to_cs(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
action_fn_args,
):
# HF represents Q, K, and V in a packed format (torch.Size(3*hidden, hidden)). We need to
# unpack the weight and bias tensor for CS 1.7 format.
q_key = new_key
k_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key)
v_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key)
cs_config = action_fn_args["configs"][1]
hidden_size = cs_config["model"]["hidden_size"]
num_heads = cs_config["model"]["num_heads"]
head_size = hidden_size // num_heads
num_kv_groups = cs_config["model"]["extra_attention_params"][
"num_kv_groups"
]
kv_group_size = num_heads // num_kv_groups
if new_key.endswith(".bias"):
assert len(old_state_dict[old_key].shape) == 1
packed_dim = old_state_dict[old_key].shape[0]
assert (
head_size * (num_kv_groups * 2 + num_heads) == packed_dim
), "Invalid tensor shape {} at {}.".format(
old_state_dict[old_key].shape, old_key
)
split_by_num_heads = old_state_dict[old_key].reshape(
num_kv_groups, (kv_group_size + 2), -1
)
query = split_by_num_heads[:, :kv_group_size]
key = split_by_num_heads[:, kv_group_size : kv_group_size + 1]
value = split_by_num_heads[:, kv_group_size + 1 : kv_group_size + 2]
query = self.interleave_helper(head_size, query)
key = self.interleave_helper(head_size, key)
query = query.reshape(-1)
value = value.reshape(-1)
key = key.reshape(-1)
new_state_dict[q_key] = query
new_state_dict[k_key] = key
new_state_dict[v_key] = value
elif new_key.endswith(".weight"):
packed_dim, dim = old_state_dict[old_key].shape
assert (
head_size * (num_kv_groups * 2 + num_heads)
) == packed_dim, "Invalid tensor shape {} at {}.".format(
old_state_dict[old_key].shape, old_key
)
split_by_num_heads = old_state_dict[old_key].reshape(
num_kv_groups, (kv_group_size + 2), -1, dim
)
query = split_by_num_heads[:, :kv_group_size]
key = split_by_num_heads[:, kv_group_size : kv_group_size + 1]
value = split_by_num_heads[:, kv_group_size + 1 : kv_group_size + 2]
query = self.interleave_helper(head_size, query)
key = self.interleave_helper(head_size, key)
query = query.reshape(-1, dim)
value = value.reshape(-1, dim)
key = key.reshape(-1, dim)
new_state_dict[q_key] = query
new_state_dict[k_key] = key
new_state_dict[v_key] = value
else:
raise ValueError("Invalid key after conversion: {}".format(new_key))
def qkv_converter_cs_to_hf(
self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args
):
# HF represents Q, K, and V in a packed format. It also contains
# special ".bias" and ".masked_bias" register buffers that need to be
# initialized
q_key = old_key
k_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key)
v_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key)
cs_config = action_fn_args["configs"][1]
hidden_size = cs_config["model"]["hidden_size"]
num_heads = cs_config["model"]["num_heads"]
head_size = hidden_size // num_heads
num_kv_groups = cs_config["model"]["extra_attention_params"][
"num_kv_groups"
]
kv_group_size = num_heads // num_kv_groups
assert (
k_key in old_state_dict
), "Expected the following key to exist! {}".format(k_key)
assert (
v_key in old_state_dict
), "Expected the following key to exist! {}".format(v_key)
query = old_state_dict[q_key]
value = old_state_dict[v_key]
key = old_state_dict[k_key]
if new_key.endswith(".bias"):
cs_config = action_fn_args["configs"][1]
# map qkv
query = self.reverse_interleave_helper(
head_size,
query,
group_size=kv_group_size,
num_groups=num_kv_groups,
)
key = self.reverse_interleave_helper(
head_size, key, group_size=1, num_groups=num_kv_groups
)
value = value.reshape(num_kv_groups, 1, -1)
packed_qkv = torch.cat(
(
query,
key,
value,
),
dim=1,
)
packed_qkv = packed_qkv.reshape(-1)
new_state_dict[new_key] = packed_qkv
elif new_key.endswith(".weight"):
hidden_size = query.shape[-1]
query = self.reverse_interleave_helper(
head_size,
query,
group_size=kv_group_size,
num_groups=num_kv_groups,
)
key = self.reverse_interleave_helper(
head_size, key, group_size=1, num_groups=num_kv_groups
)
value = value.reshape(num_kv_groups, 1, -1, value.shape[-1])
packed_qkv = torch.cat(
(
query,
key,
value,
),
dim=1,
)
packed_qkv = packed_qkv.reshape(-1, hidden_size)
new_state_dict[new_key] = packed_qkv
else:
raise ValueError("Invalid key after conversion: {}".format(new_key))
def qkv_converter(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
self.qkv_converter_hf_to_cs(
old_key, new_key, old_state_dict, new_state_dict, action_fn_args
)
else:
self.qkv_converter_cs_to_hf(
old_key, new_key, old_state_dict, new_state_dict, action_fn_args
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.9", "cs-2.0", "cs-2.1", "cs-2.2"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Falcon_40B_HF_CS20
[docs]class Converter_Falcon_40B_Headless_WithoutModelPrefix_HF_CS20(
BaseCheckpointConverter_HF_CS
):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Embedding:
ConversionRule(
[
EquivalentSubkey(
"word_embeddings", "embedding_layer.word_embeddings"
),
r"\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("ln_attn.", "norm1."),
r"(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("ln_mlp.", "norm3."),
r"(?:weight|bias)",
],
action=self.replaceKey,
),
# Attention:
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("self_attention.", "self_attn."),
Converter_Falcon_40B_Attention_HF_CS20(),
],
action=None,
),
# mlp
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"mlp.dense_h_to_4h", "ffn.ffn.0.linear_layer"
),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("h", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"mlp.dense_4h_to_h", "ffn.ffn.1.linear_layer"
),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
# final norm
ConversionRule(
[
EquivalentSubkey("ln_f", "transformer_decoder.norm"),
r"\.(?:weight|bias)",
],
action=self.replace_final_norm,
),
# other
ConversionRule([r"lm_head\.(?:weight|bias)"], exists="right"),
ConversionRule([r"ln_f\.(?:weight|bias)"], exists="right"),
]
def replace_final_norm(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
new_state_dict[new_key] = old_state_dict[old_key]
# CS model has both "ln_f" and "transformer_decoder.norm"
# we need to copy the original ("ln_f") too:
if from_index == 0:
ln_f_key = re.sub(r"transformer_decoder\.norm\.", "ln_f.", new_key)
new_state_dict[ln_f_key] = old_state_dict[old_key]
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
):
if converter_indices.direction == 0:
logging.warning(
"{} Falcon has a language model head (lm_head) "
"while {} GPTNeoxModel does not. Initializing lm_head to default.".format(
self.formats()[1], self.formats()[0]
)
)
# Manually tie weights
if (
converter_indices.direction == 1
and configs[1]["model"]["share_embedding_weights"]
):
if (
old_state_dict.get("embedding_layer.word_embeddings.weight", 0)
is None
):
old_state_dict[
"embedding_layer.word_embeddings.weight"
] = old_state_dict["lm_head.weight"]
[docs] def post_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
key_prefix="",
):
if converter_indices.direction == 0:
# We are converting from HF Falcon (which is headless) -> CS GPTJModel (which has a
# head). We need to create 'lm_head' and init to default values
hf_config = configs[0]
cs_config = configs[1]
use_bias_in_output = cs_config["model"].get(
"use_bias_in_output", False
)
vocab_size = cs_config["model"]["vocab_size"]
embed_dim = cs_config["model"]["hidden_size"]
if hf_config["tie_word_embeddings"]:
lm_head_weight = old_state_dict[
'transformer.word_embeddings.weight'
]
else:
lm_head_weight = torch.zeros((vocab_size, embed_dim))
lm_head_weight.normal_(mean=0.0, std=0.02)
new_state_dict[key_prefix + "lm_head.weight"] = lm_head_weight
if use_bias_in_output:
lm_head_bias = torch.zeros(vocab_size)
new_state_dict[key_prefix + "lm_head.bias"] = lm_head_bias
super().post_model_convert(
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
key_prefix=key_prefix,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.9", "cs-2.0", "cs-2.1", "cs-2.2"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Falcon_40B_HF_CS20
[docs]class Converter_Falcon_40B_Headless_HF_CS20(
Converter_Falcon_40B_Headless_WithoutModelPrefix_HF_CS20
):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule(
[
Converter_Falcon_40B_Headless_WithoutModelPrefix_HF_CS20(),
],
action=None,
),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_Falcon_40B_Headless_WithoutModelPrefix_HF_CS20(),
],
action=None,
),
]
[docs]class Converter_Falcon_40B_WithoutModelPrefix_HF_CS20(
BaseCheckpointConverter_HF_CS
):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
"lm_head",
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("transformer.", ""),
Converter_Falcon_40B_Headless_WithoutModelPrefix_HF_CS20(),
],
action=None,
),
]
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
):
# Manually tie weights
if (
converter_indices.direction == 1
and configs[1]["model"]["share_embedding_weights"]
):
if (
old_state_dict.get("embedding_layer.word_embeddings.weight", 0)
is None
):
old_state_dict[
"embedding_layer.word_embeddings.weight"
] = old_state_dict["lm_head.weight"]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.9", "cs-2.0", "cs-2.1", "cs-2.2"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Falcon_40B_HF_CS20
[docs]class Converter_Falcon_40B_HF_CS20(
Converter_Falcon_40B_WithoutModelPrefix_HF_CS20
):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule(
[
Converter_Falcon_40B_WithoutModelPrefix_HF_CS20(),
],
action=None,
),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_Falcon_40B_WithoutModelPrefix_HF_CS20(),
],
action=None,
),
]
[docs]class ConfigConverter_Falcon_40B_HF_CS20(BaseConfigConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
["model_type"],
action=BaseConfigConverter.assert_factory_fn(0, "RefinedWeb"),
),
# Embedding
ConversionRule(["vocab_size"], action=self.replaceKey),
ConversionRule(
[EquivalentSubkey("alibi", "position_embedding_type")],
action=self.convert_position_embedding_type,
),
ConversionRule(
[
EquivalentSubkey(
"tie_word_embeddings", "share_embedding_weights"
)
],
action=self.replaceKey,
),
# Decoder Block
ConversionRule(
["hidden_size"],
action=self.convert_hidden_size,
),
ConversionRule(
[EquivalentSubkey("n_head", "num_heads")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("n_head_kv", "extra_attention_params")],
action=self.convert_num_head_groups,
),
ConversionRule(
[EquivalentSubkey("n_layer", "num_hidden_layers")],
action=self.replaceKey,
),
ConversionRule(
["max_position_embeddings"],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("parallel_attn", "use_untied_layer_norm")],
action=self.parallel_attn_convert,
),
ConversionRule(
["use_projection_bias_in_attention"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["use_ffn_bias_in_attention"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["use_ffn_bias"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["nonlinearity"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, "gelu"),
),
ConversionRule(
[
EquivalentSubkey(
"attention_dropout", "attention_dropout_rate"
)
],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("hidden_dropout", "residual_dropout_rate")],
action=self.replaceKey,
),
ConversionRule(
["layer_norm_epsilon"],
action=self.replaceKey,
),
ConversionRule(
["use_bias_in_output"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["initializer_range"],
action=self.replaceKey,
),
ConversionRule(
["bias"],
exists="left",
action=BaseConfigConverter.assert_factory_fn(0, False),
),
ConversionRule(
["alibi"],
exists="left",
action=BaseConfigConverter.assert_factory_fn(0, False),
),
]
self.defaults = [
{
"alibi": False,
"architectures": ["RWForCausalLM"],
"auto_map": {
"AutoConfig": "configuration_RW.RWConfig",
"AutoModel": "modelling_RW.RWModel",
"AutoModelForSequenceClassification": (
"modelling_RW.RWForSequenceClassification"
),
"AutoModelForTokenClassification": "modelling_RW.RWForTokenClassification",
"AutoModelForQuestionAnswering": "modelling_RW.RWForQuestionAnswering",
"AutoModelForCausalLM": "modelling_RW.RWForCausalLM",
},
"parallel_attn": True,
"bias": False,
"bos_token_id": 11,
"eos_token_id": 11,
"model_type": "RefinedWeb",
"torch_dtype": "bfloat16",
"use_cache": True,
"tie_word_embeddings": True,
},
{
"position_embedding_type": "rotary",
"embedding_dropout_rate": 0.0,
"share_embedding_weights": True,
"nonlinearity": "gelu",
"max_position_embeddings": 2048,
"attention_module": "multiquery_attention",
"attention_type": "scaled_dot_product",
"use_untied_layer_norm": True,
"extra_attention_params": {"num_kv_groups": 1},
"loss_scaling": "num_tokens",
},
]
def convert_num_head_groups(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
extra = {"num_kv_groups": old_state_dict[old_key]}
new_state_dict[new_key] = extra
elif from_index == 1:
new_state_dict[new_key] = old_state_dict[old_key]["num_kv_groups"]
def convert_position_embedding_type(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# HF supports absolute, or sinusoidal (fixed)
# CS supports learned, fixed
if from_index == 0:
if old_state_dict[old_key] == True:
raise ConfigConversionError(
"CS model doesn't support falcon with position_embedding_type = alibi"
)
new_state_dict[new_key] = "rotary"
else:
new_state_dict[new_key] = False
def convert_hidden_size(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
new_state_dict[new_key] = old_state_dict[old_key]
if from_index == 0:
# Falcon uses 4 * hidden as intermediate size
new_state_dict["filter_size"] = old_state_dict[old_key] * 4
else:
assert (
old_state_dict[old_key] * 4 == old_state_dict["filter_size"]
), "HF model only supports filter_size = 4 * hidden_size"
def parallel_attn_convert(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
assert (
old_state_dict[old_key] == True
), "parallel attention has to be enabled for falcon-40B"
new_state_dict[new_key] = True
def pre_config_convert(
self,
config,
converter_indices,
):
config = super().pre_config_convert(config, converter_indices)
# Apply defaults
for key in self.defaults[converter_indices.direction]:
if key not in config:
config[key] = self.defaults[converter_indices.direction][key]
return config
def post_config_convert(
self,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
):
# Apply defaults
for key in self.defaults[1 - converter_indices.direction]:
if key not in new_config:
new_config[key] = self.defaults[
1 - converter_indices.direction
][key]
if converter_indices.direction == 0:
# falcon uses rotary_dim == head_dim
new_config["rotary_dim"] = (
old_config["hidden_size"] // old_config["n_head"]
)
else:
# embedding dropout check
assert (
old_config["embedding_dropout_rate"] == 0.0
), "Falcon has no embedding dropout"
# rotary check
assert (
old_config["rotary_dim"]
== old_config["hidden_size"] // old_config["num_heads"]
), "rotary dimension of falcon is equal to head_dim"
return super().post_config_convert(
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.9", "cs-2.0", "cs-2.1", "cs-2.2"),
)