Source code for common.pytorch.metrics.cb_metric

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import torch

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch


[docs]@dataclass class DeviceOutputs: """Class for encapsulating the outputs of `CBMetric.update_on_device`. Args: args: postional arguments which are passed to `CBMetric.update_on_host` once they are converted to CPU tensors. kwargs: keyword arguments which are passed to `CBMetric.update_on_host` once they are converted to CPU tensors. """ args: List[Any] = field(default_factory=list) kwargs: Dict[str, Any] = field(default_factory=dict)
[docs]class CBMetric(ABC): """Base class for creating metrics on CS devices. Subclasses must override methods to provide the full functionality of the metric. These methods are meant to split the computation graph into 2 portions: 1. update_on_device: Compiles and runs on the device (i.e., Cerebras). 2. update_on_host: Runs on the host (i.e., CPU). These metrics also support running on CPU and GPU. """
[docs] def __init__(self, name: Optional[str] = None): """Constructs a `CBMetric` instance. This also registers the metric in the global pool of metrics. Therefore, it is important for subclasses to call `super().__init__()`. Otherwise, the metrics will not run. Args: name: Name of the metric. If None or empty string, it defaults to the name of the class. """ # Keeps track of total number of times the metric was updated self._num_updates = 0 # Get the metric name self._name = self._get_unique_name(name) # Register the metric in the global list global _METRICS assert self._name not in _METRICS _METRICS[self._name] = self self._is_appliance = False self._ws_enabled = False if cm.use_cs(): self._is_appliance = cm.is_appliance() state = cbtorch.state() if state.is_inside_execution_loop: raise RuntimeError( "Metrics must be created outside the exeuction loop." ) self._ws_enabled = cbtorch.env().weight_streaming_mode # Stores the state tensors on the device self.init_state()
@property def num_updates(self): """Returns number of times the metric was updated (i.e., stepped).""" return self._num_updates @property def name(self): """Returns the name of the metric.""" return self._name
[docs] def on_device_state_dict(self) -> Dict[str, torch.Tensor]: """A hook for subclasses to inject metric state variables (WS only). In constrast to pipeline execution strategy where metrics are executed on the host, in weight streaming, metrics are part of the graph and are executed on device. As such any metric state variables that are updated need to be tracked to create a correct graph. This hook provides a mechanism for metric implementations to specify their state variables which will come up as outputs in the compile. """ return dict()
[docs] def init_state(self): """Sets the initial state of the metric. Subclasses should override this method to provide any metric-specific states. This method is called once as part of `__init__`. """
[docs] def update_on_device(self, *args, **kwargs) -> DeviceOutputs: """Define the portion of the metric computation that runs on the device. This method must return a `DeviceOutputs` object whose args/kwargs can only contain a item/list/tuple/dict of torch tensors or Nones. These tensors are converted to CPU tensors at the step boundary and passed to `update_on_host` to do the host (i.e. CPU) portion of the computation. The default implementation is just a passthrough where the arguments are converted to host tensors as is. This method is called for every iteration. NOTE: No tensors should be evaluated in this method. This method merely defines the operations in the graph that runs on device. """ return DeviceOutputs(args=list(args), kwargs=kwargs)
[docs] @abstractmethod def update_on_host(self, *args, **kwargs) -> None: """Define the portion of the metric computation that runs on host. This methods takes as inputs the outputs of `update_on_device` whose tensors have been evaluated and converted to CPU tensors. It can do any sort of computation on the host (e.g., updating the metric state). This method is called for every iteration. """ raise NotImplementedError
[docs] @abstractmethod def compute(self) -> Any: """Returns the computed metric value over many iterations. This is the "reduction" part of the metric over all steps. """ raise NotImplementedError
[docs] def reset_state(self) -> None: """Resets the metric state. Subclasses should override this method to clear any metrics-specific states. """
[docs] def reset(self) -> None: """Resets the metric state. Instead of overriding this method, subclasses should override `reset_state` method which is called internally in this method. """ self._num_updates = 0 self.reset_state()
def __call__(self, *args, **kwargs) -> None: """Run the metric accumulator over one execution step. The arugments to this method are passed directly to `update_on_device`. """ self._track_state() device_outputs = self.update_on_device(*args, **kwargs) assert isinstance(device_outputs, DeviceOutputs), ( f"Expected device outputs to be of type `DeviceOutputs`, " f"but got `{type(device_outputs)}`." ) if self._is_appliance: def _on_activations_received(): cpu_args = [ state.get_activation_for_output(tensor) if isinstance(tensor, torch.Tensor) else tensor for tensor in device_outputs.args ] cpu_kwargs = { key: state.get_activation_for_output(tensor) if isinstance(tensor, torch.Tensor) else tensor for key, tensor in device_outputs.kwargs.items() } self.update_on_host(*cpu_args, **cpu_kwargs) self._num_updates += 1 state = cbtorch.state() state.track_object( { "cb_metric": { self.name: [device_outputs.args, device_outputs.kwargs] } }, force=True, ) state.register_activation_callback(_on_activations_received) else: @cm.step_closure def _update_on_host_closure(args, kwargs): args = cm.to_cpu(args) kwargs = cm.to_cpu(kwargs) self.update_on_host(*args, **kwargs) self._num_updates += 1 _update_on_host_closure(device_outputs.args, device_outputs.kwargs) def _get_unique_name(self, name: Optional[str] = None): """Returns a unique name for this metric. Args: name: The default name prefix to use. If None, class name is used. Defaults to None. """ idx = 0 prefix = name or self.__class__.__name__ unique_name = prefix for name in get_all_metrics(): if name == unique_name: idx += 1 unique_name = f"{prefix}_{idx}" return unique_name def _track_state(self): """Tracks and names the metric state.""" state_dict = self.on_device_state_dict() if state_dict: if not self._ws_enabled: raise RuntimeError( "On device metric state variables aren't supported for " "Pipeline mode." ) state = cbtorch.state() state.track_object(state_dict) cm.set_metric_state_names(state_dict, self.name)
[docs] @classmethod def create_metric_impl_factory( cls, pipeline_metric_cls: Optional["CBMetric"] = None, ws_metric_cls: Optional["CBMetric"] = None, ) -> "CBMetric": """ Returns a factory for generating a correct instance of a metric Args: pipeline_metric_cls: Optional `CBMetric` which specifies the compute for the pipeline execution strategy. Can be used for, and is the default for CPU/GPU ws_metric_cls: Optional `CBMetric` which species the compute in weight streaming execution strategy. Can be used for CPU/GPU Returns: metric_factory: (*args, **kwargs) -> `CBMetric` that automatically gives an instance of the correct metric given the execution strategy Raises: AssertionError: if values of `pipeline_metric_cls` or `ws_metric_cls` are invalid """ if pipeline_metric_cls is None and ws_metric_cls is None: raise ValueError( f"At least one metric implementation for either pipeline or weight streaming " f"execution strategy must be provided, but both are None." ) if pipeline_metric_cls is not None and not ( issubclass(pipeline_metric_cls, cls) ): raise TypeError( f"expected provided pipeline metric to be a subclass of {cls}, " f"got {type(pipeline_metric_cls)}" ) if ws_metric_cls is not None and not (issubclass(ws_metric_cls, cls)): raise TypeError( f"expected provided WS metric to be a subclass of {cls}, " f"got {type(ws_metric_cls)}" ) # check if update on device was overridden if pipeline_metric_cls is not None and ws_metric_cls is not None: if ( pipeline_metric_cls.update_on_device == CBMetric.update_on_device ): pipe_method = pipeline_metric_cls.update_on_host else: pipe_method = pipeline_metric_cls.update_on_device pipe_spec = inspect.getfullargspec(pipe_method) if ws_metric_cls.update_on_device == CBMetric.update_on_device: ws_method = ws_metric_cls.update_on_host else: ws_method = ws_metric_cls.update_on_device ws_spec = inspect.getfullargspec(ws_method) if pipe_spec != ws_spec: raise ValueError( f"The signature seen by calling the pipeline metric implementation " f"{pipeline_metric_cls} does not match the one seen by the WS metric " f"implementation {ws_metric_cls}. The methods that define these " f"signatures are {pipe_method} for the pipeline metric and {ws_method} for " f"the weight streaming metric. Please ensure that they take the same " f"args, kwargs, and have the same default values." ) # return a factory for the metric def metric_factory(*args, **kwargs): """ Returns an instance of the proper metric for the execution strategy Raises: TypeError: if the needed metric class was not provided """ if cm.use_cs(): if cbtorch.env().weight_streaming_mode: if ws_metric_cls: return ws_metric_cls(*args, **kwargs) raise TypeError( f"No weight streaming implementation was provided for this metric, " f"but you are running the weight streaming strategy. " f"The only registered metric is a pipeline implementation " f"{pipeline_metric_cls}. If you'd like to use this metric with the " f"weight streaming execution strategy, please provide an implementation, " f"or change the execution strategy." ) else: if pipeline_metric_cls: return pipeline_metric_cls(*args, **kwargs) raise TypeError( f"No pipeline implementation was provided for this metric, " f"but you are running the pipeline strategy. " f"The only registered metric is a weight streaming implementation " f"{ws_metric_cls}. If you'd like to use this metric with the " f"pipeline execution strategy, please provide an implementation, " f"or change the execution strategy." ) else: if pipeline_metric_cls: # pipeline is cpu/gpu default return pipeline_metric_cls(*args, **kwargs) return ws_metric_cls(*args, **kwargs) return metric_factory
# Keeps track of all registered metrics _METRICS = dict()
[docs]def get_all_metrics() -> Dict[str, CBMetric]: """Returns all registered metrics.""" return _METRICS
[docs]def compute_all_metrics() -> Dict[str, Any]: """Computes all the registered metrics and returns them in a dict.""" metrics = dict() for name, metric in get_all_metrics().items(): metrics[name] = metric.compute() return metrics
[docs]def reset_all_metrics() -> None: """Resets the internal state of all reistered metrics.""" for metric in get_all_metrics().values(): metric.reset()