Source code for cerebras.modelzoo.config_manager.config_classes.base.base_config

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base implementation for config classes with helper modules

"""

import ast
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import List, Literal, Optional  # pylint: disable=W0611

import yaml

from cerebras.modelzoo.common.registry import registry  # no qa

# pylint: disable=wildcard-import
from cerebras.modelzoo.config_manager.config_validators import *

from typing import (  # noqa
    Any,
    Callable,
    Union,
    get_args,
)


# Alias required as an empty object to check for values that are mandatory and not provided
required = object()


[docs]def config_field( default: Any = required, constraint: Optional[Callable[..., Any]] = None, ): """ Custom field definition for config to abstract metadata usage Args: default: Default value expected for the field constraint: The function to be invoked to set constraints to the parameter value """ metadata = { "constraints": constraint, } return field(default=default, metadata=metadata)
[docs]def is_union_type_hint(type_hint): """Check if Union type""" return get_origin(type_hint) is Union
[docs]def get_member_type_hints(cls): """ Iterates over all members of a class and extracts their type hints. Args: cls: The class to iterate over. Returns: A dictionary mapping member names to their corresponding type hints. """ class_fields = fields(cls) type_hints = {} # Extracting type information for class_field in class_fields: type_hint = class_field.type type_hints[class_field.name] = get_args(type_hint) or type_hint return type_hints
[docs]def set_constraint(current_constraint, updated_constraint): """Set the required constraint type if not already set""" if current_constraint[0] is not None: logging.warning( "Trying to select constraint for config implicitly, more than one valid type exists" ) current_constraint[0] = updated_constraint
[docs]@dataclass class BaseConfig: """This class represents a Base Model config, inherited by sub config classes""" def to_yaml(self, file_path): """ This method writes the config to a yaml file Args: file_path: The path of output yaml file """ with open(file_path, "w") as file: yaml.dump(asdict(self), file) def validate(self): """Validation method that iterates over class members with validation meta attached""" type_hints = get_member_type_hints(self) class_fields = fields(self) # Iterate over all class attributes and call their validations for class_field in class_fields: field_name = class_field.name field_value = getattr(self, field_name) if field_name != "validate" and hasattr(self, field_name): curr_field = getattr(self, field_name) # Check if the field is an instance of a child class if isinstance(curr_field, BaseConfig): # If it's an instance of a child class, recursively call its validate method curr_field.validate() type_hints_child = get_member_type_hints(curr_field) type_hints.update(type_hints_child) field_meta = class_field.metadata constraint = None # Get the implicit constraint if we have one set explicitly if "constraints" in field_meta: constraint = field_meta['constraints'] # Check if all the mandatory params received a value if field_value is required: raise ValueError( f"required value for {field_name}, which is mandatory and must be set" ) # Check if the field is maked optional is_optional = is_union_type_hint(class_field.type) and type( None ) in get_args(class_field.type) if not is_optional and field_value is None: raise ValueError( f"None value for {field_name}, which is not of optional type" ) # If there is a custom validation logic attached use that if constraint is not None and field_value is not None: if constraint(field_value) is False: raise ValueError( f"value for {field_name}, does not match the constraint" ) elif field_value is not None: # If its a valid value, check for type based validation validate_field_type(class_field, field_value) def set_class_type(self, field_name, class_type, field_value): """ Set the field to class type instance It calls the constructor of the class object type Init params are the same as the dict/list we get. The typecase will fail in class init if the param list doesnt match the class signature """ if isinstance(field_value, str): field_dict = ast.literal_eval(field_value) setattr(self, field_name, class_type(**field_dict)) elif isinstance(field_value, list): for value in field_value: setattr(self, field_name, class_type(**value)) elif isinstance(field_value, dict): field_dict = field_value setattr(self, field_name, class_type(**field_dict)) elif not is_dataclass(field_value) and not isinstance( field_value, dict ): logging.error( f"We got a config class initialization with invalid type {type(field_value)}" ) def __post_init__(self): """ Post init runs through the class object and creates sub-class objects from dict type initializations """ for curr_field in fields(self): field_name = curr_field.name field_type = curr_field.type field_value = getattr(self, field_name) # Check if the field type is a Union if get_origin(field_type) is Union: for union_type in get_args(field_type): if is_dataclass(union_type) and field_value is not None: self.set_class_type( field_name=field_name, class_type=union_type, field_value=field_value, ) break elif is_dataclass(field_type): if field_value is not None: self.set_class_type( field_name=field_name, class_type=field_type, field_value=field_value, )