# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from collections import OrderedDict
from typing import Tuple
import torch
from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
BaseConfigConverter,
BaseConfigConverter_HF_CS,
ConversionRule,
EquivalentSubkey,
FormatIndices,
FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.clip_vit import (
ConfigConverter_CLIPViT_HF_CS21,
Converter_CLIPViT_Core_HF_CS21,
)
from cerebras.modelzoo.tools.checkpoint_converters.helper import (
Build_HF_CS_Converter_WithOptionalModel,
)
from cerebras.modelzoo.tools.checkpoint_converters.llama import (
Converter_LlamaModel_HF_CS21,
)
from cerebras.modelzoo.tools.checkpoint_converters.llava import (
ConfigConverter_LLaMaProjector_HF_CS22,
ConfigConverter_LLaVA_HF_CS22,
Converter_LLaVA_CLIPViT_WithoutModel_HF_CS22,
Converter_LLaVA_LLaMA_WithoutModel_HF_CS22,
Converter_LLaVA_WithoutModel_HF_CS22,
)
[docs]class Converter_MMSimple_LLaVA_LLaMA_WithoutModel_HF_CS23(
Converter_LLaVA_LLaMA_WithoutModel_HF_CS22
):
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("lm_head", "text_model.lm_head"),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("model.", "text_model."),
Converter_LlamaModel_HF_CS21(),
],
),
ConversionRule(
[
r"image_model.image_model_list.0.0.*",
],
exists="right",
action=None,
),
# projector_image_model
ConversionRule(
[
EquivalentSubkey(
"model.mm_projector", "image_model.projection.ffn"
),
r"\.\d+",
EquivalentSubkey(".", ".linear_layer."),
r"(?:weight|bias)",
],
action=self.convert_projector,
),
*self.rules,
]
@classmethod
def converter_note(cls) -> str:
note = super().converter_note()
return (
note
+ f"MMSimple LLaVA converter using CLIP-ViT and LLaMA backbones."
)
[docs]class Converter_MMSimple_LLaVA_CLIPViT_WithoutModel_HF_CS23(
Converter_LLaVA_CLIPViT_WithoutModel_HF_CS22
):
def __init__(self):
super().__init__()
self.rules = [
# This is ignored since it's handled in Vision model
# i.e Converter_MMSimple_LLaVA_LLaMA_WithoutModel_HF_CS23
ConversionRule(
[r"image_model.projection.*"],
exists="right",
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"vision_model.",
"image_model.image_model_list.0.0.",
),
Converter_CLIPViT_Core_HF_CS21(),
],
),
*self.rules,
]
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_MMSimple_LLaVA_CLIPViT_HF_CS23
Converter_MMSimple_LLaVA_CLIPViT_HF_CS23 = (
Build_HF_CS_Converter_WithOptionalModel(
"Converter_MMSimple_LLaVA_CLIPViT_HF_CS23",
Converter_MMSimple_LLaVA_CLIPViT_WithoutModel_HF_CS23,
derived_class=Converter_MMSimple_LLaVA_CLIPViT_WithoutModel_HF_CS23,
)
)
Converter_MMSimple_LLaVA_LLaMA_HF_CS23 = (
Build_HF_CS_Converter_WithOptionalModel(
"Converter_MMSimple_LLaVA_LLaMA_HF_CS23",
Converter_MMSimple_LLaVA_LLaMA_WithoutModel_HF_CS23,
derived_class=Converter_MMSimple_LLaVA_LLaMA_WithoutModel_HF_CS23,
)
)
[docs]class Converter_MMSimple_LLaVA_WithoutModel_HF_CS24(
Converter_LLaVA_WithoutModel_HF_CS22
):
def __init__(self):
super().__init__()
@staticmethod
def converters():
return (
Converter_MMSimple_LLaVA_CLIPViT_HF_CS23,
Converter_MMSimple_LLaVA_LLaMA_HF_CS23,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-2.4"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_MMSimple_LLaVA_HF_CS24
def post_checkpoint_convert(
self,
input_checkpoint,
output_checkpoint,
configs: Tuple[dict, dict],
converter_indices: FormatIndices,
):
if converter_indices.direction == 0: # HF -> CS
# We are converting from HF
# to our Multimodal-Simple model. We need to create the visual token `projection`
# layer and init to default values for phase 1
# Check if there was a mapping of HF projector weights to CS namespace,
# if not, initialize defaults
is_projector_exists = any(
[
"image_model.projection" in k
for k in output_checkpoint["model"].keys()
]
)
if not is_projector_exists:
logging.info(
f"---- HF checkpoint does not have projector weight, initializing defaults"
)
cs_config = configs[1]
# im_proj_config = cs_config["model"]["projector"]["image_model"]
im_proj_config = cs_config["model"]["image_model_list"][
"global_image_projection"
]
input_unit = im_proj_config["input_unit"]
layers_units = im_proj_config["layers_units"]
use_bias = im_proj_config["use_bias"]
input_ = [input_unit] + layers_units[:-1]
output_ = layers_units
for i, (inp, out) in enumerate(zip(input_, output_)):
scale = math.sqrt(1.0 / inp)
projection_weight = torch.zeros(out, inp)
projection_weight.uniform_(-scale, scale)
output_checkpoint["model"][
f"image_model.projection.ffn.{i}.linear_layer.weight"
] = projection_weight
if use_bias:
projection_bias = torch.zeros(out)
projection_bias.uniform_(-scale, scale)
output_checkpoint["model"][
f"image_model.projection.ffn.{i}.linear_layer.bias"
] = projection_bias
super(
Converter_LLaVA_WithoutModel_HF_CS22, self
).post_checkpoint_convert(
input_checkpoint,
output_checkpoint,
configs,
converter_indices,
)
Converter_MMSimple_LLaVA_HF_CS24 = Build_HF_CS_Converter_WithOptionalModel(
"Converter_MMSimple_LLaVA_HF_CS24",
Converter_MMSimple_LLaVA_WithoutModel_HF_CS24,
derived_class=Converter_MMSimple_LLaVA_WithoutModel_HF_CS24,
)
[docs]class Converter_MMSimple_LLaVA_WithoutModel_HF_CS23(
Converter_MMSimple_LLaVA_WithoutModel_HF_CS24
):
def __init__(self):
super().__init__()
@staticmethod
def converters():
return (
Converter_MMSimple_LLaVA_CLIPViT_HF_CS23,
Converter_MMSimple_LLaVA_LLaMA_HF_CS23,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-2.3"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_MMSimple_LLaVA_HF_CS23
Converter_MMSimple_LLaVA_HF_CS23 = Build_HF_CS_Converter_WithOptionalModel(
"Converter_MMSimple_LLaVA_HF_CS23",
Converter_MMSimple_LLaVA_WithoutModel_HF_CS23,
derived_class=Converter_MMSimple_LLaVA_WithoutModel_HF_CS23,
)
[docs]class ConfigConverter_MMSimple_LLaVA_CLIPViT_HF_CS23(
ConfigConverter_CLIPViT_HF_CS21
):
def __init__(self):
super().__init__()
[docs]class ConfigConverter_MMSimple_LLaVA_LLaMa_HF_CS23(
ConfigConverter_LLaMaProjector_HF_CS22
):
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(["extra_ffn_params.*"], exists="right", action=None),
*self.rules,
]
[docs]class ConfigConverter_MMSimple_LLaVA_HF_CS24(ConfigConverter_LLaVA_HF_CS22):
def __init__(self):
super().__init__()
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-2.4"),
)
@staticmethod
def converters():
return (
ConfigConverter_MMSimple_LLaVA_CLIPViT_HF_CS23,
ConfigConverter_MMSimple_LLaVA_LLaMa_HF_CS23,
)
@staticmethod
def component_names():
return (
"image_model",
"text_model",
)
@classmethod
def save(
cls,
file_without_ext: str,
config: OrderedDict,
converter_indices: FormatIndices,
**kwargs,
) -> str:
# saving CS requires only saving once
if converter_indices.direction == 0: # HF -> CS
# Pop `image_model` dict here, since popping in post_config_convert will break the
# checkpoint_convertor
config["model"].pop("image_model")
return super().save(
file_without_ext, config, converter_indices, **kwargs
)
# saving HF requires separating encoders and saving both
else:
save_files = []
dir = os.path.dirname(file_without_ext)
for i, name in enumerate(cls.component_names()):
path = os.path.join(dir, name, "config")
print(dir)
print(path)
if not os.path.exists(os.path.join(dir, name)):
os.mkdir(os.path.join(dir, name))
if name == "text_model":
# add path to folder containing
# image model in text_model config
config[i]["mm_vision_tower"] = os.path.dirname(
save_files[0]
)
if name == "image_model":
preprocess_path = os.path.join(
dir, name, "preprocessor_config"
)
# Save preprocessor config after the dir is created
BaseConfigConverter_HF_CS.save(
preprocess_path,
cls.preprocessor_config_defaults,
converter_indices,
**kwargs,
)
save_file = BaseConfigConverter_HF_CS.save(
path, config[i], converter_indices, **kwargs
)
save_files.append(save_file)
return save_files
[docs] def pre_config_convert(
self,
model,
config,
converter_indices,
):
"""
config: List[dicts] if converter_indices = 0 (HF-> CS) else dict (CS->HF)
"""
orig_config = config
if converter_indices.direction == 1: # CS -> HF
# CS -> HF
# Move projector config into text_model config
# for CS inorder to match keys
# config["model"] = config["trainer"]["init"].pop("model")
if ("trainer" in config) and ("model" not in config):
config = config["trainer"]["init"]
projector_config = config["model"]["image_model_list"].pop(
"global_image_projection"
)
config["model"]["projector"] = {"image_model": projector_config}
image_model_list = config["model"].pop("image_model_list")
config["model"]["image_feature_select_layer_idx"] = (
image_model_list["image_models"][0]["image_encoder"][
"image_layer_idx"
]
)
config["model"]["image_feature_select_mode"] = image_model_list.pop(
"image_feature_select_mode"
)
config["model"]["image_model"] = image_model_list["image_models"][
0
]["image_encoder"]
return super().pre_config_convert(model, config, converter_indices)
def post_config_convert(
self,
model,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
):
model_config = super().post_config_convert(
model,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
)
if converter_indices.direction == 0: # HF -> CS
v_cfg = model_config["model"]["image_model"]
t_cfg = model_config["model"]["text_model"]
# # ConversionRule doesn"t work for this key -> bug in the config?
# if "embedding_dropout_rate" in t_cfg:
# t_cfg["embd_pdrop"] = t_cfg["embedding_dropout_rate"]
# del t_cfg["embedding_dropout_rate"]
# move some sub-dicts around
# LLaVA convertor post_config_convert_defaults adds `image_start_idx`,
# so we remove it here
model_config["model"].pop("image_start_idx")
v_cfg["image_layer_idx"] = model_config["model"].pop(
"image_feature_select_layer_idx"
)
model_config["model"]["image_model_list"] = {
"image_models": [{"image_encoder": v_cfg}]
}
model_config["model"]["image_model_list"][
"image_feature_select_mode"
] = model_config["model"].pop("image_feature_select_mode")
img_projector = model_config["model"].pop("projector")
img_projector = img_projector.pop("image_model")
model_config["model"]["image_model_list"][
"global_image_projection"
] = img_projector
return model_config
[docs]class ConfigConverter_MMSimple_LLaVA_HF_CS23(ConfigConverter_LLaVA_HF_CS22):
def __init__(self):
super().__init__()
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-2.3"),
)
@staticmethod
def converters():
return (
ConfigConverter_MMSimple_LLaVA_CLIPViT_HF_CS23,
ConfigConverter_MMSimple_LLaVA_LLaMa_HF_CS23,
)
@staticmethod
def component_names():
return (
"image_model",
"text_model",
)
@classmethod
def save(
cls,
file_without_ext: str,
config: OrderedDict,
converter_indices: FormatIndices,
**kwargs,
) -> str:
# saving CS requires only saving once
if converter_indices.direction == 0: # HF -> CS
# Pop `image_model` dict here, since popping in post_config_convert will break the
# checkpoint_convertor
config["model"].pop("image_model")
return super().save(
file_without_ext, config, converter_indices, **kwargs
)
# saving HF requires separating encoders and saving both
else:
save_files = []
dir = os.path.dirname(file_without_ext)
for i, name in enumerate(cls.component_names()):
path = os.path.join(dir, name, "config")
print(dir)
print(path)
if not os.path.exists(os.path.join(dir, name)):
os.mkdir(os.path.join(dir, name))
if name == "text_model":
# add path to folder containing
# image model in text_model config
config[i]["mm_vision_tower"] = os.path.dirname(
save_files[0]
)
if name == "image_model":
preprocess_path = os.path.join(
dir, name, "preprocessor_config"
)
# Save preprocessor config after the dir is created
BaseConfigConverter_HF_CS.save(
preprocess_path,
cls.preprocessor_config_defaults,
converter_indices,
**kwargs,
)
save_file = BaseConfigConverter_HF_CS.save(
path, config[i], converter_indices, **kwargs
)
save_files.append(save_file)
return save_files
[docs] def pre_config_convert(
self,
model,
config,
converter_indices,
):
"""
config: List[dicts] if converter_indices = 0 (HF-> CS) else dict (CS->HF)
"""
orig_config = config
if converter_indices.direction == 1: # CS -> HF
# CS -> HF
# Move projector config into text_model config
# for CS inorder to match keys
# config["model"] = config["trainer"]["init"].pop("model")
if ("trainer" in config) and ("model" not in config):
config = config["trainer"]["init"]
print(config["model"]["image_model_list"])
projector_config = config["model"]["image_model_list"].pop(
"global_image_projection"
)
config["model"]["projector"] = {"image_model": projector_config}
image_model_list = config["model"].pop("image_model_list")
config["model"]["image_feature_select_layer_idx"] = (
image_model_list["image_models"][0]["image_model"][0][
"image_layer_idx"
]
)
config["model"]["image_feature_select_mode"] = image_model_list.pop(
"image_feature_select_mode"
)
config["model"]["image_model"] = image_model_list["image_models"][
0
]["image_model"][0]
return super().pre_config_convert(model, config, converter_indices)
def post_config_convert(
self,
model,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
):
model_config = super().post_config_convert(
model,
original_config,
old_config,
new_config,
converter_indices,
drop_unmatched_keys,
)
if converter_indices.direction == 0: # HF -> CS
v_cfg = model_config["model"]["image_model"]
t_cfg = model_config["model"]["text_model"]
# ConversionRule doesn"t work for this key -> bug in the config?
# if "embedding_dropout_rate" in t_cfg:
# t_cfg["embd_pdrop"] = t_cfg["embedding_dropout_rate"]
# del t_cfg["embedding_dropout_rate"]
# move some sub-dicts around
# LLaVA convertor post_config_convert_defaults adds `image_start_idx`,
# so we remove it here
model_config["model"].pop("image_start_idx")
v_cfg["image_layer_idx"] = model_config["model"].pop(
"image_feature_select_layer_idx"
)
model_config["model"]["image_model_list"] = {
"image_models": [{"image_model": [v_cfg]}]
}
model_config["model"]["image_model_list"][
"image_feature_select_mode"
] = model_config["model"].pop("image_feature_select_mode")
img_projector = model_config["model"].pop("projector")
img_projector = img_projector.pop("image_model")
model_config["model"]["image_model_list"][
"global_image_projection"
] = img_projector
return model_config