cerebras.modelzoo.common.utils.model.generation_utils.sample_tokens#

cerebras.modelzoo.common.utils.model.generation_utils.sample_tokens(token_logits, rand_uniform, temperature=None, top_k=None, top_p=None)[source]#

Function used to sample tokens, if needed. Sampling supports one of temperature, top_k, top_p or a mix of these. If all sampling arguments are None, we run greedy decoding of the logits.

Parameters
  • token_logits (torch.Tensor) – Tensor with logits for tokens

  • rand_uniform (torch.Tensor) – Random uniform tensor for sampling

  • temperature (float) – Parameter to control the randomness of the predicted tokens

  • top_k (int) – Sample tokens by restricting to the k highest probability elements

  • top_p (float) – Sample tokens by restricting to top tokens summing to prob_cut_off <= prob_cut_off

Returns

Greedy or sampled token from the logits based on sampling parameters