# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from modelzoo.common.pytorch.layers.AttentionHelper import get_attention_module
from modelzoo.vision.pytorch.dit.layers.vae.ResNetBlock2D import ResnetBlock2D
[docs]class UNetMidBlock2D(nn.Module):
[docs]    def __init__(
        self,
        in_channels: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        resnet_pre_norm: bool = True,
        add_attention: bool = True,
        attn_num_head_channels=1,
        output_scale_factor=1.0,
        attention_type="aiayn_attention",
        extra_attn_params=None,
    ):
        super().__init__()
        resnet_groups = (
            resnet_groups
            if resnet_groups is not None
            else min(in_channels // 4, 32)
        )
        self.add_attention = add_attention
        extra_attn_params = (
            {} if extra_attn_params is None else extra_attn_params
        )
        AttentionModule = get_attention_module(
            attention_type, extra_attn_params
        )
        self.output_scale_factor = output_scale_factor
        # there is always at least one resnet
        resnets = [
            ResnetBlock2D(
                in_channels=in_channels,
                out_channels=in_channels,
                temb_channels=temb_channels,
                eps=resnet_eps,
                groups=resnet_groups,
                dropout=dropout,
                time_embedding_norm=resnet_time_scale_shift,
                non_linearity=resnet_act_fn,
                output_scale_factor=self.output_scale_factor,
                pre_norm=resnet_pre_norm,
            )
        ]
        attentions = []
        norms = []
        for _ in range(num_layers):
            if self.add_attention:
                group_norm = nn.GroupNorm(
                    num_channels=in_channels,
                    num_groups=resnet_groups,
                    eps=resnet_eps,
                    affine=True,
                )
                if attn_num_head_channels is not None:
                    num_heads = in_channels // attn_num_head_channels
                else:
                    num_heads = 1
                attention_layer = AttentionModule(
                    embed_dim=in_channels,
                    num_heads=num_heads,
                    inner_dim=None,
                    dropout=0.0,
                    batch_first=True,
                    attention_type="scaled_dot_product",
                    softmax_dtype_fp32=True,
                    use_projection_bias=True,
                    use_ffn_bias=True,
                )
                norms.append(group_norm)
                attentions.append(attention_layer)
            else:
                attentions.append(None)
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=self.output_scale_factor,
                    pre_norm=resnet_pre_norm,
                )
            )
        self.norms = nn.ModuleList(norms)
        self.attentions = nn.ModuleList(attentions)
        self.resnets = nn.ModuleList(resnets) 
    def forward(self, hidden_states, temb=None):
        hidden_states = self.resnets[0](hidden_states, temb)
        for norm, attn, resnet in zip(
            self.norms, self.attentions, self.resnets[1:]
        ):
            if attn is not None:
                residual = hidden_states
                batch, channel, height, width = hidden_states.shape
                # norm
                hidden_states = norm(hidden_states)
                # attn
                hidden_states = hidden_states.view(
                    batch, channel, height * width
                ).transpose(1, 2)
                attn_mask = torch.ones(
                    hidden_states.shape[1],
                    hidden_states.shape[1],
                    dtype=hidden_states.dtype,
                    device=hidden_states.device,
                )
                hidden_states = attn(
                    hidden_states,
                    hidden_states,
                    hidden_states,
                    attn_mask=attn_mask,
                )
                hidden_states = hidden_states.transpose(-1, -2).reshape(
                    batch, channel, height, width
                )
                # residual connection
                hidden_states = (
                    hidden_states + residual
                ) / self.output_scale_factor
            hidden_states = resnet(hidden_states, temb)
        return hidden_states