cerebras.modelzoo.common.utils.model.generation_utils.sample_tokens#
- cerebras.modelzoo.common.utils.model.generation_utils.sample_tokens(token_logits, rand_uniform, temperature=None, top_k=None, top_p=None)[source]#
Function used to sample tokens, if needed. Sampling supports one of temperature, top_k, top_p or a mix of these. If all sampling arguments are None, we run greedy decoding of the logits.
- Parameters
token_logits (torch.Tensor) – Tensor with logits for tokens
rand_uniform (torch.Tensor) – Random uniform tensor for sampling
temperature (float) – Parameter to control the randomness of the predicted tokens
top_k (int) – Sample tokens by restricting to the k highest probability elements
top_p (float) – Sample tokens by restricting to top tokens summing to prob_cut_off <= prob_cut_off
- Returns
Greedy or sampled token from the logits based on sampling parameters