cerebras.pytorch.sparse#
Sparsity Algorithms#
These classes are the built-in sparsity algorithms.
SparsityAlgorithm is the abstract base
class that all sparsity algorithms should derive from.
- class cerebras.pytorch.sparse.SparsityAlgorithm(sparsity, init_method='random')[source]#
- Base class for all sparsity algorithms. - This class is responsible for sparsifying parameters and registering hooks to apply the sparsity pattern to the parameters before forward and to the gradients after backward. It also registers hooks to update the sparsity pattern after each optimizer step. - Warning - The way that sparse parameters are represented in the cerebras.pytorch API is via a mask tensor. This mask tensor is multiplied inplace to the original dense parameter before forward and to the gradients after backward. However, this is not the way that sparse parameters are represented on a Cerebras system. There, sparse parameters are handled natively in CSR format. As such, there is no mask tensor that can be referenced on the system side. What this means is that using the mask tensor haphazardly can lead to compile failures. Even if compile succeeds, any operations performed on the mask can be very computationally expensive. Having said that, there are several operations on masks that are supported on the Cerebras system. Please see the usage in the prepackaged algorithms as a guide for when and how it is acceptable to use the mask. - Constructs a SparsityAlgorithm instance. - Parameters
- sparsity (Optional[Union[float, cerebras.pytorch.sparse.utils.HyperParameterSchedule]]) – The sparsity level to use for the algorithm. This can be a float or a - HyperParameterSchedule. If a dictionary is passed in, then it is automatically converted to a- HyperParameterSchedule
- init_method (Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras.pytorch.sparse.utils.ScoreShaper], Optional[torch.device]], torch.BoolTensor]]) – The method to use to initialize the sparsity mask. See - make_init_methodfor more details.
 
 - property num_sparse_params: int#
- Return the number of parameters that have been sparsified by this algorithm. 
 - get_sparse_params(obj)[source]#
- Get all sparse parameters that were sparsified by this algorithm. - Parameters
- obj (Union[torch.Tensor, torch.nn.Module, torch.optim.Optimizer]) – The object to get sparse parameters from. 
- Returns
- If obj is a Tensor, returns the sparse parameter associated with that tensor (if any). If obj is a Module, returns an iterator over all sparse parameters of the module - and its submodules recursively. - If obj is an Optimizer, returns an iterator over all sparse parameters associated
- with the optimize param groups. 
 
- Return type
- Union[cerebras.pytorch.sparse.base.SparseParameter, Generator[cerebras.pytorch.sparse.base.SparseParameter, None, None]] 
 
 - initialize()[source]#
- Initialize the sparsity pattern for all parameters sparsified by this algorithm. 
 - csx_annotate_sparsity(param)[source]#
- Annotate the parameter with hints about the sparsity pattern. - These hints are used as performance hints for the Cerebras compiler. - Parameters
- param (cerebras.pytorch.sparse.base.SparseParameter) – The sparse parameter to annotate with hints. 
 
 - property sparsity: Dict[torch.Tensor, cerebras.pytorch.sparse.utils.HyperParameterSchedule]#
- Return the mapping between a parameter and its sparsity schedule. 
 - sparsify_parameter(module, name, param)[source]#
- Initialize the mask for a parameter in the given module. - Parameters
- module (torch.nn.Module) – The module that owns the parameter 
- name (str) – The full name of the parameter 
- param (torch.Tensor) – The parameter to initialze the sparsity mask for. 
 
 
 - final apply(obj)[source]#
- Sparsify the passed in object. - Note - This is called implicitly when calling - module.apply(sparsity)or- optimizer.apply(sparsity)- Parameters
- obj (Union[torch.nn.Module, cerebras.pytorch.optim.optimizer.Optimizer]) – a - torch.nn.Moduleor a- cstorch.optim.Optimizerobject to sparsify.
 
 - sparsify_module(module)[source]#
- Sparsify the - torch.nn.Moduleobject.- Parameters
- module (torch.nn.Module) – the - torch.nn.Moduleobject to sparsify
 
 - prune_weight(sparse_param)#
