common.pytorch.sparsity package#

Submodules#

common.pytorch.sparsity.appliance module#

Preview release of static sparsity work.

This module contains a function that runs in the appliance_client to modify groups of tensors while in-flight to the appliance according to sparsification settings. It also contains a function to help build those tensor groups and configure the sparsifier.

This only works in appliance mode.

common.pytorch.sparsity.appliance.appliance_sparsify(params: dict, weight_fw_name: str, tensors: List[cerebras_appliance.appliance_manager.TensorSendPayload]) None[source]#

This function is used as a closure on the TensorGroup for applying sparsity. params and weight_fw_name are added as a functools.partial closure and must be pickleable to send over to the multiprocessing.Pool.

Parameters
  • params – numpy sparsifier params for this single weight.

  • weight_fw_name – The fw_name of the weight in tensors.

  • tensors – tensors in the group, with their tensor member set.

common.pytorch.sparsity.appliance.build_sparsify_grouper(params: dict, model: modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel) cerebras_appliance.appliance_manager.TensorGrouper[source]#

Construct a function building the tensor groups according to parameters needing sparsification and their associated optimizer state.

Parameters
  • params – top-level “sparsity” params from the yaml

  • model – Model to pull parameters and associated optimzer state from.

Returns

Function to be used as the appliance send_weights_grouper

common.pytorch.sparsity.appliance.compute_mask(params: dict, weight: numpy.ndarray) Tuple[numpy.ndarray, numpy.ndarray][source]#

Compute a sparsity mask for the given weight according to params

Parameters
  • params – configuration of the sparsity to compute

  • weight – The initial weight values.

Returns

mask with np.dtype bool of same shape as weight indicating sparsity pattern: True: keep weight. False: prune weight.

regrow with np.dtype bool indicating positions which _were_ pruned that should instead be regrown as zeros

common.pytorch.sparsity.appliance.sparsify_grouper(sparse_tensor_groups: Dict[str, Tuple[dict, List[str]]], tensors: Iterable[cerebras_appliance.appliance_manager.TensorSendPayload]) Iterable[cerebras_appliance.appliance_manager.TensorGroup][source]#

Constuct a grouping of tensors from the given lazy tensors. The group will consist of weights needing sparsification and their associated optimizer state, or single-tensor “groups” for tensors not needing sparsification.

Parameters
  • sparse_tensor_groups – FW names of weights needing sparsity applied mapping to the and the sparsity params and optimizer state fw names.

  • tensors – all tensors to actually send

Yields
TensorGroup objects with either single tensors or group of tensors

with a closure to apply sparsity

common.pytorch.sparsity.appliance.validate_sparsity_params(params: dict)[source]#

Validates the sparsity block of the model configuration. A ValueError will be raised if there are any invalid or unsupported settings.

common.pytorch.sparsity.finalizer module#

Sideband sparse training on CS2 uses an inband representation for pruned weights which needs to be finalized before running training or inference on another device. This module exposes a helper script for finalizing the sparsity.

common.pytorch.sparsity.finalizer.finalize_cs2_sparsity(model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None) dict#

Given a module loaded from a checkpoint trained on CS2 with sideband sparsity, finalize the sparsity into the model’s parameters as zeros and return the mask representing the sparsity pattern for each sparse parameter.

Parameters
  • model – The model whose parameters should be updated to freeze sparsity

  • optimizer – If given, the corresponding optimizer states sparsity pattern

  • frozen (is also) –

Returns

Dict mapping the parameter names to the sparsity pattern as a bool torch tensor, where True values indicates present weights and False represents pruned weights.

common.pytorch.sparsity.finalizer.finalize_cs2_sparsity_checkpoint(state_dict: dict)[source]#

Given a state_dict trained on CS2 with sideband sparsity, finalize the sparsity of all tensors, both weights and optimizer state by replacing the inband pruned weight representation with zeros for use in dense training or evaluation.

Parameters

state_dict – state_dict to finalize the sparsity pattern in. Modified.

Module contents#