Source code for common.pytorch.model_utils.create_initializer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch.nn as nn

from modelzoo.common.pytorch.model_utils.weight_initializers import (
    lecun_normal_,
    lecun_uniform_,
    trunc_normal_,
    variance_scaling_,
)

INIT2FN = {
    "constant": nn.init.constant_,
    "ones": nn.init.ones_,
    "zeros": nn.init.zeros_,
    "eye": nn.init.eye_,
    "uniform": nn.init.uniform_,
    "normal": nn.init.normal_,
    "xavier_normal": nn.init.xavier_normal_,
    "glorot_normal": nn.init.xavier_normal_,  # alias for `xavier_normal`
    "xavier_uniform": nn.init.xavier_uniform_,
    "glorot_uniform": nn.init.xavier_uniform_,  # alias for `xavier_uniform`
    "truncated_normal": trunc_normal_,
    "variance_scaling": variance_scaling_,
    "lecun_normal": lecun_normal_,
    "lecun_uniform": lecun_uniform_,
    "kaiming_normal": nn.init.kaiming_normal_,
    "kaiming_uniform": nn.init.kaiming_uniform_,
}


[docs]def create_initializer(spec): """ Creates the specified initializer. :param dict/str spec: either a string indicating the name of the initializer or a dict that includes the name + other params if relevant. :param int seed: random seed for the initializer or None to run unseeded. :returns: initializer that can be passed to layers """ if type(spec) == str: spec = {"name": spec} if "name" not in spec: raise ValueError("Initializer name must be provided") name = spec["name"].lower() if name == "constant": return lambda tensor: INIT2FN[name]( tensor, val=_get_spec_value(spec, "val", 0) ) elif name in ["ones", "zeros", "eye", "lecun_normal", "lecun_uniform"]: return lambda tensor: INIT2FN[name](tensor) elif name == "uniform": return lambda tensor: INIT2FN[name]( tensor, a=_get_spec_value(spec, "a", -0.05), b=_get_spec_value(spec, "b", 0.05), ) elif name == "normal": return lambda tensor: INIT2FN[name]( tensor, mean=_get_spec_value(spec, "mean", 0.0), std=_get_spec_value(spec, "std", 0.05), ) elif name in [ "xavier_normal", "xavier_uniform", "glorot_normal", "glorot_uniform", ]: return lambda tensor: INIT2FN[name]( tensor, gain=_get_spec_value(spec, "gain", 1.0) ) elif name == "kaiming_normal": return lambda tensor: INIT2FN[name]( tensor, a=_get_spec_value(spec, "a", 0.0), mode=_get_spec_value(spec, "mode", "fan_in"), nonlinearity=_get_spec_value( spec, "nonlinearity", "leaky_relu", override_gain_calc=True ), ) elif name == "kaiming_uniform": return lambda tensor: INIT2FN[name]( tensor, a=_get_spec_value(spec, "a", 0.0), mode=_get_spec_value(spec, "mode", "fan_in"), nonlinearity=_get_spec_value( spec, "nonlinearity", "leaky_relu", override_gain_calc=True ), ) elif name == "truncated_normal": std = _get_spec_value(spec, "std", 0.05) return lambda tensor: INIT2FN[name]( tensor, mean=_get_spec_value(spec, "mean", 0.0), std=std, a=_get_spec_value(spec, "a", -2 * std), b=_get_spec_value(spec, "b", 2 * std), ) elif name == "variance_scaling": return lambda tensor: INIT2FN[name]( tensor, scale=_get_spec_value(spec, "scale", 1.0), mode=_get_spec_value(spec, "mode", "fan_in"), distribution=_get_spec_value( spec, "distribution", "truncated_normal" ), ) else: raise ValueError(f"Invalid or unsupported initializer, '{name}'. ")
def _get_spec_value(spec, key, default_value, override_gain_calc=False): """ Returns value of spec[key]. If key is not present, gives a warning and returns default_value. """ def is_nonlinearity(value): return value in [ "linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", "sigmoid", "tanh", "relu", "leaky_relu", ] name = spec["name"] value = spec.get(key) if value is None: logging.debug( f"{name} initializer's {key} parameter not specified. " f"Using {default_value}." ) value = default_value elif override_gain_calc: pass elif is_nonlinearity(value): value = nn.init.calculate_gain(value) return value