# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base implementation for config classes with helper modules
"""
import ast
import logging
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import List, Literal, Optional  # pylint: disable=W0611
import yaml
from cerebras.modelzoo.common.registry import registry  # no qa
# pylint: disable=wildcard-import
from cerebras.modelzoo.config_manager.config_validators import *
from typing import (  # noqa
    Any,
    Callable,
    Union,
    get_args,
)
# Alias required as an empty object to check for values that are mandatory and not provided
required = object()
[docs]def config_field(
    default: Any = required,
    constraint: Optional[Callable[..., Any]] = None,
):
    """
    Custom field definition for config to abstract metadata usage
    Args:
        default: Default value expected for the field
        constraint: The function to be invoked to set constraints to the parameter value
    """
    metadata = {
        "constraints": constraint,
    }
    return field(default=default, metadata=metadata) 
[docs]def is_union_type_hint(type_hint):
    """Check if Union type"""
    return get_origin(type_hint) is Union 
[docs]def get_member_type_hints(cls):
    """
    Iterates over all members of a class and extracts their type hints.
    Args:
        cls: The class to iterate over.
    Returns:
        A dictionary mapping member names to their corresponding type hints.
    """
    class_fields = fields(cls)
    type_hints = {}
    # Extracting type information
    for class_field in class_fields:
        type_hint = class_field.type
        type_hints[class_field.name] = get_args(type_hint) or type_hint
    return type_hints 
[docs]def set_constraint(current_constraint, updated_constraint):
    """Set the required constraint type if not already set"""
    if current_constraint[0] is not None:
        logging.warning(
            "Trying to select constraint for config implicitly, more than one valid type exists"
        )
    current_constraint[0] = updated_constraint 
[docs]@dataclass
class BaseConfig:
    """This class represents a Base Model config, inherited by sub config classes"""
    def to_yaml(self, file_path):
        """
        This method writes the config to a yaml file
        Args:
            file_path: The path of output yaml file
        """
        with open(file_path, "w") as file:
            yaml.dump(asdict(self), file)
    def validate(self):
        """Validation method that iterates over class members with validation meta attached"""
        type_hints = get_member_type_hints(self)
        class_fields = fields(self)
        # Iterate over all class attributes and call their validations
        for class_field in class_fields:
            field_name = class_field.name
            field_value = getattr(self, field_name)
            if field_name != "validate" and hasattr(self, field_name):
                curr_field = getattr(self, field_name)
                # Check if the field is an instance of a child class
                if isinstance(curr_field, BaseConfig):
                    # If it's an instance of a child class, recursively call its validate method
                    curr_field.validate()
                    type_hints_child = get_member_type_hints(curr_field)
                    type_hints.update(type_hints_child)
                field_meta = class_field.metadata
                constraint = None
                # Get the implicit constraint if we have one set explicitly
                if "constraints" in field_meta:
                    constraint = field_meta['constraints']
                # Check if all the mandatory params received a value
                if field_value is required:
                    raise ValueError(
                        f"required value for {field_name}, which is mandatory and must be set"
                    )
                # Check if the field is maked optional
                is_optional = is_union_type_hint(class_field.type) and type(
                    None
                ) in get_args(class_field.type)
                if not is_optional and field_value is None:
                    raise ValueError(
                        f"None value for {field_name}, which is not of optional type"
                    )
                # If there is a custom validation logic attached use that
                if constraint is not None and field_value is not None:
                    if constraint(field_value) is False:
                        raise ValueError(
                            f"value for {field_name}, does not match the constraint"
                        )
                elif field_value is not None:
                    # If its a valid value, check for type based validation
                    validate_field_type(class_field, field_value)
    def set_class_type(self, field_name, class_type, field_value):
        """
        Set the field to class type instance
        It calls the constructor of the class object type
        Init params are the same as the dict/list we get.
        The typecase will fail in class init if the param list
        doesnt match the class signature
        """
        if isinstance(field_value, str):
            field_dict = ast.literal_eval(field_value)
            setattr(self, field_name, class_type(**field_dict))
        elif isinstance(field_value, list):
            for value in field_value:
                setattr(self, field_name, class_type(**value))
        elif isinstance(field_value, dict):
            field_dict = field_value
            setattr(self, field_name, class_type(**field_dict))
        elif not is_dataclass(field_value) and not isinstance(
            field_value, dict
        ):
            logging.error(
                f"We got a config class initialization with invalid type {type(field_value)}"
            )
    def __post_init__(self):
        """
        Post init runs through the class object and creates sub-class objects
        from dict type initializations
        """
        for curr_field in fields(self):
            field_name = curr_field.name
            field_type = curr_field.type
            field_value = getattr(self, field_name)
            # Check if the field type is a Union
            if get_origin(field_type) is Union:
                for union_type in get_args(field_type):
                    if is_dataclass(union_type) and field_value is not None:
                        self.set_class_type(
                            field_name=field_name,
                            class_type=union_type,
                            field_value=field_value,
                        )
                        break
            elif is_dataclass(field_type):
                if field_value is not None:
                    self.set_class_type(
                        field_name=field_name,
                        class_type=field_type,
                        field_value=field_value,
                    )