Source code for cerebras.modelzoo.trainer.extensions.wandb.wandb

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains the WandbLogger class for logging metrics to Weights and Biases."""

from typing import List, Optional
from warnings import warn

import torch

from cerebras.modelzoo.trainer.loggers import Logger


[docs]class WandbLogger(Logger): """ Logger class for logging metrics to Weights and Biases. """ def __init__( self, project: Optional[str] = None, group: Optional[str] = None, run_id: Optional[str] = None, run_name: Optional[str] = None, job_type: Optional[str] = None, tags: Optional[List[str]] = None, resume: str = "auto", ): """ Args: project: The name of the project to which the run belongs. group: The name of the group to which the run belongs. run_id: The unique identifier for the run. run_name: The name of the run. job_type: The type of job. tags: List of tags to be associated with the run. resume: Resume mode for the run. It can be one of the following: - "never": Do not resume the run. - "allow": Allow the run to resume if a previous run exists. - "auto": Automatically resume the run if a previous run exists. - "must": Resume the run if a previous run exists. """ self.project = project self.group = group self.run_id = run_id self.run_name = run_name self.job_type = job_type self.tags = tags self.resume = resume def pre_setup(self, trainer): # pylint: disable=no-self-use try: # pylint: disable=unused-import import wandb # noqa except ImportError: raise RuntimeError( "wandb is an optional dependency of modelzoo. " "In order to use it, 'pip install wandb==0.16.2' into this venv" ) def finalize(self): try: import wandb if wandb.run is not None: wandb.run.finish() finally: pass
[docs] def check_presence_of_wandb_dir( # pylint: disable=no-self-use self, rundir ): """Check if the wandb directory is present in the run directory. Args: rundir: The directory where the run is being stored. """ # Ensure that the wandb directory is not already present or empty wandb_dir = rundir / "wandb" if wandb_dir.exists(): # Ensure there are no run-* folders in the wandb directory. if any( dir.is_dir() and dir.name.startswith('run-') for dir in wandb_dir.iterdir() ): raise FileExistsError( f"A previous run seems to already exist in {wandb_dir}. " "Please specify a different 'model_dir'." )
def setup(self, trainer): import wandb from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run rundir = trainer.model_dir previous_run_id = None run_files = list((rundir / "wandb").glob("run-*")) if run_files: previous_run_id = run_files[0].split('-')[-1] if self.resume == "never": if ( self.run_id is not None and previous_run_id is not None and self.run_id == previous_run_id ): raise ValueError( f"The specified run_id ({self.run_id}) matches with a " f"previous_run_id ({previous_run_id}) " "but 'never' mode requires them to be different." ) self.check_presence_of_wandb_dir(rundir) elif self.resume in ["allow", "auto"]: if self.run_id is not None and previous_run_id is not None: if self.run_id == previous_run_id: # Log into this previous run as it's the same run pass else: # Raise an error if a wandb run already exists inside the specified run dir. self.check_presence_of_wandb_dir(rundir) elif previous_run_id is not None: # No new run ID provided, so default to the previous run ID self.run_id = previous_run_id elif self.resume == "must" and previous_run_id: if self.run_id is not None and self.run_id != previous_run_id: raise ValueError( f"The specified run_id ({self.run_id}) does not match " f"previous_run_id ({previous_run_id}) " "but resume mode 'must' requires them to be the same." ) self.run_id = previous_run_id if wandb.run is None: # pylint: disable=all self.run = wandb.init( dir=rundir, job_type=self.job_type, # config=params, project=self.project, group=self.group, tags=self.tags, name=self.run_name, id=self.run_id, resume=self.resume, ) # define default x-axis if isinstance(self.run, (Run, RunDisabled)) and getattr( self.run, "define_metric", None ): self.run.define_metric("global_step") self.run.define_metric( "*", step_metric="global_step", step_sync=True ) else: self.run = wandb.run def log_metrics(self, metrics, step): m = {"global_step": step} summary = {} for name, value in metrics.items(): if isinstance(value, torch.Tensor): if value.numel() == 1: m[name] = value.item() else: warn( "Attempting to log a non-scalar tensor for {name}. " "WandB Logger does not support logging non-scalar tensors." ) elif isinstance(value, (int, float)): m[name] = value elif isinstance(value, str): summary[name] = value else: try: import pandas as pd import wandb if isinstance(value, pd.DataFrame): m[name] = wandb.Table(dataframe=value) continue except ImportError: pass warn( f"Attempting to log a {type(value)} for {name}. " f"WandB Logger does not support logging {type(value)}" ) self.run.log(m) self.run.summary.update(summary)