# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch.nn as nn
from modelzoo.common.pytorch.model_utils.weight_initializers import (
    lecun_normal_,
    lecun_uniform_,
    trunc_normal_,
    variance_scaling_,
)
INIT2FN = {
    "constant": nn.init.constant_,
    "ones": nn.init.ones_,
    "zeros": nn.init.zeros_,
    "eye": nn.init.eye_,
    "uniform": nn.init.uniform_,
    "normal": nn.init.normal_,
    "xavier_normal": nn.init.xavier_normal_,
    "glorot_normal": nn.init.xavier_normal_,  # alias for `xavier_normal`
    "xavier_uniform": nn.init.xavier_uniform_,
    "glorot_uniform": nn.init.xavier_uniform_,  # alias for `xavier_uniform`
    "truncated_normal": trunc_normal_,
    "variance_scaling": variance_scaling_,
    "lecun_normal": lecun_normal_,
    "lecun_uniform": lecun_uniform_,
    "kaiming_normal": nn.init.kaiming_normal_,
    "kaiming_uniform": nn.init.kaiming_uniform_,
}
[docs]def create_initializer(spec):
    """
    Creates the specified initializer.
    :param dict/str spec: either a string indicating the name of the initializer
        or a dict that includes the name + other params if relevant.
    :param int seed: random seed for the initializer or None to run unseeded.
    :returns: initializer that can be passed to layers
    """
    if type(spec) == str:
        spec = {"name": spec}
    if "name" not in spec:
        raise ValueError("Initializer name must be provided")
    name = spec["name"].lower()
    if name == "constant":
        return lambda tensor: INIT2FN[name](
            tensor, val=_get_spec_value(spec, "val", 0)
        )
    elif name in ["ones", "zeros", "eye", "lecun_normal", "lecun_uniform"]:
        return lambda tensor: INIT2FN[name](tensor)
    elif name == "uniform":
        return lambda tensor: INIT2FN[name](
            tensor,
            a=_get_spec_value(spec, "a", -0.05),
            b=_get_spec_value(spec, "b", 0.05),
        )
    elif name == "normal":
        return lambda tensor: INIT2FN[name](
            tensor,
            mean=_get_spec_value(spec, "mean", 0.0),
            std=_get_spec_value(spec, "std", 0.05),
        )
    elif name in [
        "xavier_normal",
        "xavier_uniform",
        "glorot_normal",
        "glorot_uniform",
    ]:
        return lambda tensor: INIT2FN[name](
            tensor, gain=_get_spec_value(spec, "gain", 1.0)
        )
    elif name == "kaiming_normal":
        return lambda tensor: INIT2FN[name](
            tensor,
            a=_get_spec_value(spec, "a", 0.0),
            mode=_get_spec_value(spec, "mode", "fan_in"),
            nonlinearity=_get_spec_value(
                spec, "nonlinearity", "leaky_relu", override_gain_calc=True
            ),
        )
    elif name == "kaiming_uniform":
        return lambda tensor: INIT2FN[name](
            tensor,
            a=_get_spec_value(spec, "a", 0.0),
            mode=_get_spec_value(spec, "mode", "fan_in"),
            nonlinearity=_get_spec_value(
                spec, "nonlinearity", "leaky_relu", override_gain_calc=True
            ),
        )
    elif name == "truncated_normal":
        std = _get_spec_value(spec, "std", 0.05)
        return lambda tensor: INIT2FN[name](
            tensor,
            mean=_get_spec_value(spec, "mean", 0.0),
            std=std,
            a=_get_spec_value(spec, "a", -2 * std),
            b=_get_spec_value(spec, "b", 2 * std),
        )
    elif name == "variance_scaling":
        return lambda tensor: INIT2FN[name](
            tensor,
            scale=_get_spec_value(spec, "scale", 1.0),
            mode=_get_spec_value(spec, "mode", "fan_in"),
            distribution=_get_spec_value(
                spec, "distribution", "truncated_normal"
            ),
        )
    else:
        raise ValueError(f"Invalid or unsupported initializer, '{name}'. ") 
def _get_spec_value(spec, key, default_value, override_gain_calc=False):
    """
    Returns value of spec[key].
    If key is not present, gives a warning and returns default_value.
    """
    def is_nonlinearity(value):
        return value in [
            "linear",
            "conv1d",
            "conv2d",
            "conv3d",
            "conv_transpose1d",
            "conv_transpose2d",
            "conv_transpose3d",
            "sigmoid",
            "tanh",
            "relu",
            "leaky_relu",
        ]
    name = spec["name"]
    value = spec.get(key)
    if value is None:
        logging.debug(
            f"{name} initializer's {key} parameter not specified. "
            f"Using {default_value}."
        )
        value = default_value
    elif override_gain_calc:
        pass
    elif is_nonlinearity(value):
        value = nn.init.calculate_gain(value)
    return value