cerebras.modelzoo.common.utils.model.transformer_utils.get_extended_attention_mask#
- cerebras.modelzoo.common.utils.model.transformer_utils.get_extended_attention_mask(attention_mask: torch.Tensor, input_shape: Optional[Tuple[int]] = None, causal: bool = False, device: Optional[torch.device] = None, dtype=None) torch.Tensor [source]#
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. :param attention_mask: Mask with ones indicating tokens to attend to, zeros for tokens to ignore. :type attention_mask:
torch.Tensor
:param input_shape: The shape of the input to the model (required for causal masks) :type input_shape:Tuple[int]
:param causal: (bool): if enabled the returned mask will be causal :param device: (torch.device
):The device of the input to the model.
- Returns
torch.Tensor
The extended attention mask, with a the same dtype asattention_mask.dtype
.