- Prune the dense weight and register a hook to prune the gradients. - Note - This is called automatically in a module forward pre-hook. 
 - _grad_hook(p, grad)[source]#
- Hook to prune the gradients after backward(). - Note - This is called automatically in the parameter’s backward grad hook. - Parameters
- p (torch.Tensor) – The original parameter. 
- grad (torch.Tensor) – The gradient of the parameter. 
 
 
 - sparsify_optimizer(optimizer)[source]#
- Sparsify the - torch.optim.Optimizerobject.- Parameters
- optimizer (torch.optim.Optimizer) – the - torch.optim.Optimizerobject to sparsify
 
 - abstract update(optimizer=None)[source]#
- Update the parameter’s sparsity masks. - Parameters
- optimizer (Optional[cerebras.pytorch.optim.optimizer.Optimizer]) – The optimizer that is being used to update the sparse parameters. 
 
 - register_target_sparsity_hook(hook)[source]#
- Register a hook which will be called when a new target sparsity is computed. It should have the following signature: - hook(sparsity, name, target) - sparsityargument is the sparsity instance being used.- nameis the name of the group of parameters that the target sparsity is being computed for.- targetis the computed target sparsity value.- Parameters
- hook (Callable) – The user defined hook to be registered. 
- Returns
- a handle that can be used to remove the added hook by calling - handle.remove()
- Return type
- torch.utils.hooks.RemovableHandle
 
 - register_computed_sparsity_hook(hook)[source]#
- Register a hook which will be called when a new sparsity mask is computed. It should have the following signature: - hook(sparsity, name, computed) - sparsityargument is the sparsity instance being used.- nameis the name of the parameter that the mask belongs to.- computedis the calculated sparsity level of the newly computed mask.- Parameters
- hook (Callable) – The user defined hook to be registered. 
- Returns
- a handle that can be used to remove the added hook by calling - handle.remove()
- Return type
- torch.utils.hooks.RemovableHandle
 
 
Static Sparsity Algorithms#
- class cerebras.pytorch.sparse.Static(sparsity=None, **kwargs)[source]#
- Bases: - cerebras.pytorch.sparse.base.SparsityAlgorithm- Constructs a Static sparsity instance. - Parameters
- sparsity (Optional[float]) – A float specifying the level of sparsity to apply to each parameter 
 
Dynamic Sparsity Algorithms#
- class cerebras.pytorch.sparse.DynamicSparsityAlgorithm(sparsity=None, update=None, **kwargs)[source]#
- Bases: - cerebras.pytorch.sparse.base.SparsityAlgorithm,- abc.ABC- Constructs a DynamicSparsityAlgorithm instance. - Parameters
- sparsity (Optional[Union[float, dict]]) – - A float specifying the level of sparsity to apply to each parameter or a dictionary specifying the schedule to use for sparsity. The dictionary must have a “type” key, which specifies the type of schedule to use. The remaining keys are schedule-specific. The following schedule types are supported: 
- update (Optional[Union[Dict, Callable[[torch.LongTensor], torch.BoolTensor]]]) – A dictionary specifying the schedule to use for updating the sparsity pattern. The dictionary must contain keys that can be used to construct either a - FreqScheduleor a- ListSchedule. If not provided, the sparsity pattern will be updated every step.
 
 - property is_update_step: torch.BoolTensor#
- Returns a boolean tensor indificating whether the current step is an update step according to the update schedule. 
 - abstract update_mask(p, mask, sparsity)#
- Compute an updated sparsity pattern. - Parameters
- p (torch.Tensor) – the parameter to sparsify 
- mask (torch.tensor(dtype=torch.bool)) – the current mask of param p 
- sparsity (torch.tensor(dtype=torch.float32)) – the desired sparsity level 
 
