Source code for common.pytorch.loss_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module which provides utilities for aggregating and saving losses"""

from typing import Optional

import torch
from torch.utils.tensorboard import SummaryWriter

from modelzoo import CSOFT_PACKAGE, CSoftPackage
from modelzoo.common.pytorch import cb_model as cm


[docs]class LossSaver: """Helper class for storing losses during training/eval."""
[docs] def __init__(self, writer: Optional[SummaryWriter] = None): """Constructs a `LossSaver` instance. Args: writer: Tensorboard summary writer for writing losses. """ self.writer = writer self._name = "loss_reduce" self._last_saved_loss = -1 self._total_loss = 0 self._total_size = 0
[docs] def add(self, loss: torch.Tensor, step: int, epoch: int = None): """Store loss value. This method will reduce losses across workers. Args: loss: The loss tensor whose value will be stored. step: Global step at which loss was computed. epoch: The current epoch. """ if cm.use_cs(): # Get mean loss across workers using a mesh reduce reduced = cm.mesh_reduce(self._name, loss.item(), self.mean_reduce) self._last_saved_loss = reduced cm.write_to_summary( self.writer, global_step=step, dict_to_write={"loss": reduced} ) elif self.writer: self._last_saved_loss = loss.item() scalar_name = "loss" if epoch is None else f"loss (epoch {epoch})" self.writer.add_scalar(scalar_name, self._last_saved_loss, step)
[docs] def accumulate(self, loss: torch.Tensor): """Accumulates loss values. This method will reduce losses across workers and update a total_loss Args: loss: The loss tensor whose value will be stored. """ if self._last_saved_loss != -1: self._total_loss += self._last_saved_loss self._last_saved_loss = -1 elif cm.use_cs(): # Get mean loss across workers using a mesh reduce reduced = cm.mesh_reduce(self._name, loss.item(), self.mean_reduce) self._total_loss += reduced else: self._total_loss += loss.item() self._total_size += 1
[docs] def clear(self): """Clears the total_loss value""" self._last_saved_loss = -1 self._total_loss = 0 self._total_size = 0
@property def total_loss(self) -> float: """Return the total accumulated loss""" return self._total_loss @property def average_loss(self) -> float: """Return the total accumulated loss""" if self._total_size == 0: return float("nan") return self._total_loss / float(self._total_size)
[docs] @staticmethod def mean_reduce(vals: list): """Apply mean reduction over values. Args: vals: List of values to apply mean reduction over. Returns: The mean reduction of values. """ return sum(vals) / len(vals)