# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import zip_longest
import torch.nn as nn
from torch.nn import Dropout, MaxPool2d, MaxPool3d
from modelzoo.vision.pytorch.unet.layers.UNetBlock import UNetBlock
[docs]class Encoder(nn.Module):
    """
    Args:
        in_channels (int): The input channel dimension before passing through the encoder.
        encoder_filters ([int]): List of filter sizes for each block in the encoder.
        convs_per_block ([str]): List of conv specifications for each conv in the block.
        bias (bool): Flag to use bias vectors.
        norm_layer (nn.Module): Desired normalization layer.
        norm_kwargs (dict): A dictionary of the arguments to pass to the
            constructor of the normalization layer.
        act (str): Activation to use.
        skip_connect (bool): Flag for if the model concatenates encoder outputs to decoder inputs.
        residual_blocks (bool): Flag for using residual connections at the end of each block.
        downscale_method (str):Downscaling method at the end of each block.
        dropout_rate (float): The probability that each element is dropped.
        use_conv3d (bool): 3D convolutions will be used when set to True
        downscale_first_conv (bool): If True, the first convolution operation in each UNetBlock will
            be downscaled. If False, the last convolution in each UNetBlock will be downscaled
        downscale_encoder_blocks (bool or [bool]): bool or list of bools that determine whether each
            block in the Encoder includes downsampling. Length of the list must correspond to the
            number of UNetBlocks in the Encoder. If a single bool is provided, all blocks will use 
            this value.
    """
[docs]    def __init__(
        self,
        in_channels,
        encoder_filters,
        convs_per_block,
        bias,
        norm_layer,
        norm_kwargs,
        act="relu",
        skip_connect=True,
        residual_blocks=False,
        downscale_method="max_pool",
        dropout_rate=0.0,
        use_conv3d=False,
        downscale_first_conv=False,
        downscale_encoder_blocks=True,
    ):
        super(Encoder, self).__init__()
        self.skip_connect = skip_connect
        dropout_layers = []
        pooling_layers = []
        unet_blocks = []
        if isinstance(downscale_encoder_blocks, bool):
            downscale_encoder_blocks = [downscale_encoder_blocks] * len(
                encoder_filters
            )
        for block_idx in range(len(encoder_filters) - 1):
            unet_blocks.append(
                UNetBlock(
                    in_channels=encoder_filters[block_idx - 1]
                    if block_idx
                    else in_channels,
                    out_channels=encoder_filters[block_idx],
                    encoder=True,
                    convs_per_block=convs_per_block,
                    skip_connect=skip_connect,
                    norm_layer=norm_layer,
                    norm_kwargs=norm_kwargs,
                    downscale_method=downscale_method,
                    bias=bias,
                    residual_blocks=residual_blocks,
                    act=act,
                    use_conv3d=use_conv3d,
                    downscale_first_conv=downscale_first_conv,
                    downscale=downscale_encoder_blocks[block_idx],
                )
            )
            if downscale_method == "max_pool":
                if use_conv3d:
                    max_pool = MaxPool3d(kernel_size=2, stride=2)
                else:
                    max_pool = MaxPool2d(kernel_size=2, stride=2)
                pooling_layers.append(max_pool)
            if dropout_rate:
                dropout_layers.append(Dropout(p=dropout_rate))
        self.unet_blocks = nn.ModuleList(unet_blocks)
        self.pooling_layers = nn.ModuleList(pooling_layers)
        self.dropout_layers = nn.ModuleList(dropout_layers) 
    def forward(self, inputs):
        skip_connections = []
        outputs = inputs
        for unet_block, pooling_layer, dropout_layer in zip_longest(
            self.unet_blocks, self.pooling_layers, self.dropout_layers
        ):
            outputs, skip_connection = unet_block(outputs)
            if self.skip_connect:
                skip_connections.append(skip_connection)
            if pooling_layer:
                outputs = pooling_layer(outputs)
            if dropout_layer:
                outputs = dropout_layer(outputs)
        return outputs, skip_connections