- Returns
- The updated sparsity pattern on parameter p 
- Return type
 
 
- class cerebras.pytorch.sparse.GMP(**kwargs)[source]#
- Bases: - cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm- Implements Gradual Magnitude Pruning - Sparsity increases monotonically based on weight magnitude. - See: https://arxiv.org/abs/1710.01878 - Parameters
- **kwargs – All arguments are passed to the - DynamicSparsityAlgorithm’s constructor.
 - Example - sparsity_opt = cstorch.sparse.GMP(
- schedule={“type”: “exp”, “init”: 0, “gamma”: 1000*math.log(0.3) update={“freq”: 1000}, 
 - ) 
- class cerebras.pytorch.sparse.SET(drop_fraction=0.3, **kwargs)[source]#
- Bases: - cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm- Implements Sparse Evolutionary Training (SET) - Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly. - See: https://arxiv.org/abs/1707.04780 - Parameters
- drop_fraction (Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras.pytorch.sparse.utils.HyperParameterSchedule]) – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter. 
- **kwargs – Any additional arguments are passed to the - DynamicSparsityAlgorithm’s constructor.
 
 - Example: - sparsity_opt = cstorch.sparse.SET( sparsity=0.9, update={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) 
- class cerebras.pytorch.sparse.RigL(drop_fraction=0.3, balance_in_groups=None, balance_out_groups=None, **kwargs)[source]#
- Bases: - cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm- Implements Rigging the Lottery (RigL). - Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown using a proxy measure of where a pruned connection would have had the most impact by finding the highest magnitude (dense) gradients of pruned weights. - See: https://arxiv.org/abs/1911.11134 - Parameters
- drop_fraction (Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras.pytorch.sparse.utils.HyperParameterSchedule]) – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter. 
- balance_in_groups (Optional[int]) – The number of groups used by - InputGroupScoreShaper
- balance_out_groups (Optional[int]) – The number of groups used by - OutputGroupScoreShaper
- **kwargs – Any additional arguments are passed to the - DynamicSparsityAlgorithm’s constructor.
 
 - Example: - sparsity = cstorch.sparse.RiGL( sparsity=0.9, update={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) 
Group Sparsity Algorithm#
- class cerebras.pytorch.sparse.Group(groups=None)[source]#
- Bases: - cerebras.pytorch.sparse.base.SparsityAlgorithm- Group sparsity algorithm. This algorithm allows for multiple sparsity algorithms to be applied to different groups of parameters. - For example: - sparsity = cstorch.sparse.Group({ "fc1.*": cstorch.sparse.Static(sparsity=0.5), "fc2.*": cstorch.sparse.GMP( schedule=[0.3, 0.4, 0.5], update: {"freq": 100} ), }) sparsity.add("fc3.*", cstorch.sparse.RigL(sparsity=0.5)) model.apply(sparsity) optimizer.apply(sparsity) - The group sparsity algorithm will apply the sparsity algorithms to the parameters that match the filter. If a parameter name matches multiple filters, the first filter that matches will be used. - Parameters
- groups (Dict[str, cerebras.pytorch.sparse.base.SparsityAlgorithm]) – A dictionary of filter -> algorithm pairs. See - addfor more details.
 - add(filter, algorithm)[source]#
- Add a sparsity algorithm to the group. - Parameters
- filter (Union[str, Callable[[str, torch.Tensor], bool]]) – - A string, list of strings, or callable that takes a parameter name and a parameter tensor and returns True if the parameter should be sparsified. - If one or more strings are provided, the filter will match if any of the strings match the parameter name. The strings may contain glob patterns, e.g. “fc1.*” will match all parameters in the “fc1” module. 
- algorithm (cerebras.pytorch.sparse.base.SparsityAlgorithm) – An instance of - SparsityAlgorithm
 
 
 - extend(group)[source]#
- Extend the group with the filters and algorithms from another group. - Parameters
- group (cerebras.pytorch.sparse.group.Group) – An instance of - Group
 
 
Configuration routine#
The highest level entry-point to enabling sparsity is
configure, which will
configure a sparsity algorithm and return it. The config
dictionary follows the same form as given in
Sparsity via YAML.
If param_filter is not provided, the following default param filter
gets applied.
- cerebras.pytorch.sparse.configure.default_sparse_param_filter(name, param)[source]#
- Return True if the given parameter should be sparse. - Only returns true if the parameter is > 1D and not an embedding or norm or lm_head or pe_helper. - Parameters
- name (str) – Name of the parameter 
- param (torch.nn.Parameter) – The parameter itself 
 
 
Customizing Sparsity & Reference#
Several building blocks can be inherited from or composed to help build new dynamic sparsity algorithms or customize the behavior of existing ones.
cerebras.pytorch.sparse.init#
Sparsity mask initialization methods and helpers, invoked by
SparsityAlgorithm.
- cerebras.pytorch.sparse.init.random(p, sparsity, score_shaper=None, device=None)[source]#
- Uniformly random sparsity pattern. - A score tensor with the same shape as the parameter is randomly generated with values between 0.0 and 1.0. The mask is then created by taking the - top-kof the score tensor, where k is determined by the sparsity level.
- cerebras.pytorch.sparse.init.topk(p, sparsity, score_shaper=None, device=None)[source]#
- Prune lowest magnitude weights. 
- cerebras.pytorch.sparse.init.from_zeros(p, sparsity, score_shaper=None, device=None)[source]#
- Any zeros currently in the weights represent pruned connections. NOTE: Doesn’t actualy honor the configured sparsity. 
- cerebras.pytorch.sparse.init.checkerboard(p, sparsity, score_shaper=None, device=None)[source]#
- Mostly for stress and performance testing, creates a sparsity mask that is maximally distributed in a checkerboard across the weight. 
- cerebras.pytorch.sparse.init.make_init_method(init_method)[source]#
- Returns the corresponding init method callable for the given init_method. - Parameters
- init_method (Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras.pytorch.sparse.utils.ScoreShaper], Optional[torch.device]], torch.BoolTensor]]) – - The method to use to initialize the sparsity mask. This can be a string or a callable. If a string, it must be one of - ” - random”: Randomly initialize the mask
- ” - topk”: prune the lowest magnitude weights
- ” - from_zeros”: Any zeros in the weights represent pruned connections
- ” - checkerboard”: Creates a sparsity mask that is maximally distributed across the weight
 - If a callable, it must have the signature: - def init_method( param: torch.Tensor, sparsity: float, scope_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None ) -> torch.Tensor: - where
