# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Base class for implementing metrics."""
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import wraps
from types import MethodType
from typing import Any
import torch
import cerebras_pytorch as cstorch
class _empty:
    """An empty object to use as a sentinel for caching metric results."""
[docs]class Metric(torch.nn.Module, ABC):
    """Base class for implementing metrics compatible with Cerebras WSC.
    This class is designed to be used as a base class for implementing metrics
    compatible with Cerebras Wafer-Scale Cluster, but they also work with CPU
    and GPU backends.
    To implement a new metric, subclass `Metric` and implement the following:
    - `reset`: This is to initialize the metric state.
    - `update`: This is to update the metric state at every iteration.
    - `compute`: This is to compute the final metric value based on the state.
    To use metrics, instantiate them and call them with the appropriate inputs.
    For example:
        >>> metric = MyMetric()
        >>> metric(input_1, input_2)  # Calls update and compute
        >>> metric.compute()  # Returns the final (cached) metric value
    """
    # Keep a registry of all metrics so we can reference them in the backend
    registry = defaultdict(list)
[docs]    def __init__(self, name: str):
        """Constructs a `Metric` instance.
        Args:
            name: The name of the metric. This is used to reference the metric
                and does not have to be unique.
        """
        super().__init__()
        self.name = name
        # Keeps track of total number of times the metric was updated
        self._num_updates = 0
        # Add the metric to the registry
        self.registry[self.name].append(self)
        # Cached result of the last compute call
        self._cached_result = _empty
        # Wrap reset, update and compute methods
        self._wrap_reset()
        self._wrap_update()
        self._wrap_compute()
        # Call reset to initialize metric state
        self.reset() 
    @property
    def num_updates(self) -> int:
        """Returns the number of times the metric was updated."""
        return self._num_updates
[docs]    @abstractmethod
    def reset(self) -> None:
        """Resets the metric state.""" 
[docs]    @abstractmethod
    def update(self, *args, **kwargs) -> None:
        """Updates the metric state.""" 
[docs]    @abstractmethod
    def compute(self) -> Any:
        """Computes and returns the current metric value.""" 
[docs]    def register_state(
        self, name: str, tensor: torch.Tensor, persistent: bool = False
    ) -> None:
        """Registers a state variable to the module.
        By default, metric state variables are non-persistent buffers that
        are not included in the module's state dictionary. To have them as part
        of the state dictionary, set `persistent=True`.
        Once registered, the state variable can be accessed as an attribute
        on the module by the given name.
        Args:
            name: The name of the state variable.
            tensor: The tensor to register.
            persistent: Whether this state is part of the module's `state_dict`.
        """
        self.register_buffer(name, tensor, persistent=persistent) 
[docs]    def forward(self, *args, **kwargs) -> Any:
        """Updates and computes the metric value."""
        self.update(*args, **kwargs)
        return self.compute() 
    def _wrap_reset(self) -> None:
        """Wraps the update method to clear the cache before running update."""
        reset = self.reset
        @wraps(reset)
        def wrapped_reset(self: Metric, *args, **kwargs):
            reset_result = reset(*args, **kwargs)
            self._num_updates = 0
            return reset_result
        self.reset = MethodType(wrapped_reset, self)
    def _wrap_update(self) -> None:
        """Wraps the update method to clear the cache before running update."""
        update = self.update
        @wraps(update)
        def wrapped_update(self: Metric, *args, **kwargs):
            self._cached_result = _empty
            update_result = update(*args, **kwargs)
            self._num_updates += 1
            return update_result
        self.update = MethodType(wrapped_update, self)
    def _wrap_compute(self) -> None:
        """Wraps the compute method to cache the result."""
        compute = self.compute
        @wraps(compute)
        def wrapped_compute(self: Metric):
            if self._cached_result is not _empty:
                return self._cached_result
            # Cache the result (which could be a lazy tensor) so if we call
            # compute() again without an explicit update(), we don't recompute
            # unnecessarily. User should not be explicitly calling compute()
            # or update() anyway, so this is mostly a sanity check. Instead,
            # they should be calling metric instance as a function which will
            # both update and compute.
            self._cached_result = compute()
            @cstorch.step_closure
            def cache_result(r):
                # Once the step closure runs, the cached result is no longer
                # a lazy tensor and has been materialized to a CPU value.
                # This allows further `compute()` calls with no `update()`
                # to return the cached result without unnecessary re-computing.
                self._cached_result = r
            cache_result(self._cached_result)
            return self._cached_result
        self.compute = MethodType(wrapped_compute, self)
    def __float__(self) -> float:
        """Returns the floating point representation of the metric value."""
        return float(self.compute()) 
[docs]def compute_all_metrics():
    """Compute the floating point value of all registered metrics."""
    # TODO: Deprecate this eventually as we don't want to be keeping global
    #       state if we can help it
    return {
        f"{name}.{i}" if len(metric_list) > 1 else name: float(metric)
        for name, metric_list in Metric.registry.items()
        for i, metric in enumerate(metric_list)
    }