""" Accuracy metric for PyTorch """
import warnings
import torch
from cerebras_pytorch.experimental.metrics.metric import CSMetric
[docs]class AccuracyMetric(CSMetric):
"""
Computes the accuracy of the model's predictions
Args:
name: Name of the metric
"""
[docs] def reset(self):
# Maybe any tensor registration here is automatically moved to device
self.register_state(
"total_correct_predictions", torch.tensor(0, dtype=torch.float32)
)
self.register_state(
"total_num_tokens", torch.tensor(0, dtype=torch.float32)
)
[docs] def update(
self, labels, predictions, weights=None, dtype=None
): # pylint: disable=arguments-differ
if labels.shape != predictions.shape:
warnings.warn(
"Shapes mismatch in accuracy metric"
f"\n labels: {labels.shape}"
f"\n predictions {predictions.shape}"
)
predictions = predictions.reshape(labels.shape)
correct_predictions = (labels == predictions).float()
if weights is None:
num_correct_predictions = correct_predictions.sum()
num_tokens = torch.tensor(
correct_predictions.numel(),
dtype=torch.float32,
device=predictions.device,
)
else:
correct_predictions = correct_predictions * weights
num_correct_predictions = correct_predictions.sum()
num_tokens = (weights > 0).float().sum()
self.total_correct_predictions.add_(num_correct_predictions)
self.total_num_tokens.add_(num_tokens)
result = self.total_correct_predictions / self.total_num_tokens
# pylint: disable=attribute-defined-outside-init
if dtype is not None:
self.result = result.to(dtype)
self.result = result