# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
import torch.nn as nn
from modelzoo.common.pytorch.model_utils.create_initializer import (
    create_initializer,
)
from .AttentionLayer import MultiheadAttention
[docs]class MultiQueryAttention(MultiheadAttention):
    """Implements the Multi-Query Attention Layer from
        `Fast Transformer Decoding: One Write-Head is All You Need
        <https://arxiv.org/abs/1911.02150>`
    Args:
        embed_dim (int): Number of input units in each projection output
        num_heads (int): Number of attention heads.
        inner_dim (int): Number of output units in attention query/key/value projection. Defaults to ``embed_dim``.
        dropout (float): Dropout rate for key-query weights. Defaults to 0.0.
        batch_first (bool): If True, then the input and output tensors are
            provided as (batch, seq, feature), otherwise the format will be
            (seq, batch, feature). Default: True (batch, seq, feature).
        add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False.
        add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value
            sequences at dim=1. Default: False
        kdim (int):  Number of output units in key projection
        vdim (int):  Number of output units in  projection
        use_projection_bias (bool): Whether to use bias in the key, query, and
            value projections.
        use_ffn_bias (bool): Whether to use bias in the output projection.
        attention_initializer (str): Projection kernel initializer. Defaults to
            ``xavier_uniform``.
        attention_q_initializer: Query projection kernel initializer. If not
            specified, the query will be initialized via ``attention_initializer``
        output_layer_initializer (str or initializer): If not None, use this
            initializer for the output transform layer. Defaults to None.
        bias_initializer (str): Bias initializer. Defaults to ``zeros``.
        attention_type (str): The attention variant to execute. Currently
            accepts ``dot_product`` and ``scaled_dot_product``. Defaults to
            ``scaled_dot_product``.
        softmax_dtype_fp32 (bool): Use an FP32 softmax implementation.
        device (optional): Device to create the model parameters on, can be a cuda device or CS device.
    """
[docs]    def __init__(
        self,
        embed_dim,
        num_heads,
        inner_dim=None,
        dropout=0.0,
        batch_first=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        use_projection_bias=None,
        use_ffn_bias=False,
        attention_initializer="xavier_uniform",
        attention_q_initializer=None,
        output_layer_initializer=None,
        bias_initializer="zeros",
        attention_type="scaled_dot_product",
        scale_qk_dot_by_d=False,
        softmax_dtype_fp32=True,
        scale_qk_dot_by_layer_idx=False,
        device=None,
        # MQA specific
        num_kv_groups=1,
    ):
        super(MultiQueryAttention, self).__init__(
            embed_dim=embed_dim,
            num_heads=num_heads,
            inner_dim=inner_dim,
            dropout=dropout,
            batch_first=batch_first,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            vdim=vdim,
            kdim=kdim,
            use_projection_bias=use_projection_bias,
            use_ffn_bias=use_ffn_bias,
            attention_initializer=attention_initializer,
            attention_q_initializer=attention_q_initializer,
            output_layer_initializer=output_layer_initializer,
            bias_initializer=bias_initializer,
            attention_type=attention_type,
            scale_qk_dot_by_d=scale_qk_dot_by_d,
            softmax_dtype_fp32=softmax_dtype_fp32,
            scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
            device=device,
        )
        self.head_dim = self.inner_dim // self.num_heads
        self.num_kv_groups = num_kv_groups
        self.per_group_num_heads = self.num_heads // self.num_kv_groups
        assert (
            self.num_heads % self.num_kv_groups == 0
        ), f"num_heads has to be a multiple of num_kv_groups but got {self.num_heads} and {self.num_kv_groups}"
        # assuming only 1 head for key and value projections
        self.proj_k_dense_layer = nn.Linear(
            self.kdim,
            self.num_kv_groups * self.head_dim,
            bias=use_projection_bias,
            device=device,
        )
        self.proj_v_dense_layer = nn.Linear(
            self.vdim,
            self.num_kv_groups * self.head_dim,
            bias=use_projection_bias,
            device=device,
        )
        # reset newly initialized parameters
        self.__reset_parameters() 
    def reset_parameters(self):
        super().reset_parameters()
        self.__reset_parameters()
    def __reset_parameters(self):
        # bias initialization
        bias_initializer = create_initializer(self.bias_initializer)
        if self.use_projection_bias:
            bias_initializer(self.proj_k_dense_layer.bias.data)
            bias_initializer(self.proj_v_dense_layer.bias.data)
        # k, v projections
        weight_initializer = create_initializer(self.initializer)
        weight_initializer(self.proj_k_dense_layer.weight.data)
        weight_initializer(self.proj_v_dense_layer.weight.data)
    def construct_key_vector(self, k, attn_mask=None, key_padding_mask=None):
        # linear projection
        k = self.proj_k_dense_layer(
            k
        )  # [batch_size, seq_length, self.num_kv_groups * self.head_dim]
        if self.num_kv_groups == 1:
            return torch.unsqueeze(
                k, 2
            )  # [batch_size, seq_length, 1, kv_channels]
        batch_size, seq_length, _ = k.shape
        # [batch_size, seq_length, self.num_kv_groups, self.head_dim]
        k = k.reshape(batch_size, seq_length, self.num_kv_groups, self.head_dim)
        return k
    def construct_value_vector(self, v, attn_mask=None, key_padding_mask=None):
        # linear projection
        v = self.proj_v_dense_layer(
            v
        )  # [batch_size, seq_length, self.num_kv_groups * self.head_dim]
        if self.num_kv_groups == 1:
            return torch.unsqueeze(
                v, 1
            )  # [batch_size, 1, seq_length, kv_channels]
        batch_size, seq_length, _ = v.shape
        v = v.reshape(batch_size, seq_length, self.num_kv_groups, self.head_dim)
        v = v.transpose(2, 1)
        # [batch_size, self.num_kv_groups, seq_length, self.head_dim]
        return v
    def expand_kv_over_group_dim(self, x):
        # expand k/v over dimension
        batch_size, _, seq_length, _ = x.shape
        x = x.unsqueeze(
            2
        )  # [batch_size, self.num_kv_groups, 1, seq_length, self.head_dim]
        # expand over per_group_num_heads
        x = x.expand(
            batch_size,
            self.num_kv_groups,
            self.per_group_num_heads,
            seq_length,
            self.head_dim,
        )
        x = x.reshape(batch_size, self.num_heads, seq_length, self.head_dim)
        return x
    def calculate_attention_logits(self, q, k, layer_idx):
        if self.num_kv_groups > 1:
            k = self.expand_kv_over_group_dim(k)
        return super().calculate_attention_logits(q, k, layer_idx)
    def calculate_attention_output(self, attention_scores, v):
        if self.num_kv_groups > 1:
            v = self.expand_kv_over_group_dim(v)
        return super().calculate_attention_output(attention_scores, v)
    def check_extra_params(params):
        if "num_kv_groups" not in params:
            params["num_kv_groups"] = 1
            logging.warning(
                "num_kv_groups is not set in the yaml, it is set to 1 by default. "
                "Please provide a value if this is not intended."
            )