# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from modelzoo.common.pytorch.layers.utils import (
    get_2d_fixed_position_embeddings,
    patchify_helper,
)
from modelzoo.common.pytorch.model_utils.create_initializer import (
    create_initializer,
)
[docs]class ViTEmbeddingLayer(nn.Module):
[docs]    def __init__(
        self,
        image_size=[224, 224],
        num_channels=3,
        patch_size=[16, 16],
        hidden_size=768,
        initializer_range=0.02,
        embedding_dropout_rate=0.0,
        projection_initializer=None,
        position_embedding_initializer=None,
        position_embedding_type="learned",
        use_conv_patchified_embedding=False,
        prepend_cls_token=False,
        init_conv_like_linear=False,
    ):
        super(ViTEmbeddingLayer, self).__init__()
        self.image_size = image_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.initializer_range = initializer_range
        self.position_embedding_type = position_embedding_type
        self.use_conv_patchified_embedding = use_conv_patchified_embedding
        self.prepend_cls_token = prepend_cls_token
        self.init_conv_like_linear = init_conv_like_linear
        assert (
            self.image_size[0] % self.patch_size[0] == 0
            and self.image_size[1] % self.patch_size[1] == 0
        ), f"image size {self.image_size} is not divisible by patch_size {self.patch_size}"
        assert self.position_embedding_type in [
            None,
            "fixed",
            "learned",
        ], "Only `learned` or `fixed` position embeddings are supported for now."
        self.num_patches = [
            (self.image_size[0] // self.patch_size[0]),
            (self.image_size[1] // self.patch_size[1]),
        ]
        if use_conv_patchified_embedding:
            self.linear_proj = nn.Conv2d(
                self.num_channels,
                self.hidden_size,
                kernel_size=self.patch_size,
                stride=self.patch_size,
            )
        else:
            self.embedding_size = (
                self.patch_size[0] * self.patch_size[1] * num_channels
            )
            self.linear_proj = nn.Linear(self.embedding_size, self.hidden_size)
        if self.position_embedding_type == "learned":
            num_position_embeddings = self.num_patches[0] * self.num_patches[1]
            if self.prepend_cls_token:
                num_position_embeddings += 1
            self.position_embeddings = nn.Embedding(
                num_position_embeddings, self.hidden_size
            )
        elif self.position_embedding_type == "fixed":  # fixed sin&cos
            position_embeddings = get_2d_fixed_position_embeddings(
                self.num_patches,
                self.hidden_size,
                add_cls_token=prepend_cls_token,
            )
            self.position_embeddings = torch.nn.Parameter(
                position_embeddings, requires_grad=False
            )
        if self.prepend_cls_token:
            self.cls_embedding = nn.Parameter(
                torch.zeros(1, 1, self.hidden_size)
            )
            self.cls_embedding_position_index = (
                self.num_patches[0] * self.num_patches[1]
            )  # seq_len + 1 - 1, cls pe is the last
        self.default_initializer = {
            "name": "truncated_normal",
            "std": self.initializer_range,
            "mean": 0.0,
            "a": self.initializer_range * -2.0,
            "b": self.initializer_range * 2.0,
        }
        if projection_initializer is None:
            projection_initializer = self.default_initializer
        if position_embedding_initializer is None:
            position_embedding_initializer = self.default_initializer
        self.projection_initializer = projection_initializer
        self.position_embedding_initializer = position_embedding_initializer
        self.dropout_embd = nn.Dropout(embedding_dropout_rate)
        self.__reset_parameters() 
    def reset_parameters(self):
        self.__reset_parameters()
    def __reset_parameters(self):
        projection_initializer = create_initializer(self.projection_initializer)
        w = self.linear_proj.weight.data
        if self.use_conv_patchified_embedding and self.init_conv_like_linear:
            # Modifying fan-in fan-out by reshaping.
            # Bias set to zeros already
            projection_initializer(w.view([w.shape[0], -1]))
        else:
            projection_initializer(w)
        create_initializer("zeros")(self.linear_proj.bias.data)
        if self.prepend_cls_token:
            create_initializer(self.default_initializer)(
                self.cls_embedding.data
            )
        if self.position_embedding_type == "learned":
            create_initializer(self.position_embedding_initializer)(
                self.position_embeddings.weight.data
            )
    def get_image_sequence_position_embeddings(self, embeddings, indices=None):
        # embeddings shape [batch_size, seq_len, hidden_size], shouldn't contain cls
        # indices shape [batch_size, seq_len]
        if indices is None:
            position_ids = torch.arange(
                0, embeddings.shape[1], device=embeddings.device,
            ).expand((embeddings.shape[0], -1))
        else:
            position_ids = indices
        if self.position_embedding_type == "learned":
            position_embeddings = self.position_embeddings(position_ids)
        elif self.position_embedding_type == "fixed":  # fixed
            position_ids = torch.broadcast_to(
                position_ids.unsqueeze(-1),
                (
                    position_ids.shape[0],
                    position_ids.shape[1],
                    embeddings.shape[-1],
                ),
            ).long()
            position_embeddings = torch.gather(
                self.position_embeddings.to(embeddings.dtype).expand(
                    position_ids.shape[0], -1, -1
                ),
                1,
                position_ids,
            )
        return position_embeddings
    def get_cls_token_position_embeddings(self, batch_size, dtype, device):
        if self.position_embedding_type == "learned":
            cls_indices = (
                torch.ones((batch_size, 1), dtype=torch.int32, device=device,)
                * self.cls_embedding_position_index
            )
            pe = self.position_embeddings(cls_indices)
        else:
            pe = (
                self.position_embeddings[self.cls_embedding_position_index :, :]
                .to(dtype)
                .expand(batch_size, -1, -1)
            )
        # [bs, 1, hidden_size]
        return pe
[docs]    def select_patches(self, patches, patch_indices=None):
        """Select from patches based on patch_indices
        Args:
            patches (Tensor): shape [batch_size, full_sequence_length, hidden_size]
            patch_indices (Tensor): shape [batch_size., subset_sequence_length]
        Returns:
            patches (Tensor): shape [batch_size, subset_sequence_length, hidden_size]
        """
        if patch_indices is None:
            return patches
        batch_size, subset_sequence_length = patch_indices.shape
        patch_indices = torch.broadcast_to(
            patch_indices.unsqueeze(-1),
            (batch_size, subset_sequence_length, patches.shape[-1]),
        ).long()
        patches = torch.gather(patches, 1, patch_indices)
        return patches 
[docs]    def forward(self, input_images, patch_indices=None):
        """Applies patching and linear projection to the input images.
        Args:
            input_images (Tensor): shape if use_conv_patchified_embedding ``[batch_size, num_channels, height, width]`` else ``[batch_size, sequence_len, embedding_size]``.
            patch_indices (Tensor): shape [batch_size, subset_seq_length]. If specified, embedding layer will select a subset of all image patches based on indices.
                This is used for applications like MAE. Default to None.
        Returns:
            image_embeddings (Tensor): shape ``[batch_size, sequence_length, hidden_size]``.
        """
        batch_size = input_images.shape[0]
        if self.use_conv_patchified_embedding:
            # conv projection
            image_embeddings = self.linear_proj(input_images)
            # reshape
            hidden_size = image_embeddings.shape[1]
            image_embeddings = image_embeddings.reshape(
                batch_size, hidden_size, -1
            ).transpose(
                1, 2
            )  # [bs, seq_length, hidden_size]
            image_embeddings = self.select_patches(
                image_embeddings, patch_indices=patch_indices
            )
        else:
            # patchify
            patchified_image = patchify_helper(input_images, self.patch_size)
            # this saves computation compared to the conv implementation because patch selection happens before linear_proj
            image_embeddings = self.select_patches(
                patchified_image, patch_indices=patch_indices
            )
            # linear projection
            image_embeddings = self.linear_proj(
                image_embeddings
            )  # [bs, seq_length, hidden_size]
        embeddings = image_embeddings
        if self.position_embedding_type is not None:
            image_pe = self.get_image_sequence_position_embeddings(
                image_embeddings, indices=patch_indices
            )
            embeddings += image_pe
        if self.prepend_cls_token:
            expanded_cls_embedding = self.cls_embedding.type_as(
                image_embeddings
            ).expand(batch_size, -1, -1)
            expanded_cls_position_embedding = self.get_cls_token_position_embeddings(
                batch_size,
                image_embeddings.dtype,
                expanded_cls_embedding.device,
            )
            cls_embeddings = (
                expanded_cls_embedding + expanded_cls_position_embedding
            )
            embeddings = torch.cat([cls_embeddings, embeddings], dim=1)
        embeddings = self.dropout_embd(embeddings)
        return embeddings