# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility to load the config from yaml or .py config file
"""
import importlib
import importlib.util
import logging
import os
from dataclasses import asdict
import yaml
from cerebras.modelzoo.common.registry import registry
# TODO : This is a bit ugly and should be removed.
# This is an ignore list of some QA params that we find in our configs.
# These are injected via test params or similar before run.py call so we find them here.
# If we find these, we ignore them for now.
ignore_list = [
'verify_determinism',
'batch_size',
'use_fake_data',
'disable_convergence_checks',
]
[docs]def flatten_sparsity_params(kwargs):
"""
Config classes package sparsity related params in a sub dict.
ALthough, if we use native yaml config, they come unrolled.
This utility unwraps the sparsity related params(if present)
into an unroller sparsity param dict for consistency.
Args:
kwargs : Input args
Returns:
Flattened dict
"""
if isinstance(kwargs, (int, float, list, tuple)):
return kwargs
if 'groups' in kwargs:
kwargs = kwargs.pop('groups', {})
else:
return kwargs # No need to flatten if no groups present
if isinstance(kwargs, dict):
additional_dict = kwargs.pop('params', {})
flattened_dict = kwargs.copy()
for key, value in additional_dict.items():
new_key = f"{key}"
flattened_dict[new_key] = value
return flattened_dict
elif isinstance(kwargs, list):
param_list = []
for param in kwargs:
additional_dict = param.pop('params', {})
flattened_dict = param.copy()
for key, value in additional_dict.items():
new_key = f"{key}"
flattened_dict[new_key] = value
param_list.append(flattened_dict)
return param_list
else:
return kwargs
[docs]def flatten_optimizer_params(kwargs):
"""
Config classes package optimizer related params in a sub dict.
ALthough, if we use native yaml config, they come unrolled.
This utility unwraps the optimizer related params(if present)
into an unroller optimizer param dict for consistency.
Args:
kwargs : Input args dict
Returns:
flattened_args: Flattened dict
"""
additional_dict = kwargs.pop('optim_params', {})
flattened_dict = kwargs.copy()
for key, value in additional_dict.items():
new_key = f"{key}"
flattened_dict[new_key] = value
return flattened_dict
[docs]def process_config(config, config_class, params_conf):
"""
Perform config mapping and validation
Args:
config: The config class object
config_class: The clas the config belongs to
params_conf: Dictionary of params
"""
# Disabled by default, enable for internal test trains
# that might have some unused params left to be cleaned
allow_config_class_validation_failures = int(
os.environ.get('CONFIG_CLASS_VALIDATION_FALLBACK', 0)
)
if allow_config_class_validation_failures == 1:
logging.info(
"Config class validation failure fallback is enabled for the run which is intended for"
"internal runs only, unset CONFIG_CLASS_VALIDATION_FALLBACK env variable if you want a"
"strict validaiton enforced for configs using config class"
)
try:
config = config_class(**params_conf)
config.validate()
except Exception as e: # pylint: disable=broad-except
logging.warning(
f"CONFIG WARNING: Falling back to default flow because of config class error: {e}"
"config could not be validated via config class, proceed if this is expected"
)
# invalidate the config class object
config = None
else:
try:
config = config_class(**params_conf)
config.validate()
except Exception as e: # pylint: disable=broad-except
raise ValueError(
f"CONFIG ERROR : Invalid param configuration supplied. Please fix error : {e} or "
"contact Cerebras support"
)
return config
[docs]def validate_config_params(params_conf, model_name):
"""
Load the config class and run validation
check on the config based on parameter constraints
Args:
params_conf: The config params passed as a dict
model_name: The model key name used by config map to check what class of config to use
"""
# Pop the description field as we dont have a corresponding config class member
if "description" in params_conf:
descr = params_conf["description"]
logging.info(f"Loading config : {descr}")
params_conf.pop("description")
ignored_keys = {}
# Pop the keys to be ignored and store them in a separate dictionary
for key in ignore_list:
if key in params_conf:
ignored_keys[key] = params_conf.pop(key)
logging.warning(
f"CONFIG WARNING: Config class ignored usage of param {key} ."
"Please note this type of usage is not permitted, please modify your config."
)
config_class = registry.get_config_class(model_name)
config = None
if config_class is not None:
logging.info(f"Loading config class : {config_class}")
config = process_config(config, config_class, params_conf)
logging.info(f"Config has been validated using config class")
else:
# TODO: Add an error comment here once the config classes are ready and implemented.
# For now silently default to old path
logging.warning(
f"Config loaded using yaml path without using config class"
)
if config:
params = asdict(config)
else:
params = params_conf
if params.get("sparsity") is not None:
params["sparsity"] = flatten_sparsity_params(params["sparsity"])
if params.get("optimizer") is not None:
params["optimizer"] = flatten_optimizer_params(params["optimizer"])
# Insert the ignored keys back into the dictionary
params.update(ignored_keys)
return params
[docs]def get_config_from_yaml(yaml_path, model_name):
"""
Get the config object after reading input yaml. Also runs validation
check on the config based on parameter constraints
Args:
yaml_path: The path to the config yaml file
model_name: The model key name used by config map to check what class of config to use
"""
with open(yaml_path, 'r') as file:
params_conf = yaml.safe_load(file)
config_class = registry.get_config_class(model_name)
config = None
if config_class is not None:
logging.info(f"Loading config class{config_class} from yaml path")
config = process_config(config, config_class, params_conf)
logging.info(
f"Config {config} has been validated using config class from {yaml_path}"
)
else:
# TODO: Add an error comment here once the config classes are ready and implemented.
# For now silently default to old path
logging.warning(
f"Config loaded using yaml path without using config class"
)
params = {}
if config:
params = asdict(config)
else:
params = params_conf
return params
[docs]def read_from_config_class_file(file_path, function_name):
"""
Read the config class file and call the config generator function to get config object
Args:
file_path: The path to the config .py file
function_name: The config creation function defined in the config class creator .py file
"""
# Dynamically import the module fromt he config class file
config_module_name = os.path.splitext(os.path.basename(file_path))[0]
spec = importlib.util.spec_from_file_location(config_module_name, file_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
# Get the config creation function from the module
get_config = getattr(config_module, function_name, None)
if get_config and callable(get_config):
# Override the existing function with the same name to make sure we call the correct one
locals()[function_name] = get_config
# Call the imported function
return get_config()
else:
logging.warning(
f"Could not find a valid config class creator in {file_path},"
)
return None
[docs]def get_config_from_class(config_class_file):
"""
Read the config class file and returns a config object
Args:
config_class_file: The path to the config .py file
"""
config = read_from_config_class_file(
config_class_file, "get_variant_config"
)
if config is not None:
config.validate()
logging.info(
f"Config {config_class_file} has been validated using config class"
)
return config