""" Perplexity metric for PyTorch """
import torch
from cerebras_pytorch.experimental.metrics.metric import CSMetric
[docs]class PerplexityMetric(CSMetric):
"""
Computes the perplexity of the model's predictions
Args:
name: Name of the metric
"""
[docs] def reset(self):
self.register_state("total_loss", torch.tensor(0, dtype=torch.float32))
self.register_state(
"total_num_tokens", torch.tensor(0, dtype=torch.float32)
)
[docs] def update(
self, labels, loss, weights=None, dtype=None
): # pylint: disable=arguments-differ
if weights is None:
num_tokens = torch.tensor(
labels.numel(), dtype=torch.float32, device=labels.device
)
else:
num_tokens = (weights > 0).float().sum()
self.total_loss.add_(loss)
self.total_num_tokens.add_(num_tokens)
result = torch.exp(self.total_loss / self.total_num_tokens)
# pylint: disable=attribute-defined-outside-init
if dtype is not None:
self.result = result.to(dtype)
self.result = result