Source code for experimental.metrics.perplexity

""" Perplexity metric for PyTorch """

import torch

from cerebras_pytorch.experimental.metrics.metric import CSMetric


[docs]class PerplexityMetric(CSMetric): """ Computes the perplexity of the model's predictions Args: name: Name of the metric """
[docs] def reset(self): self.register_state("total_loss", torch.tensor(0, dtype=torch.float32)) self.register_state( "total_num_tokens", torch.tensor(0, dtype=torch.float32) )
[docs] def update( self, labels, loss, weights=None, dtype=None ): # pylint: disable=arguments-differ if weights is None: num_tokens = torch.tensor( labels.numel(), dtype=torch.float32, device=labels.device ) else: num_tokens = (weights > 0).float().sum() self.total_loss.add_(loss) self.total_num_tokens.add_(num_tokens) result = torch.exp(self.total_loss / self.total_num_tokens) # pylint: disable=attribute-defined-outside-init if dtype is not None: self.result = result.to(dtype) self.result = result