Source code for experimental.metrics.metric

"""Base Metric class"""
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import wraps
from types import MethodType

import torch

import cerebras_pytorch.experimental as cstorch
from cerebras_pytorch.experimental.backend import current_backend_impl


[docs]class Metric(torch.nn.Module, ABC): """ The abstract basemetric class """ registry = defaultdict(list)
[docs] def __init__(self, name): super().__init__() self.name = name # Add the metric to the registry so that they can be referenced later self.registry[self.name].append(self) reset_orig = self.reset # wrap the reset method in case the user wants to manually call it @wraps(reset_orig) def wrapped_reset(self): self.registered_states = {} # pylint: disable=not-callable reset_orig() backend = current_backend_impl() with backend.device: # need to register a module for appliance data self.registered_states = torch.nn.ParameterDict( self.registered_states ) for key, value in self.registered_states.items(): self.register_parameter(key, value) # once parameters are registered, no need to keep # registered state module del self.registered_states self.reset = MethodType(wrapped_reset, self) # pylint: disable=not-callable self.reset()
[docs] def register_output(self, name: str): """ Create and register a new property with provided name that handles fetching the tensor value when assigning to the property Note, this means that only tensors are allowed to be set for these properties Args: name: the name of the property """ registered_name = f"_output_{name}_value" UNASSIGNED = object() if name in self.__dict__: delattr(self, registered_name) return # already registered def getter(self): value = getattr(self, registered_name, UNASSIGNED) if value is UNASSIGNED: raise RuntimeError( f"Attempting to retrieve {name} " f"but it was never assigned to.\n" f"Please make sure that {name} is assigned a " f"torch.Tensor value before it is retrieved, e.g.\n\n" f"\tclass {self.__class__.__name__}:\n" f"\t\t...\n" f"\t\tdef update(self, ...):\n" f"\t\t\t...\n" f"\t\t\tself.{name}: torch.Tensor = ..." ) return value def setter(self, value): if not isinstance(value, torch.Tensor): raise TypeError( f"Expected {name} to be assigned a torch.Tensor. " f"Got: {type(value)}" ) elif value.numel() > 1: raise ValueError( f"Expected {name} to be assigned a torch Scalar. " f"Got a torch.Tensor of size {value.size()}" ) @cstorch.step_closure def set_value(value): setattr(self, registered_name, value.item()) set_value(value) setattr(self.__class__, name, property(getter, setter))
[docs] @abstractmethod # pylint: disable=method-hidden def reset(self): """ Reset the metric state """
[docs] def register_state(self, name: str, value: torch.Tensor): """Registers a state variable to the module""" if not isinstance(value, torch.Tensor): raise TypeError(f"Cannot register non-Tensor state: {type(value)}") if not isinstance(value, torch.nn.Parameter): value = torch.nn.Parameter(value, requires_grad=False) self.registered_states[name] = value
[docs] @abstractmethod def update(self, *args, **kwargs): """Update the metric value"""
[docs] def forward(self, *args, **kwargs): """Updates the metric value""" return self.update(*args, **kwargs)
[docs] @abstractmethod def compute(self) -> float: """Compute and return the final metric value"""
def __float__(self): """ Return the floating pointer representation of the final metric value """ return float(self.compute()) def _load_from_state_dict(self, state_dict, *args, **kwargs): # No state ever needs to be loaded for metrics. Skip loading altogether # to avoid missing keys warning pass
class CSMetric(Metric): """The base metric class for metrics used on a Cerebras wafer scale cluster""" def __init__(self, name): super().__init__(name) # Assign the resulting values to `self.result` in order ot make use of # the default implementation of `compute`. But not required if compute # is overriden with a custom implementation self.register_output("result") def compute(self) -> float: """Compute and return the final metric value""" return self.result
[docs]def compute_all_metrics(): """Compute the floating point value of all registered metrics""" # TODO: Deprecate this eventually as we don't want to be keeping global # state if we can help it return { f"{name}.{i}" if len(metric_list) > 1 else name: float(metric) for name, metric_list in Metric.registry.items() for i, metric in enumerate(metric_list) }