# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from modelzoo.common.pytorch.layers.utils import ModuleWrapperClass
[docs]class BatchChannelNorm2D(nn.Module):
    """
    Implements Batch Channel Normalization
    proposed in `Micro-Batch Training with Batch-Channel
    Normalization and Weight Standardization`
    <https://arxiv.org/abs/1903.10520>
    Args:
        num_groups (int): number of groups to separate the channels into.
        num_channels (int): number of channels. `C` from an expected input of size (N, C, H, W).
        eps (float): a value added to the denominator for numerical stability. Default: 1e-5.
        momentum (float): The `Update rate` value used for the 
            `running_mean` and `running_var` computation. Default: 0.1.
        device (torch.device): Device to place the learnable parameters.
        dtype (torch.dtype): Data type of learnable parameters.
    
    Shape:
        input: `(N, C, H, W)`
        output: `(N, C, H, W)` (same shape as input)
    """
[docs]    def __init__(
        self,
        num_groups,
        num_channels,
        eps=1.0e-5,
        momentum=0.1,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(BatchChannelNorm2D, self).__init__()
        if num_channels % num_groups != 0:
            raise ValueError('num_channels must be divisible by num_groups')
        self.num_channels = num_channels
        self.num_groups = num_groups
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(
            torch.empty(1, num_groups, 1, **factory_kwargs)
        )
        self.bias = nn.Parameter(
            torch.empty(1, num_groups, 1, **factory_kwargs)
        )
        self.batchnorm = nn.BatchNorm2d(
            num_channels,
            eps=self.eps,
            momentum=self.momentum,
            affine=True,
            **factory_kwargs,
        )
        cn = partial(
            nn.functional.batch_norm,
            running_mean=None,
            running_var=None,
            weight=None,
            bias=None,
            training=self.training,
            momentum=0,
            eps=self.eps,
        )
        self.channelnorm = ModuleWrapperClass(
            fcn=cn,
            name='ChannelNorm',
            kwargs={
                "momentum": 0,
                "eps": self.eps,
                "affine": False,
                "track_running_stats": False,
            },
        )
        self.reset_parameters() 
    def forward(self, input):
        out = self.batchnorm(input)
        out = out.reshape(1, input.shape[0] * self.num_groups, -1)
        out = self.channelnorm(out)
        out = out.reshape(input.shape[0], self.num_groups, -1)
        out = self.weight * out + self.bias
        out = out.reshape(input.shape)
        return out
    def reset_parameters(self) -> None:
        nn.init.ones_(self.batchnorm.weight.data)
        nn.init.zeros_(self.batchnorm.bias.data)
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.bias)