# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Utility functions for writing scalars and tensors to event files"""
import logging
from typing import Union
from warnings import warn
import torch
from ...backend import current_backend_impl
from ..step_closures import step_closure
[docs]def summarize_scalar(name: str, scalar: Union[int, float, torch.Tensor]):
    """
    Save the scalar to the event file of the writer specified in the data
    executor
    Args:
        name: the key to save the scalar in the event file
        scalar: the scalar value to summarize.
            Note, if a torch.Tensor is provided, it must be a scalar tensor
            for which scalar.item() can be called
    """
    if not isinstance(scalar, (int, float, torch.Tensor)):
        raise TypeError(
            f"Expected int, float, or torch.Tensor for scalar summary "
            f"but got: {type(scalar)}"
        )
    if isinstance(scalar, (int, float)):
        backend = current_backend_impl(raise_exception=False)
        if backend and not backend.supports_multi_tracing:
            raise RuntimeError(
                "Passing a Python int or float scalar is not supported "
                "for the current backend. "
                "Only passing in a scalar torch.Tensor is supported"
            )
    if isinstance(scalar, torch.Tensor) and scalar.numel() != 1:
        raise ValueError(
            f"Expected tensor to be a scalar but tensor has size: "
            f"{scalar.size()}"
        )
    @step_closure
    def scalar_summary(name, writer, **kwargs):
        scalar = kwargs[name]
        if isinstance(scalar, torch.Tensor):
            scalar = scalar.item()
        writer.add_scalar(
            name, scalar, writer.base_step + backend.run_context.iteration
        )
    backend = current_backend_impl()
    writer = backend.run_context.writer
    if writer:
        scalar_summary(name, writer, **{name: scalar})
    else:
        logging.warning(
            f"Scalar summary for `{name}` was not saved as no SummaryWriter "
            f"was provided. "
        )
        warn(
            f"To enable writing scalar summaries, please pass in a "
            f"SummaryWriter object to the DataExecutor, e.g.\n\n"
            f"\twriter = cstorch.utils.tensorboard.SummaryWriter(...)"
            f"\texecutor = cstorch.utils.data.DataExecutor(..., writer=writer)"
        ) 
[docs]def summarize_tensor(name: str, tensor: torch.Tensor):
    """
    Save the tensor to the event file of the writer specified in the data
    executor
    Args:
        name: the key to save the tensor in the event file
        tensor: the torch.Tensor to summarize
    """
    if not isinstance(tensor, torch.Tensor):
        raise TypeError(
            f"Expected torch.Tensor for tensor summary but got: {type(tensor)}"
        )
    @step_closure
    def tensor_summary(name, writer, **kwargs):
        tensor = kwargs[name]
        writer.add_tensor(
            name,
            tensor.detach(),
            step=writer.base_step + backend.run_context.iteration,
        )
    backend = current_backend_impl()
    writer = backend.run_context.writer
    if writer:
        tensor_summary(name, writer, **{name: tensor})
    else:
        logging.warning(
            f"Tensor summary for `{name}` was not saved as no SummaryWriter "
            f"was provided. "
        )
        warn(
            f"To enable writing tensor summaries, please pass in a "
            f"SummaryWriter object to the DataExecutor, e.g.\n\n"
            f"\twriter = cstorch.utils.tensorboard.SummaryWriter(...)"
            f"\texecutor = cstorch.utils.data.DataExecutor(..., writer=writer)"
        )