Source code for common.pytorch.model_utils.checkpoint_converters.gpt2_hf_cs

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Tuple

import torch

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)


[docs]class Converter_GPT2_Attention_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("c_proj", "proj_output_dense_layer"), "\.(?:weight|bias)", ], action=transpose_key_if_2D, ), ConversionRule( [ EquivalentSubkey("c_attn", "proj_q_dense_layer"), "\.(?:weight|bias)", ], action=self.c_attn_converter, ), ConversionRule( [ EquivalentSubkey("q_attn", "proj_q_dense_layer"), "\.(?:weight|bias)", ], action=self.assert_already_converted, ), ConversionRule( [ EquivalentSubkey("c_attn", "proj_k_dense_layer"), "\.(?:weight|bias)", ], action=self.assert_already_converted, ), ConversionRule( [ EquivalentSubkey("c_attn", "proj_v_dense_layer"), "\.(?:weight|bias)", ], action=self.assert_already_converted, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
[docs] def c_attn_converter( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: self.c_attn_converter_hf_to_cs17( old_key, new_key, old_state_dict, new_state_dict, action_fn_args ) else: self.c_attn_converter_cs17_to_hf( old_key, new_key, old_state_dict, new_state_dict, action_fn_args )
[docs] def c_attn_converter_hf_to_cs17( self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args ): # HF represents Q, K, and V in a packed format. We need to unpack the # weight and bias tensor for CS 1.7 format. q_key = new_key k_key = re.sub("\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key) v_key = re.sub("\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key) if new_key.endswith(".bias"): assert len(old_state_dict[old_key].shape) == 1 packed_dim = old_state_dict[old_key].shape[0] embed_dim = packed_dim // 3 assert ( 3 * embed_dim == packed_dim ), "Invalid tensor shape {} at {}. Bias should be divisible by 3 since Q, K, and V are packed".format( old_state_dict[old_key].shape, old_key ) ( new_state_dict[q_key], new_state_dict[k_key], new_state_dict[v_key], ) = torch.chunk(old_state_dict[old_key], 3, dim=0) elif new_key.endswith(".weight"): embed_dim, packed_dim = old_state_dict[old_key].shape assert ( 3 * embed_dim == packed_dim ), "Invalid tensor shape {} at {}. The second dimension should be 3x the first dimension (embed_dim) since Q, K, and V are packed".format( old_state_dict[old_key].shape, old_key ) ( new_state_dict[q_key], new_state_dict[k_key], new_state_dict[v_key], ) = torch.chunk( torch.transpose(old_state_dict[old_key], 0, 1), 3, dim=0 ) else: raise ValueError("Invalid key after conversion: {}".format(new_key))
[docs] def c_attn_converter_cs17_to_hf( self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args, ): # HF represents Q, K, and V in a packed format. It also contains # special ".bias" and ".masked_bias" register buffers that need to be # initalized q_key = old_key k_key = re.sub("\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key) v_key = re.sub("\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key) assert ( k_key in old_state_dict ), "Expected the following key to exist! {}".format(k_key) assert ( v_key in old_state_dict ), "Expected the following key to exist! {}".format(v_key) new_state_dict[new_key] = torch.cat( ( old_state_dict[q_key], old_state_dict[k_key], old_state_dict[v_key], ), dim=0, ) # Need to tranpose to convert from Linear.weight -> Conv1D.weight if len(new_state_dict[new_key].shape) == 2: new_state_dict[new_key] = torch.transpose( new_state_dict[new_key], 0, 1 ) if new_key.endswith(".bias"): max_position_embeddings = action_fn_args["configs"][1]["model"][ "max_position_embeddings" ] attn_bias_key = re.sub("\.c_attn\.", ".", new_key) new_state_dict[attn_bias_key] = torch.tril( torch.ones( (max_position_embeddings, max_position_embeddings), dtype=torch.uint8, ) ).view(1, 1, max_position_embeddings, max_position_embeddings) masked_bias_key = re.sub("\.c_attn\.", ".masked_", new_key) new_state_dict[masked_bias_key] = torch.tensor(-1e4)
[docs] def assert_already_converted( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: # We should never hit this case as this key should have been matched # already assert False, "Invalid key: {}".format(old_key) else: # When we convert from CS -> HF, the proj_q_dense_layer should also handle # conversion of proj_k_dense_layer and proj_v_dense_layer since HF # represents these three layers in a packed format. We simply need # to test that the key containing the packed format has already # been converted. assert ( new_key in new_state_dict ), "Key should've been already converted: {} -> {}".format( old_key, new_key )
[docs]class Converter_GPT2Model_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("wte", "embedding_layer.word_embeddings"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "wpe", "embedding_layer.position_embeddings" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("h", "transformer_decoder.layers"), "\.\d+\.", EquivalentSubkey("attn.", "self_attn."), Converter_GPT2_Attention_HF_CS17(), ], action=None, ), ConversionRule( [ EquivalentSubkey("h", "transformer_decoder.layers"), "\.\d+\.", EquivalentSubkey("ln_1", "norm1"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("h", "transformer_decoder.layers"), "\.\d+\.", EquivalentSubkey("ln_2", "norm3"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("h", "transformer_decoder.layers"), "\.\d+\.", EquivalentSubkey("mlp.c_fc", "ffn.ffn.0.linear_layer"), "\.(?:weight|bias)", ], action=transpose_key_if_2D, ), ConversionRule( [ EquivalentSubkey("h", "transformer_decoder.layers"), "\.\d+\.", EquivalentSubkey("mlp.c_proj", "ffn.ffn.1.linear_layer"), "\.(?:weight|bias)", ], action=transpose_key_if_2D, ), ConversionRule( [ EquivalentSubkey("ln_f", "transformer_decoder.norm"), "\.(?:weight|bias)", ], action=self.replace_final_norm, ), ConversionRule(["lm_head\.(?:weight|bias)"], exists="right"), ConversionRule(["ln_f\.(?:weight|bias)"], exists="right"), ConversionRule( ["h\.\d+\.attn\.(?:masked_bias|bias)",], exists="left" ), ]
[docs] def replace_final_norm( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): new_state_dict[new_key] = old_state_dict[old_key] # CS 1.7 has both "ln_f" and "transformer_decoder.norm" # we need to copy the original ("ln_f") too: if from_index == 0: ln_f_key = re.sub("transformer_decoder\.norm\.", "ln_f.", new_key) new_state_dict[ln_f_key] = old_state_dict[old_key]
[docs] def pre_model_convert( self, old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, ): if from_index == 0: logging.warning( "{} GPT2 has a language model head (lm_head) " "while {} GPT2Model does not. Initializing lm_head to default.".format( self.formats()[1], self.formats()[0] ) ) # Manually tie weights if from_index == 1 and configs[1]["model"]["share_embedding_weights"]: if ( old_state_dict.get("embedding_layer.word_embeddings.weight", 0) is None ): old_state_dict[ "embedding_layer.word_embeddings.weight" ] = old_state_dict["lm_head.weight"]
[docs] def post_model_convert( self, old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, ): if from_index == 0: # We are converting from HF GPT2Model (which is headless) -> CS GPT2LMHeadModel # We need to create 'lm_head' and init to default values hf_config = configs[0] cs_config = configs[1] use_bias_in_output = cs_config["model"].get( "use_bias_in_output", False ) vocab_size = cs_config["model"]["vocab_size"] embed_dim = cs_config["model"]["hidden_size"] if hf_config["tie_word_embeddings"]: lm_head_weight = old_state_dict['wte.weight'] else: lm_head_weight = torch.zeros((vocab_size, embed_dim)) lm_head_weight.normal_(mean=0.0, std=0.02) new_state_dict["lm_head.weight"] = lm_head_weight if use_bias_in_output: lm_head_bias = torch.zeros(vocab_size) new_state_dict["lm_head.bias"] = lm_head_bias
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return ( "{} GPT2Model <-> {} GPT2LMHeadModel\n" "The HF model doesn't contain a language model head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a language model head initialized to default random " "values. When converting to HF, the language model head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1])
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT2Model_HF_CS17
[docs]class Converter_GPT2Model_HF_CS18(Converter_GPT2Model_HF_CS17):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule([Converter_GPT2Model_HF_CS17(),], action=None,), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_GPT2Model_HF_CS17(), ], action=None, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return ( "{} GPT2Model <-> {} GPT2LMHeadModel\n" "The HF model doesn't contain a language model head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a language model head initialized to default random " "values. When converting to HF, the language model head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1])
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT2Model_HF_CS18
[docs]class ConfigConverter_GPT2Model_HF_CS17(BaseConfigConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Embedding ConversionRule(["vocab_size"], action=self.replaceKey), ConversionRule( ["position_embedding_type"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, "learned"), ), ConversionRule( ["use_position_embedding"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("embd_pdrop", "embedding_dropout_rate")], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "tie_word_embeddings", "share_embedding_weights" ) ], action=self.replaceKey, ), ConversionRule( ["embedding_layer_norm"], action=BaseConfigConverter.assert_factory_fn(1, False), ), # Decoder Block ConversionRule( [EquivalentSubkey("n_embd", "hidden_size")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("n_head", "num_heads")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("n_layer", "num_hidden_layers")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("n_positions", "max_position_embeddings")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("scale_attn_weights", "attention_type")], action=self.convert_attention_type, ), ConversionRule( ["use_projection_bias_in_attention"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("n_inner", "filter_size")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("activation_function", "nonlinearity")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("attn_pdrop", "attention_dropout_rate")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("resid_pdrop", "dropout_rate")], action=self.replaceKey, ), ConversionRule(["rotary_dim"], action=self.replaceKey), ConversionRule(["layer_norm_epsilon"], action=self.replaceKey,), ConversionRule( ["use_bias_in_output"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule(["initializer_range"], action=self.replaceKey), ConversionRule( ["fixed_sparse_attention"], action=BaseConfigConverter.assert_factory_fn(1, None), ), ConversionRule( ["norm_first"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ff_layer1_dropout"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["scale_attn_by_inverse_layer_idx"], action=BaseConfigConverter.assert_factory_fn(0, False), ), ConversionRule( ["reorder_and_upcast_attn"], action=BaseConfigConverter.assert_factory_fn(0, False), ), ]
[docs] def convert_attention_type( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: new_state_dict[new_key] = ( "scaled_dot_product" if old_state_dict[old_key] else "dot_product" ) else: if ( old_state_dict[old_key] != "scaled_dot_product" and old_state_dict[old_key] != "dot_product" ): raise ConfigConversionError( "Can't convert config with {}={}. Only {} is supported.".format( old_key, old_state_dict[old_key], "scaled_dot_product and dot_product", ) ) new_state_dict[new_key] = old_state_dict[old_key].startswith( "scaled_" )
[docs] def pre_config_convert( self, config, from_index, ): config = super().pre_config_convert(config, from_index) defaults = [ {"tie_word_embeddings": True,}, {"share_embedding_weights": True,}, ] # Apply defaults for key in defaults[from_index]: if key not in config: config[key] = defaults[from_index][key] if from_index == 0: if "n_inner" not in config or config["n_inner"] is None: config["n_inner"] = 4 * config["n_embd"] else: if "embedding_dropout_rate" not in config: config["embedding_dropout_rate"] = config["dropout_rate"] if "attention_dropout_rate" not in config: config["attention_dropout_rate"] = config["dropout_rate"] return config
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_GPT2Model_HF_CS18(ConfigConverter_GPT2Model_HF_CS17):
[docs] def __init__(self): super().__init__()
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs]class Converter_GPT2LMHeadModel_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( ["lm_head\.(?:weight|bias)"], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("transformer.", ""), Converter_GPT2Model_HF_CS17(), ], action=None, ), ]
[docs] def pre_model_convert( self, old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, ): # Manually tie weights if from_index == 1 and configs[1]["model"]["share_embedding_weights"]: if ( old_state_dict.get("embedding_layer.word_embeddings.weight", 0) is None ): old_state_dict[ "embedding_layer.word_embeddings.weight" ] = old_state_dict["lm_head.weight"]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return "{} GPT2LMHeadModel <-> {} GPT2LMHeadModel".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT2Model_HF_CS17
[docs]class Converter_GPT2LMHeadModel_HF_CS18(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [Converter_GPT2LMHeadModel_HF_CS17(),], action=None, ), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_GPT2LMHeadModel_HF_CS17(), ], action=None, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return "{} GPT2LMHeadModel <-> {} GPT2LMHeadModel".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT2Model_HF_CS18
#### Action Helper Functions
[docs]def transpose_key_if_2D( old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): # HF checkpoint stores some layers as Conv2D instead of Linear. # In those cases, we need to tranpsose the weight matrix for the # dimensions to line up when converting. x = old_state_dict[old_key] if len(x.shape) == 2: x = x.transpose(0, 1) new_state_dict[new_key] = x