# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Sparsity mask initialization methods and helpers, invoked by
``BaseSparsityOptimizer``.
"""
import inspect
from typing import Callable, Optional, Union
import numpy as np
import torch
from cerebras_pytorch.utils.typing import signature_matches_type_hint
from .utils import ScoreShaper, make_mask_topk_sparsity
InitMethodCallable = Callable[
    [torch.nn.Parameter, torch.FloatTensor, Optional[ScoreShaper]],
    torch.BoolTensor,
]
InitMethodType = Union[str, InitMethodCallable]
[docs]def random(
    p: torch.nn.Parameter,
    sparsity: torch.FloatTensor,
    score_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None,
) -> torch.BoolTensor:
    """
    Uniformly random sparsity pattern.
    """
    if device is None:
        device = p.device
    # Move sparsity to device so we can use it to trace random initialization
    sparsity = sparsity.to(device)
    score = torch.rand_like(p, device=device)
    return make_mask_topk_sparsity(score, sparsity, score_shaper) 
[docs]def topk(
    p: torch.nn.Parameter,
    sparsity: torch.FloatTensor,
    score_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None,
) -> torch.BoolTensor:
    """
    Prune lowest magnitude weights.
    """
    if device is None:
        device = p.device
    # Move sparsity to the device so we can use it to trace topk
    sparsity = sparsity.to(device)
    score = p.to(device).abs()
    return make_mask_topk_sparsity(score, sparsity, score_shaper) 
[docs]def from_zeros(
    p: torch.nn.Parameter,
    sparsity: torch.FloatTensor,
    score_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None,
) -> torch.BoolTensor:
    """
    Any zeros currently in the weights represent pruned connections.
    NOTE: Doesn't actualy honor the configured sparsity.
    """
    if device is None:
        device = p.device
    return p.to(device) != 0 
[docs]def checkerboard(
    p: torch.nn.Parameter,
    sparsity: torch.FloatTensor,
    score_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None,
) -> torch.BoolTensor:
    """
    Mostly for stress and performance testing, creates a sparsity mask that is
    maximally distributed in a checkerboard across the weight.
    """
    density = 1 - sparsity.item()
    # Create a row with a uniformly distributed sparsity pattern
    col = p.shape[-1]
    # Alocate padding for potential rolling to still result in balance.
    padding = int(np.ceil(col / density + 1e-5))
    # [ 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 ]
    steps = torch.floor(torch.arange(col + padding) * density + 1e-5)
    # [ F,   F,   T,   F,   F,   T,   F,   F]
    mask = steps[1:] != steps[:-1]
    if len(p.shape) == 2:
        row = p.shape[0]
        # Now evenly distribute this over the rows as well by rolling each
        # This offset computation is equivalent to `-np.nonzero(mask)[0][0]`
        # but is more efficient, and more importantly allows torch.roll
        # to be traceable.
        offset = -int(np.floor(1 / density - 1e-5))
        mask = torch.stack([torch.roll(mask, x * offset) for x in range(row)])
    # Trim off padding columns and return
    return mask[..., :col] 
def _noop_compile_only(
    p: torch.nn.Parameter,
    sparsity: torch.FloatTensor,
    score_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None,
) -> torch.BoolTensor:
    """
    "init" method that doesn't init to be used only with compile_only. This
    avoids computing masks on the CPU that aren't ultimately used.
    """
    return torch.empty_like(p, dtype=torch.bool)
def make_init_method(init_method: InitMethodType) -> InitMethodCallable:
    from cerebras_pytorch.backend import current_backend_impl
    if current_backend_impl().compile_only:
        return _noop_compile_only
    init_methods = {
        "random": random,
        "topk": topk,
        "from_zeros": from_zeros,
        "checkerboard": checkerboard,
    }
    init_method_error = (
        f'Unknown `init_method`: "{init_method}". Valid options are one '
        f'of the built-in {list(init_methods.keys())} or a function with '
        f'signature {InitMethodCallable}.'
    )
    if isinstance(init_method, str):
        if init_method not in init_methods:
            raise ValueError(init_method_error)
        init_method = init_methods[init_method]
    elif callable(init_method):
        signature = inspect.signature(init_method)
        if not signature_matches_type_hint(signature, InitMethodCallable):
            raise ValueError(init_method_error)
    else:
        raise ValueError(init_method_error)
    return init_method