common.pytorch.sparsity package#
Submodules#
common.pytorch.sparsity.appliance module#
Preview release of static sparsity work.
This module contains a function that runs in the appliance_client to modify groups of tensors while in-flight to the appliance according to sparsification settings. It also contains a function to help build those tensor groups and configure the sparsifier.
This only works in appliance mode.
- common.pytorch.sparsity.appliance.appliance_sparsify(params: dict, weight_fw_name: str, tensors: List[cerebras_appliance.appliance_manager.TensorSendPayload]) None [source]#
This function is used as a closure on the TensorGroup for applying sparsity. params and weight_fw_name are added as a functools.partial closure and must be pickleable to send over to the multiprocessing.Pool.
- Parameters
params – numpy sparsifier params for this single weight.
weight_fw_name – The fw_name of the weight in tensors.
tensors – tensors in the group, with their tensor member set.
- common.pytorch.sparsity.appliance.build_sparsify_grouper(params: dict, model: modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel) cerebras_appliance.appliance_manager.TensorGrouper [source]#
Construct a function building the tensor groups according to parameters needing sparsification and their associated optimizer state.
- Parameters
params – top-level “sparsity” params from the yaml
model – Model to pull parameters and associated optimzer state from.
- Returns
Function to be used as the appliance send_weights_grouper
- common.pytorch.sparsity.appliance.compute_mask(params: dict, weight: numpy.ndarray) Tuple[numpy.ndarray, numpy.ndarray] [source]#
Compute a sparsity mask for the given weight according to params
- Parameters
params – configuration of the sparsity to compute
weight – The initial weight values.
- Returns
mask with np.dtype bool of same shape as weight indicating sparsity pattern: True: keep weight. False: prune weight.
regrow with np.dtype bool indicating positions which _were_ pruned that should instead be regrown as zeros
- common.pytorch.sparsity.appliance.sparsify_grouper(sparse_tensor_groups: Dict[str, Tuple[dict, List[str]]], tensors: Iterable[cerebras_appliance.appliance_manager.TensorSendPayload]) Iterable[cerebras_appliance.appliance_manager.TensorGroup] [source]#
Constuct a grouping of tensors from the given lazy tensors. The group will consist of weights needing sparsification and their associated optimizer state, or single-tensor “groups” for tensors not needing sparsification.
- Parameters
sparse_tensor_groups – FW names of weights needing sparsity applied mapping to the and the sparsity params and optimizer state fw names.
tensors – all tensors to actually send
- Yields
- TensorGroup objects with either single tensors or group of tensors
with a closure to apply sparsity
common.pytorch.sparsity.finalizer module#
Sideband sparse training on CS2 uses an inband representation for pruned weights which needs to be finalized before running training or inference on another device. This module exposes a helper script for finalizing the sparsity.
- common.pytorch.sparsity.finalizer.finalize_cs2_sparsity(model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None) dict #
Given a module loaded from a checkpoint trained on CS2 with sideband sparsity, finalize the sparsity into the model’s parameters as zeros and return the mask representing the sparsity pattern for each sparse parameter.
- Parameters
model – The model whose parameters should be updated to freeze sparsity
optimizer – If given, the corresponding optimizer states sparsity pattern
frozen (is also) –
- Returns
Dict mapping the parameter names to the sparsity pattern as a bool torch tensor, where True values indicates present weights and False represents pruned weights.
- common.pytorch.sparsity.finalizer.finalize_cs2_sparsity_checkpoint(state_dict: dict)[source]#
Given a state_dict trained on CS2 with sideband sparsity, finalize the sparsity of all tensors, both weights and optimizer state by replacing the inband pruned weight representation with zeros for use in dense training or evaluation.
- Parameters
state_dict – state_dict to finalize the sparsity pattern in. Modified.