Source code for cerebras.modelzoo.config.base_config

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import ClassVar, final
from warnings import warn

from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

_GLOBAL_CONFIG = None


[docs]class BaseConfig(BaseModel): """ Base class for all config classes in the ModelZoo. This class is a thin wrapper around Pydantic's BaseModel class that adds some additional functionality and constraints to the models. """ __orig_class__: ClassVar = None model_config = ConfigDict( # Forbid extra attributes during config initialization extra="forbid", # Make the model faux-immutable frozen=True, # Allow arbitrary types for field types arbitrary_types_allowed=True, # Always revalidate models during validation revalidate_instances="always", # Always validate default values during validation validate_default=True, # Always validate the return value from call validators validate_return=True, # Always validate assignments validate_assignment=True, # Use attribute docstrings for field descriptions use_attribute_docstrings=True, # Allow populating the model with the field names as well as the aliases populate_by_name=True, ) def __call__(self, **kwargs): """ Construct the original class with the current config. By original class, we mean the class that this config class is associated with. """ if self.__orig_class__ is None: raise TypeError( f"{self.__class__.__name__} has no original class associated with it. " ) u = self.get_orig_class_args(**kwargs) try: return self.__orig_class__(**u) except: from rich import print as pprint pprint(u) raise
[docs] @classmethod def get_orig_class(cls): return cls.__orig_class__
[docs] def get_orig_class_args(self, **kwargs) -> dict: fields = { k: kwargs.get(k, getattr(self, k)) for k, f in self.model_fields.items() } if self.model_extra and self.model_config.get("include_extra", True): fields.update(self.model_extra) return fields
[docs] def copy(self, *, validate: bool = True, **kwargs) -> Self: copy = super().model_copy(**kwargs) if "update" in kwargs and validate: return self.model_validate(copy) return copy
[docs] def model_copy(self, **kwargs) -> Self: if kwargs.get("update"): warn( "model_copy() does not validate the data before creating the new model. " "If you trust this data, then ignore this warning. " "Otherwise, please call copy() instead." ) return super().model_copy(**kwargs)
@model_validator(mode="before") @classmethod def check_for_deprecated_fields(cls, data): if deprecated_fields := [ name for name, field in cls.model_fields.items() if field.deprecated and name in data ]: warn( f"Found deprecated fields for {cls.__name__}: " f"{sorted(deprecated_fields)}\n" f"Support for passing these fields in will be removed in the future.", category=FutureWarning, ) return data
[docs] def post_init(self, context): pass
[docs] @final def model_post_init(self, context): if context is None: context = {} orig = self.model_config["frozen"] try: self.model_config["frozen"] = False self.post_init(context) finally: self.model_config["frozen"] = orig