- paramis the original dense parameter
- sparsityis the sparsity level
- scope_shaperis an optional callable that can be used to reshape the mask
- deviceis optionally the device to use to initialize the mask
 
 
 
cerebras.pytorch.sparse.utils#
- class cerebras.pytorch.sparse.utils.HyperParameterSchedule[source]#
- Base class for step-aware hyperparameters used in Sparsity Optimizers. - abstract compute(step)[source]#
- Return a torch.Tensor with the value of the hyperparatmer at the given step. - Parameters
- step (torch.Tensor) – int64 tensor holding current step 
- Returns
- torch.Tensor on the device of step with the value of the
- hyperparamter 
 
- Return type
 
 - update(is_update_step)[source]#
- Given a boolean tensor indicating if this is an update step, update the internal state of this hyperparameter. - Parameters
- is_update_step (torch.Tensor) – A boolean tensor indicating if this is an update step. 
 
 
- class cerebras.pytorch.sparse.utils.Constant(value)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Constant at every step. - Parameters
- value (float) – The constant value of the hyperparameter 
 
- class cerebras.pytorch.sparse.utils.Linear(init, slope)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Linear change from an initial value. - \(y(step) = init + step \cdot slope\) - Parameters
- init (float) – The initial value of the hyperparameter 
- slope (float) – The rate of change of the hyperparameter 
 
 
- class cerebras.pytorch.sparse.utils.Exp(init, gamma, final=1)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Exponential, approaching an asymptotic final value - \(y(step) = final + (init-final) e^{step \cdot gamma}\) - Parameters
- init (float) – The initial value of the hyperparameter 
- gamma (float) – The rate of change of the hyperparameter 
- final (float) – The final value of the hyperparameter (Default: 1.0) 
 
 
- class cerebras.pytorch.sparse.utils.Power(init, beta)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Power law. - \(y(step) = init \cdot beta^{step}\) - Parameters
- init (float) – The initial value of the hyperparameter 
- beta (float) – The rate of change of the hyperparameter 
 
 
- class cerebras.pytorch.sparse.utils.Cosine(init, half_period, minimum=0.0)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Cosine function for oscilating between an initial (maximum) value down to a minimum and back to the maximum every period. - \(y(step) = o + a \cdot \cos(step \cdot \pi / half\_period)\), where \(o = (init + minimum)/2\) and \(a = init - o\). - Parameters
- init (float) – The initial value of the hyperparameter 
- half_period (float) – The number of steps to complete a full cycle 
- minimum (float) – The minimum value of the hyperparameter 
 
 
- class cerebras.pytorch.sparse.utils.Cycling(values)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Hyper parameter cycling between discrete values at update steps. - Parameters
- values (List[float]) – A list of discrete values to cycle through 
 
