"""Base Metric class"""
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import wraps
from types import MethodType
import torch
import cerebras_pytorch.experimental as cstorch
from cerebras_pytorch.experimental.backend import current_backend_impl
[docs]class Metric(torch.nn.Module, ABC):
""" The abstract basemetric class """
registry = defaultdict(list)
[docs] def __init__(self, name):
super().__init__()
self.name = name
# Add the metric to the registry so that they can be referenced later
self.registry[self.name].append(self)
reset_orig = self.reset
# wrap the reset method in case the user wants to manually call it
@wraps(reset_orig)
def wrapped_reset(self):
self.registered_states = {}
# pylint: disable=not-callable
reset_orig()
backend = current_backend_impl()
with backend.device:
# need to register a module for appliance data
self.registered_states = torch.nn.ParameterDict(
self.registered_states
)
for key, value in self.registered_states.items():
self.register_parameter(key, value)
# once parameters are registered, no need to keep
# registered state module
del self.registered_states
self.reset = MethodType(wrapped_reset, self)
# pylint: disable=not-callable
self.reset()
[docs] def register_output(self, name: str):
"""
Create and register a new property with provided name that handles
fetching the tensor value when assigning to the property
Note, this means that only tensors are allowed to be set for these
properties
Args:
name: the name of the property
"""
registered_name = f"_output_{name}_value"
UNASSIGNED = object()
if name in self.__dict__:
delattr(self, registered_name)
return # already registered
def getter(self):
value = getattr(self, registered_name, UNASSIGNED)
if value is UNASSIGNED:
raise RuntimeError(
f"Attempting to retrieve {name} "
f"but it was never assigned to.\n"
f"Please make sure that {name} is assigned a "
f"torch.Tensor value before it is retrieved, e.g.\n\n"
f"\tclass {self.__class__.__name__}:\n"
f"\t\t...\n"
f"\t\tdef update(self, ...):\n"
f"\t\t\t...\n"
f"\t\t\tself.{name}: torch.Tensor = ..."
)
return value
def setter(self, value):
if not isinstance(value, torch.Tensor):
raise TypeError(
f"Expected {name} to be assigned a torch.Tensor. "
f"Got: {type(value)}"
)
elif value.numel() > 1:
raise ValueError(
f"Expected {name} to be assigned a torch Scalar. "
f"Got a torch.Tensor of size {value.size()}"
)
@cstorch.step_closure
def set_value(value):
setattr(self, registered_name, value.item())
set_value(value)
setattr(self.__class__, name, property(getter, setter))
[docs] @abstractmethod
# pylint: disable=method-hidden
def reset(self):
""" Reset the metric state """
[docs] def register_state(self, name: str, value: torch.Tensor):
"""Registers a state variable to the module"""
if not isinstance(value, torch.Tensor):
raise TypeError(f"Cannot register non-Tensor state: {type(value)}")
if not isinstance(value, torch.nn.Parameter):
value = torch.nn.Parameter(value, requires_grad=False)
self.registered_states[name] = value
[docs] @abstractmethod
def update(self, *args, **kwargs):
"""Update the metric value"""
[docs] def forward(self, *args, **kwargs):
"""Updates the metric value"""
return self.update(*args, **kwargs)
[docs] @abstractmethod
def compute(self) -> float:
"""Compute and return the final metric value"""
def __float__(self):
"""
Return the floating pointer representation of the final metric value
"""
return float(self.compute())
def _load_from_state_dict(self, state_dict, *args, **kwargs):
# No state ever needs to be loaded for metrics. Skip loading altogether
# to avoid missing keys warning
pass
class CSMetric(Metric):
"""The base metric class for metrics used on a Cerebras wafer scale cluster"""
def __init__(self, name):
super().__init__(name)
# Assign the resulting values to `self.result` in order ot make use of
# the default implementation of `compute`. But not required if compute
# is overriden with a custom implementation
self.register_output("result")
def compute(self) -> float:
"""Compute and return the final metric value"""
return self.result
[docs]def compute_all_metrics():
"""Compute the floating point value of all registered metrics"""
# TODO: Deprecate this eventually as we don't want to be keeping global
# state if we can help it
return {
f"{name}.{i}" if len(metric_list) > 1 else name: float(metric)
for name, metric_list in Metric.registry.items()
for i, metric in enumerate(metric_list)
}