# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
[docs]def set_defaults(params):
    """
    Update any missing parameters in the params dictionary with default values
    Args:
        params: The dictionary containing the params
    """
    for section in ["train_input", "eval_input"]:
        for key in ["vocab_file"]:
            if params.get(section, {}).get(key):
                params[section][key] = os.path.abspath(params[section][key])
    model_params = params["model"]
    params["model"]["disable_nsp"] = model_params.get("disable_nsp", False)
    # Pass settings into data loader.
    for model_key in (
        "disable_nsp",
        "vocab_size",
        "mixed_precision",
    ):
        for input_key in ("train_input", "eval_input"):
            params[input_key][model_key] = model_params.get(model_key)
    params["model"]["max_position_embeddings"] = model_params.get(
        "max_position_embeddings", params["train_input"]["max_sequence_length"],
    )
    params["model"]["to_float16"] = model_params.get("to_float16", False)
    params["model"]["fp16_type"] = model_params.get("fp16_type", "float16")
    params["optimizer"]["log_summaries"] = params["optimizer"].get(
        "log_summaries", False
    )
    # Attention softmax is fp32 by default.
    params["model"]["attention_softmax_fp32"] = True
    # Attention softmax is bf16 for precision_opt_level: 2
    if params["runconfig"].get("precision_opt_level", 1) == 2:
        params["model"]["attention_softmax_fp32"] = False
    if (
        params["model"].get("fp16_type", "bfloat16") == "cbfloat16"
        and params["runconfig"].get("precision_opt_level", 1) == 1
    ):
        params["model"]["attention_softmax_fp32"] = False 
[docs]def check_unused_model_params(model_params):
    """
    While setting up the model, we pop used settings from model_params.
    This function sends a warning about any unused parameters.
    """
    model_params.pop("to_float16", None)
    model_params.pop("mixed_precision", None)
    unused_params = [
        key for key in model_params.keys() if key not in ["fp16_type"]
    ]
    if unused_params:
        logging.warning(
            "The following model params are unused: " + ", ".join(unused_params)
        )