Source code for experimental.metrics.accuracy

""" Accuracy metric for PyTorch """
import warnings

import torch

from cerebras_pytorch.experimental.metrics.metric import CSMetric


[docs]class AccuracyMetric(CSMetric): """ Computes the accuracy of the model's predictions Args: name: Name of the metric """
[docs] def reset(self): # Maybe any tensor registration here is automatically moved to device self.register_state( "total_correct_predictions", torch.tensor(0, dtype=torch.float32) ) self.register_state( "total_num_tokens", torch.tensor(0, dtype=torch.float32) )
[docs] def update( self, labels, predictions, weights=None, dtype=None ): # pylint: disable=arguments-differ if labels.shape != predictions.shape: warnings.warn( "Shapes mismatch in accuracy metric" f"\n labels: {labels.shape}" f"\n predictions {predictions.shape}" ) predictions = predictions.reshape(labels.shape) correct_predictions = (labels == predictions).float() if weights is None: num_correct_predictions = correct_predictions.sum() num_tokens = torch.tensor( correct_predictions.numel(), dtype=torch.float32, device=predictions.device, ) else: correct_predictions = correct_predictions * weights num_correct_predictions = correct_predictions.sum() num_tokens = (weights > 0).float().sum() self.total_correct_predictions.add_(num_correct_predictions) self.total_num_tokens.add_(num_tokens) result = self.total_correct_predictions / self.total_num_tokens # pylint: disable=attribute-defined-outside-init if dtype is not None: self.result = result.to(dtype) self.result = result