- class cerebras.pytorch.sparse.utils.Lambda(fn)[source]#
- Bases: - cerebras.pytorch.sparse.utils.HyperParameterSchedule- Invoke a user’s lambda function of step to obtain the hyper parameter. - Parameters
- fn (Callable[[torch.Tensor], torch.Tensor]) – A lambda function that takes a step and returns a hyperparameter 
 
- cerebras.pytorch.sparse.utils.make_hyperparam_schedule(schedule)[source]#
- Given some user specified configuration, construct a HyperParameterSchedule object that is step aware. 
- class cerebras.pytorch.sparse.utils.FreqSchedule(freq=1, start=0, stop=None)[source]#
- Bases: - cerebras.pytorch.sparse.utils.UpdateSchedule- When schedulding sparsity update steps on a regular interval, this class allows configuring the start and stop step in addition to the update frequency. - Parameters
- freq – The frequency of steps at which to update the sparsity pattern (Default: 1) 
- start – The step at which to start updating the sparsity pattern (Default: 0) 
- stop – The step at which to stop updating the sparsity pattern (Default: None) 
 
 
- class cerebras.pytorch.sparse.utils.ListSchedule(steps)[source]#
- Bases: - cerebras.pytorch.sparse.utils.UpdateSchedule- When schedulding requires an irregular update cadence, explicit steps can be provided as a list. - Parameters
- steps (Union[List[int], torch.Tensor]) – A list of steps at which to update the sparsity pattern 
 
