# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Sparsity mask initialization methods and helpers, invoked by
:py:class:`~cerebras.pytorch.sparse.SparsityAlgorithm`.
"""
import inspect
from typing import Callable, Optional, Union
import numpy as np
import torch
from cerebras.pytorch.utils.typing import signature_matches_type_hint
from .utils import ScoreShaper, make_mask_topk_sparsity
InitMethodCallable = Callable[
[
torch.nn.Parameter,
torch.FloatTensor,
Optional[ScoreShaper],
Optional[torch.device],
],
torch.BoolTensor,
]
InitMethodType = Union[str, InitMethodCallable]
[docs]def random(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None,
) -> torch.BoolTensor:
"""
Uniformly random sparsity pattern.
A score tensor with the same shape as the parameter is randomly generated
with values between 0.0 and 1.0. The mask is then created by taking the
:py:func:`top-k <cerebras.pytorch.sparse.utils.make_mask_topk_sparsity>` of
the score tensor, where k is determined by the sparsity level.
"""
if device is None:
device = p.device
# Move sparsity to device so we can use it to trace random initialization
sparsity = sparsity.to(device)
score = torch.rand_like(p, device=device)
return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def topk(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None,
) -> torch.BoolTensor:
"""
Prune lowest magnitude weights.
"""
if device is None:
device = p.device
# Move sparsity to the device so we can use it to trace topk
sparsity = sparsity.to(device)
score = p.to(device).abs()
return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def from_zeros(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None,
) -> torch.BoolTensor:
"""
Any zeros currently in the weights represent pruned connections.
NOTE: Doesn't actualy honor the configured sparsity.
"""
if device is None:
device = p.device
return p.to(device) != 0
[docs]def checkerboard(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None,
) -> torch.BoolTensor:
"""
Mostly for stress and performance testing, creates a sparsity mask that is
maximally distributed in a checkerboard across the weight.
"""
density = 1 - sparsity.item()
# Create a row with a uniformly distributed sparsity pattern
col = p.shape[-1]
# Alocate padding for potential rolling to still result in balance.
padding = int(np.ceil(col / density + 1e-5))
# [ 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 ]
steps = torch.floor(torch.arange(col + padding) * density + 1e-5)
# [ F, F, T, F, F, T, F, F]
mask = steps[1:] != steps[:-1]
if len(p.shape) == 2:
row = p.shape[0]
# Now evenly distribute this over the rows as well by rolling each
# This offset computation is equivalent to `-np.nonzero(mask)[0][0]`
# but is more efficient, and more importantly allows torch.roll
# to be traceable.
offset = -int(np.floor(1 / density - 1e-5))
mask = torch.stack([torch.roll(mask, x * offset) for x in range(row)])
# Trim off padding columns and return
return mask[..., :col].clone()
def _noop_compile_only(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None,
) -> torch.BoolTensor:
"""
"init" method that doesn't init to be used only with compile_only. This
avoids computing masks on the CPU that aren't ultimately used.
"""
return torch.empty_like(p, dtype=torch.bool)
[docs]def make_init_method(init_method: InitMethodType) -> InitMethodCallable:
"""
Returns the corresponding init method callable for the given `init_method`.
Args:
init_method: The method to use to initialize the sparsity mask.
This can be a string or a callable. If a string, it must be one of
- ":py:func:`~cerebras.pytorch.sparse.init.random`": Randomly initialize the mask
- ":py:func:`~cerebras.pytorch.sparse.init.topk`": prune the lowest magnitude weights
- ":py:func:`~cerebras.pytorch.sparse.init.from_zeros`": Any zeros in the weights represent pruned connections
- ":py:func:`~cerebras.pytorch.sparse.init.checkerboard`": Creates a sparsity mask that is maximally distributed across the weight
If a callable, it must have the signature:
.. code-block:: python
def init_method(
param: torch.Tensor,
sparsity: float,
scope_shaper: Optional[ScoreShaper] = None,
device: Optional[torch.device] = None
) -> torch.Tensor:
where
- ``param`` is the original dense parameter
- ``sparsity`` is the sparsity level
- ``scope_shaper`` is an optional callable that can be used to reshape the mask
- ``device`` is optionally the device to use to initialize the mask
"""
from cerebras.pytorch.backend import current_backend_impl
if current_backend_impl().compile_only:
return _noop_compile_only
init_methods = {
"random": random,
"topk": topk,
"from_zeros": from_zeros,
"checkerboard": checkerboard,
}
init_method_error = (
f'Unknown `init_method`: "{init_method}". Valid options are one '
f'of the built-in {list(init_methods.keys())} or a function with '
f'signature {InitMethodCallable}.'
)
if isinstance(init_method, str):
if init_method not in init_methods:
raise ValueError(init_method_error)
init_method = init_methods[init_method]
elif callable(init_method):
signature = inspect.signature(init_method)
if not signature_matches_type_hint(signature, InitMethodCallable):
raise ValueError(init_method_error)
else:
raise ValueError(init_method_error)
return init_method