Source code for common.pytorch.model_utils.checkpoint_converters.bert_finetune

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Tuple

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseCheckpointConverter_PT_PT,
    BaseConfigConverter,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.bert import (
    ConfigConverter_Bert_CS16_CS17,
    ConfigConverter_Bert_CS16_CS18,
    ConfigConverter_Bert_HF_CS17,
    ConfigConverter_Bert_HF_CS18,
    Converter_BertModel_CS16_CS17,
)


[docs]class Converter_BertFinetuneModel_CS16_CS17(BaseCheckpointConverter_PT_PT):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule(["bert\.", Converter_BertModel_CS16_CS17(),],), ConversionRule( ["classifier\.(?:weight|bias)"], action=self.replaceKey, ), ]
[docs] def post_checkpoint_convert( self, checkpoint, from_index: int, ): logging.warning( "The Bert model changed significantly between {} and {}. As a result, the" " optimizer state won't be included in the converted checkpoint.".format( *self.formats() ) ) return {"model": checkpoint["model"]}
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return ( "BertForSequenceClassification, BertForTokenClassification, " "BertForQuestionAnswering, and BertForSummarization classes" )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS17
[docs]class Converter_BertFinetuneModel_CS16_CS18(BaseCheckpointConverter_PT_PT):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [Converter_BertFinetuneModel_CS16_CS17(),], action=None, ), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertFinetuneModel_CS16_CS17(), ], action=None, ), ]
[docs] def post_checkpoint_convert( self, checkpoint, from_index: int, ): logging.warning( "The Bert model changed significantly between {} and {}. As a result, the" " optimizer state won't be included in the converted checkpoint.".format( *self.formats() ) ) return {"model": checkpoint["model"]}
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.6"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return ( "BertForSequenceClassification, BertForTokenClassification, " "BertForQuestionAnswering, and BertForSummarization classes" )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS18
[docs]class Converter_BertForSequenceClassification_HF_CS17( Converter_BertFinetuneModel_CS16_CS17, BaseCheckpointConverter_HF_CS ):
[docs] def __init__(self): super().__init__()
[docs] def post_checkpoint_convert( self, checkpoint, from_index: int, ): return BaseCheckpointConverter_HF_CS.post_checkpoint_convert( self, checkpoint, from_index )
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForSequenceClassification".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForSequenceClassification_HF_CS17
[docs]class Converter_BertForSequenceClassification_HF_CS18( BaseCheckpointConverter_HF_CS ):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [Converter_BertForSequenceClassification_HF_CS17(),], action=None, ), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertForSequenceClassification_HF_CS17(), ], action=None, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForSequenceClassification".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForSequenceClassification_HF_CS18
[docs]class ConfigConverter_BertForSequenceClassification_HF_CS17( ConfigConverter_Bert_HF_CS17 ):
[docs] def __init__(self): super().__init__() self.rules = [ # Finetuning config params ConversionRule( [EquivalentSubkey("classifier_dropout", "task_dropout")], action=self.replaceKey, ), ConversionRule(["num_labels"], action=self.replaceKey), ConversionRule(["problem_type"], action=self.replaceKey), *self.rules, ]
[docs] def pre_config_convert( self, config, from_index, ): config = super().pre_config_convert(config, from_index) # From https://github.com/huggingface/transformers/blob/23c146c38b42d1193849fbd6f2943bf754b7c428/src/transformers/models/bert/modeling_bert.py#L1579 if from_index == 0: if "num_labels" not in config: if "id2label" in config: config["num_labels"] = len(config["id2label"]) else: config["num_labels"] = 2 if ( "classifier_dropout" not in config or config["classifier_dropout"] is None ): config["classifier_dropout"] = config["hidden_dropout_prob"] if "problem_type" not in config or config["problem_type"] is None: if config["num_labels"] == 1: config["problem_type"] = "regression" else: raise ConfigConversionError( "Cannot infer the problem_type (it is either single_label_classification or multi_label_classification). Please explcitly include the problem_type field before re-running." ) return config
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_BertForSequenceClassification_HF_CS18( ConfigConverter_BertForSequenceClassification_HF_CS17, ConfigConverter_Bert_HF_CS18, ):
[docs] def __init__(self): super().__init__()
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs]class Converter_BertForTokenClassification_HF_CS17( Converter_BertFinetuneModel_CS16_CS17, BaseCheckpointConverter_HF_CS ):
[docs] def __init__(self): super().__init__()
[docs] def post_checkpoint_convert( self, checkpoint, from_index: int, ): return BaseCheckpointConverter_HF_CS.post_checkpoint_convert( self, checkpoint, from_index )
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForTokenClassification".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForTokenClassification_HF_CS17
[docs]class Converter_BertForTokenClassification_HF_CS18( BaseCheckpointConverter_HF_CS ):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [Converter_BertForTokenClassification_HF_CS17(),], action=None, ), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertForTokenClassification_HF_CS17(), ], action=None, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForTokenClassification".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForTokenClassification_HF_CS18
[docs]class ConfigConverter_BertForTokenClassification_HF_CS17( ConfigConverter_Bert_HF_CS17 ):
[docs] def __init__(self): super().__init__() self.rules = [ # Finetuning config params ConversionRule( [ EquivalentSubkey( "classifier_dropout", "encoder_output_dropout_rate" ) ], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_labels", "num_classes")], action=self.replaceKey, ), *self.rules, ]
[docs] def pre_config_convert( self, config, from_index, ): config = super().pre_config_convert(config, from_index) # Additional Finetune specific defaults: if from_index == 0: if "num_labels" not in config: if "id2label" in config: config["num_labels"] = len(config["id2label"]) else: config["num_labels"] = 2 if ( "classifier_dropout" not in config or config["classifier_dropout"] is None ): config["classifier_dropout"] = config["hidden_dropout_prob"] return config
[docs] def post_config_convert( self, original_config, old_config, new_config, from_index, drop_unmatched_keys, ): if from_index == 0: if "loss_weight" not in new_config: new_config["loss_weight"] = 1.0 if "include_padding_in_loss" not in new_config: new_config["include_padding_in_loss"] = False return super().post_config_convert( original_config, old_config, new_config, from_index, drop_unmatched_keys, )
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_BertForTokenClassification_HF_CS18( ConfigConverter_BertForTokenClassification_HF_CS17, ConfigConverter_Bert_HF_CS18, ):
[docs] def __init__(self): super().__init__()
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs]class Converter_BertForQuestionAnswering_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule(["bert\.", Converter_BertModel_CS16_CS17(),],), ConversionRule( [ EquivalentSubkey("qa_outputs", "classifier"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForQuestionAnswering".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForQuestionAnswering_HF_CS17
[docs]class Converter_BertForQuestionAnswering_HF_CS18(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [Converter_BertForQuestionAnswering_HF_CS17(),], action=None, ), # Catch checkpoints from depricated PyTorchBaseModel ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertForQuestionAnswering_HF_CS17(), ], action=None, ), ]
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs] @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForQuestionAnswering".format( cls.formats()[0], cls.formats()[1] )
[docs] @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertForQuestionAnswering_HF_CS18
[docs]class ConfigConverter_BertForQuestionAnswering_HF_CS17( ConfigConverter_Bert_HF_CS17 ):
[docs] def __init__(self): super().__init__() self.rules = [ # Finetuning config params ConversionRule( ["num_labels"], action=BaseConfigConverter.assert_factory_fn(0, 2), ), *self.rules, ]
[docs] def pre_config_convert( self, config, from_index, ): config = super().pre_config_convert(config, from_index) # Additional Finetune specific defaults: if from_index == 0: if "num_labels" not in config: if "id2label" in config: config["num_labels"] = len(config["id2label"]) else: config["num_labels"] = 2 print(config) return config
[docs] def post_config_convert( self, original_config, old_config, new_config, from_index, drop_unmatched_keys, ): if from_index == 1: new_config["num_labels"] = 2 return super().post_config_convert( original_config, old_config, new_config, from_index, drop_unmatched_keys, )
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_BertForQuestionAnswering_HF_CS18( ConfigConverter_BertForQuestionAnswering_HF_CS17, ConfigConverter_Bert_HF_CS18, ):
[docs] def __init__(self): super().__init__()
[docs] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))