- cerebras.pytorch.sparse.utils.make_update_schedule(update)[source]#
- Instantiate a supported schedule type. 
- class cerebras.pytorch.sparse.utils.ScoreFlattener[source]#
- Bases: - cerebras.pytorch.sparse.utils.ScoreShaper- Default ScoreShaper which everything is flattened, providing a global competition for magnitude. If only sub-portions of the weight should compete for magnitude, provide an alternative shaper object. 
- class cerebras.pytorch.sparse.utils.OutputGroupScoreShaper(num_groups)[source]#
- Bases: - cerebras.pytorch.sparse.utils.ScoreShaper- A ScoreShaper interface when weights are logically shaped as [num_groups*out_per_group, insize], but need to be scored in a “balanced” fashion as [num_groups, out_per_group*insize] - Examples - >>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 2.0], ... [0.0, -1.0]]) - >>> # 50% sparsity, drops the 2 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... ) tensor([[ True, True], [False, False]]) - >>> # 50% sparsity, but computed rowwise >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... score_shaper=OutputGroupScoreShaper(num_groups=2) ... ) tensor([[False, True], [ True, False]]) 
- class cerebras.pytorch.sparse.utils.InputGroupScoreShaper(num_groups)[source]#
- Bases: - cerebras.pytorch.sparse.utils.ScoreShaper- A ScoreShaper interface when weights are logically shaped as [outsize, num_groups*in_per_group], but need to be scored in a “balanced” fashion as [num_groups, outsize*in_per_group] - Examples - >>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 0.0], ... [2.0, -1.0]]) - >>> # 50% sparsity, drops the 2 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... ) tensor([[ True, False], [ True, False]]) - >>> # 50% sparsity, but computed columnwise >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... score_shaper=InputGroupScoreShaper(num_groups=2) ... ) tensor([[False, True], [ True, False]]) 
- cerebras.pytorch.sparse.utils.make_mask_drop_minimum(score, mask, drop_fraction, score_shaper=None)[source]#
- Given a sparse - score(with- mask), return a new- torch.BoolTensorthe same shape as mask where a- drop_fractionportion of the currently present (- mask==True) connections are dropped (- mask==False).- The connections are dropped at positions corresponding to the lowest values of - score.- Equivalently, a subset of - maskis returned corresponding to the highest magnitude elements of- score.- Parameters
- score (torch.FloatTensor) – Values used to evaluate which positions to drop 
- mask (torch.BoolTensor) – Current connections, same shape as - score
- drop_fraction (torch.FloatTensor) – What fraction of current connections to drop 
- score_shaper (Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]]) – If given, - score(and- mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,- scoreand- maskare reinterpreted as 1D tensors, yielding completely unstructured sparsity.
 
- Returns
- New mask that has existing connections dropped. No connections will be regrown (unless drop_fraction is negative). 
- Return type
- torch.BoolTensor 
 
- cerebras.pytorch.sparse.utils.make_mask_grow_maximum(score, mask, sparsity, mask_nonzero=None, score_shaper=None)[source]#
- Given a sparse - score(with- mask), return a new torch.BoolTensor the same shape as- maskwhere some currently pruned connections are regrown (from those positions with the highest score) such that the returned mask has the given target sparsity.- If - maskis already less sparse (has more connections) than the target, none are regrown and the original mask is returned as-is. That is, the given- maskshould be more sparse than the target sparsity.- Parameters
- score (torch.FloatTensor) – Values used to evaluate which positions to regrow 
- mask (torch.BoolTensor) – Current connections, same shape as - score
- drop_fraction – What fraction of current connections to drop 
- mask_nonzero (Optional[torch.IntTensor]) – If given, the number of nonzero elements currently in the mask, used to control the number of connections needing regrowth. If it is not given, will be computed as - mask.nonzero().int(). Since- make_mask_grow_maximumis often used in conjunction with- make_mask_drop_minimum, this value is commonly available.
- score_shaper (Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]]) – If given, - score(and- mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,- scoreand- maskare reinterpreted as 1D tensors, yielding completely unstructured sparsity.
 
- Returns
- New mask that has connections regrown necessary to reach (decrease to) the target sparsity. 
- Return type
- torch.BoolTensor 
 
- cerebras.pytorch.sparse.utils.make_mask_topk_sparsity(score, sparsity, score_shaper=None)[source]#
- Given a dense - score, return a- torch.BoolTensorwhich is True at positions corresponding to values in the top- k = (1-sparsity)*score.numel()of- score.- Parameters
- score (torch.FloatTensor) – Values used to evaluate which positions to keep. 
- sparsity (torch.FloatTensor) – rankless tensor in range [0,1] controlling fraction of the resulting mask that will be pruned. 
- score_shaper (Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]]) – If given, - scorewill be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,- scoreis reinterpreted as a 1D tensor, yielding completely unstructured sparsity.
 
- Returns
- maskwith given- sparsity, keeping only the highest values from- score.
- Return type
- torch.BoolTensor 
 - Examples - >>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 2.0], ... [0.0, -1.0]]) - >>> # 25% sparsity, drops the one lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.25), ... ) tensor([[ True, True], [ True, False]]) - >>> # 75% sparsity, drops the 3 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.75), ... ) tensor([[False, True], [False, False]])