Source code for cerebras.modelzoo.trainer.callbacks.model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Model Callback class."""

from typing import Callable, Union
from warnings import warn

import torch

import cerebras.pytorch as cstorch
from cerebras.modelzoo.trainer.callbacks import CoreCallback


[docs]class ModelCallback(CoreCallback): """Callback class that handles setting up and compiling the model.""" def __init__( self, model: Union[Callable[[], torch.nn.Module], torch.nn.Module], ): """ Args: model: The model to train. It must be one of the following: - If a callable is passed, it is assumed to be a function that takes in no arguments returns a torch.nn.Module. - If a torch.nn.Module is passed, it is used as is. """ self.model = model def setup(self, trainer): if callable(self.model) and not isinstance(self.model, torch.nn.Module): with trainer.backend.device: trainer.model = self.model() elif isinstance(self.model, torch.nn.Module): trainer.model = self.model else: raise ValueError( f"Expected model to be a torch.nn.Module or a callable that " f"returns a torch.nn.Module, but got {type(self.model)}." ) trainer.compiled_model = cstorch.compile(trainer.model, trainer.backend) def on_train_start(self, trainer, model, train_dataloader, loop, loop_idx): model.train() def on_validate_start(self, trainer, model, val_dataloader, loop): model.eval() def on_before_backward(self, trainer, model, outputs): if "loss" not in outputs: raise ValueError( f"Expected the {model.__class__.__name__}'s forward call to " f"return a dictionary with a 'loss' key. " f"Got: {sorted(outputs.keys())}" ) def on_save_checkpoint(self, trainer, state_dict): state_dict["model"] = trainer.model.state_dict() def on_load_checkpoint(self, trainer, state_dict): if "model" not in state_dict: warn( f"Checkpoint does not contain a model state dict. " f"Model state was not loaded" ) else: # This check is required for backward compatibility with checkpoints # saved with older versions of ModelZoo (pre rel-2.0.0) # We check that the model state dict keys start with "model." # and if they don't, we load the state dict into the model's model if hasattr(trainer.model, "model") and not all( k.startswith("model.") for k in state_dict["model"].keys() ): trainer.model.model.load_state_dict( state_dict["model"], strict=not trainer.checkpoint.disable_strict_checkpoint_loading, ) # This should be the case that is used for all checkpoints saved # post rel-2.0.0 else: trainer.model.load_state_dict( state_dict["model"], strict=not trainer.checkpoint.disable_strict_checkpoint_loading, ) trainer.logger.info( f"Model state found in checkpoint and loaded successfully." )